diff --git a/.build/server.crt b/.build/server.crt index da0e620b83..d161ab2652 100644 --- a/.build/server.crt +++ b/.build/server.crt @@ -1,57 +1,20 @@ -Certificate: - Data: - Version: 3 (0x2) - Serial Number: 12599801177921850358 (0xaedb7c6a2a948bf6) - Signature Algorithm: sha1WithRSAEncryption - Issuer: C=AU, ST=Some-State, O=Internet Widgits Pty Ltd, CN=localhost - Validity - Not Before: Mar 13 11:19:37 2015 GMT - Not After : Apr 12 11:19:37 2015 GMT - Subject: C=AU, ST=Some-State, O=Internet Widgits Pty Ltd, CN=localhost - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - Public-Key: (1024 bit) - Modulus: - 00:c1:df:3f:3b:b8:59:b1:33:ae:9c:ec:6b:44:41: - 7b:0a:cd:51:62:98:e2:11:f0:a0:7d:65:67:9b:49: - 88:15:91:cf:30:f1:23:dc:3c:00:83:76:be:59:df: - 9c:66:8f:eb:f3:a7:73:a0:eb:2a:26:85:d2:48:aa: - 4f:88:1b:b8:31:22:df:bd:e3:1b:6f:4f:70:c3:b2: - f4:a8:14:07:0e:77:d7:fe:91:b1:b1:3d:0a:cc:5e: - 32:ac:31:06:d3:d7:cf:e5:fc:3c:c0:db:c0:6b:0e: - 00:e0:a5:32:4a:2d:90:63:37:7e:c8:e6:5d:ad:df: - 30:81:7e:65:4a:6d:71:a5:9b - Exponent: 65537 (0x10001) - X509v3 extensions: - X509v3 Subject Key Identifier: - 17:4C:64:08:33:71:2A:34:33:CA:15:3E:F3:B8:98:1A:E7:8E:64:F4 - X509v3 Authority Key Identifier: - keyid:17:4C:64:08:33:71:2A:34:33:CA:15:3E:F3:B8:98:1A:E7:8E:64:F4 - - X509v3 Basic Constraints: - CA:TRUE - Signature Algorithm: sha1WithRSAEncryption - 81:f4:69:3e:b1:c0:9f:4b:82:10:8d:3e:7c:98:70:2a:f3:24: - ca:33:13:35:1d:9e:84:dc:b4:f1:17:1f:e6:18:d5:86:51:b3: - ce:3e:4a:97:39:cc:7e:74:94:01:da:68:43:df:b0:b6:fc:29: - 0c:86:ce:5f:0c:3d:c6:f0:8c:c0:f5:86:e7:0b:3f:fb:b0:d6: - b0:2c:9a:9e:15:be:31:dc:6d:bb:32:92:b7:36:fb:65:5a:f1: - d2:44:04:fe:eb:97:f2:8a:31:2e:4c:fd:f9:80:00:8d:91:81: - c1:90:97:18:fa:e2:c6:1c:ff:28:d1:58:94:b3:b5:9f:7a:f7: - 39:b3 -----BEGIN CERTIFICATE----- -MIICgDCCAemgAwIBAgIJAK7bfGoqlIv2MA0GCSqGSIb3DQEBBQUAMFkxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNTAzMTMxMTE5 -MzdaFw0xNTA0MTIxMTE5MzdaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21l -LVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNV -BAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAwd8/O7hZ -sTOunOxrREF7Cs1RYpjiEfCgfWVnm0mIFZHPMPEj3DwAg3a+Wd+cZo/r86dzoOsq -JoXSSKpPiBu4MSLfveMbb09ww7L0qBQHDnfX/pGxsT0KzF4yrDEG09fP5fw8wNvA -aw4A4KUySi2QYzd+yOZdrd8wgX5lSm1xpZsCAwEAAaNQME4wHQYDVR0OBBYEFBdM -ZAgzcSo0M8oVPvO4mBrnjmT0MB8GA1UdIwQYMBaAFBdMZAgzcSo0M8oVPvO4mBrn -jmT0MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAgfRpPrHAn0uCEI0+ -fJhwKvMkyjMTNR2ehNy08Rcf5hjVhlGzzj5KlznMfnSUAdpoQ9+wtvwpDIbOXww9 -xvCMwPWG5ws/+7DWsCyanhW+MdxtuzKStzb7ZVrx0kQE/uuX8ooxLkz9+YAAjZGB -wZCXGPrixhz/KNFYlLO1n3r3ObM= +MIIDUjCCAjoCFAwuj6RwuZSjCGYHja8m9tbr3nFeMA0GCSqGSIb3DQEBCwUAMGgx +EzARBgNVBAoTCk15IENvbXBhbnkxCzAJBgNVBAsTAklUMRAwDgYDVQQHEwdNeSBU +b3duMQ8wDQYDVQQIEwZNb3Njb3cxCzAJBgNVBAYTAlJVMRQwEgYDVQQDEwtsb2Nh +bGhvc3RDQTAeFw0yMTA0MTAxMzA0MDBaFw0yMjA0MTAxMzA0MDBaMGMxEzARBgNV +BAoTCk15IENvbXBhbnkxCzAJBgNVBAsTAklUMRAwDgYDVQQHEwdNeSBUb3duMQ8w +DQYDVQQIEwZNb3Njb3cxCzAJBgNVBAYTAlJVMQ8wDQYDVQQDEwZzZXJ2ZXIwggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC8LoQbo2DFwC17gZwJ8xrPKHGX +UKxoo5UcyZ3/2zZ006TYkswssejKksuiICTMI89OD8n55pNTZkXPUH7oR2oIyxTY +SiWPiNzbEh0FOxH9Kh5gmajqM/4X44OaprmyQ56m4Y2LZO2nZ9hHoe+ZRoan3+pa +g8weOM/n/wYuXZtdElOxNsB8pg09K4gevHVaLaSBCEeQfHev51vClFdN3+orBi/r +hnQF3vdw7oMT1JSH75Ray51wRaypLIslAc2DcPFTCQJMmXXMTcAcxmjAVUGrfY+d +sSCdXnOZtd7yk+0X0bVGKLBkCTOP7QpmfOVu9bOhscDiK5EoAaDKqdHSMUfhAgMB +AAEwDQYJKoZIhvcNAQELBQADggEBAKCo2Y1uKbudA8JpV6yo35tc7Z6n03++BAdq +egUBKOiE4ze7xQ7lmlt572ptqXlU/8JuPWa2Qb/wGksR0HpVPTAeU3pbXz1dcCXC +A9wCtSxapjyCYbkDrDl2FQuK0OfJi0q71JZU66D58Qu0l45nWON30to9dSiw3zPw +Rdk7X86GHYIBHKsj7mjiy1v8jH1sXeWvThOmU6+rv8UY8VuJiu4MQDdYa0Y5KFh/ +OL3tVsi7zoNu2OXY1cTKuUpKMQPbO+WSdelYromYK2OAXaNqnC27GegPqvCFWJ2I +9NZuXYj3X+j0ydZSKVjDgCda8H68olBnO0zh44XirCBef7uTVLw= -----END CERTIFICATE----- diff --git a/.build/server.key b/.build/server.key index e7772f4b86..b6dd15913f 100644 --- a/.build/server.key +++ b/.build/server.key @@ -1,15 +1,27 @@ -----BEGIN RSA PRIVATE KEY----- -MIICXgIBAAKBgQDB3z87uFmxM66c7GtEQXsKzVFimOIR8KB9ZWebSYgVkc8w8SPc -PACDdr5Z35xmj+vzp3Og6yomhdJIqk+IG7gxIt+94xtvT3DDsvSoFAcOd9f+kbGx -PQrMXjKsMQbT18/l/DzA28BrDgDgpTJKLZBjN37I5l2t3zCBfmVKbXGlmwIDAQAB -AoGBAJnspubCcivXzb33kx7JImisJP60RWFa/AEzPrQzCGGft7Gy8vbLiNjXsT/n -4uQnRn3YKFzN+VRGkXNyDN0SrQSrRrFST56aLBhqe4BEO3l6JQJQ6h1y5aW7/R+y -ehV9HIQd+RFgcyejStXJnXYC7lPycOjT4SGG/7mOZkOIbRmJAkEA8NwHvvxsLW77 -UwVto7us0oR0Ey8/vCgbEruZTdr+rVeOKKUvM4K1r4hMunXc2kJ+hhYYMoF2wfIv -gpPq1F+GLQJBAM4PFV8pL+fLqQqoRh/2dGDBKQU5wlQS+A4sTAPTdy3V1zx3BE8s -KJeYIk2Z72HqNLAL/LUC/gwKwuVg+3k0v+cCQQC2HZhZxyDAZabwSi1xXMk6z924 -V8R4L1bxHhm3bXudc5NQlj2PVCiuFX/2iIG6IgbqubAIGC3ETauwrskjVSrtAkEA -gltgddcki0t4IVnbIxVTAnKwdLHZkj591tmHLVR2LPT/OS1B+KRC+cQwz0729cao -lka/E/RUq2GTcnEsJb2NOQJAOHwXsZJj+qrMQmHmmej6X2Rro4bX7cJmyK99mAtC -LigiSdiY/uDqJ/p+cHPH9g2RLWeFetUAZID94uNRk1peAg== +MIIEowIBAAKCAQEAvC6EG6NgxcAte4GcCfMazyhxl1CsaKOVHMmd/9s2dNOk2JLM +LLHoypLLoiAkzCPPTg/J+eaTU2ZFz1B+6EdqCMsU2Eolj4jc2xIdBTsR/SoeYJmo +6jP+F+ODmqa5skOepuGNi2Ttp2fYR6HvmUaGp9/qWoPMHjjP5/8GLl2bXRJTsTbA +fKYNPSuIHrx1Wi2kgQhHkHx3r+dbwpRXTd/qKwYv64Z0Bd73cO6DE9SUh++UWsud +cEWsqSyLJQHNg3DxUwkCTJl1zE3AHMZowFVBq32PnbEgnV5zmbXe8pPtF9G1Riiw +ZAkzj+0KZnzlbvWzobHA4iuRKAGgyqnR0jFH4QIDAQABAoIBADnMS7U1dAao5Q9X +GrcPnP9dm63vEFU/URA7eLTZ/prZWntOczmTFz4I4lSUbNjqcsS2IsIHqN5nvi9T +uPbc4Ft9DJT2CR1R2wvKP3GY2AibBCOFbpUojPWHYqeAZ+6xyCvXgSL8R+YwBgTS +XwYD3F35b0CH1Iy/xFOsR5i8FXj7He8lOBA76fPrH64DEBTB2zUGztu4qpfv57v5 +sfTISi2ZOqPpXc+8Fw0RPeVWQgSRUh7U3lzL8bNBod6lYcjkhF5Yqet4MdHSyWMT +aKdZ2GRHHdWjpyx6J0cD/bjjaTSDqTD8r265mPzY6bq4t6UQMq4KeDnbeiextDf4 +ELT90YUCgYEA6insCSDJddhFZ51guPPyYE9GL8QQfnzLvFOA4qWsi0u9SAbJ9aS0 +vABaEuot0PyYPwMYq7st07z3DSKno4tisPJ2X7v2nEWxv8MjgczWpltPTPaEdmZE +WGIwG3pyh5wJk1b3VpBJB5jkjtJfGmUJaezU10bzm4QhPiEawemCjucCgYEAzbri +/6EZPbJJa9hGtkJEEVLwbQ2U/CE7mZXL+AcPlS3qMSwyz/1OArPxdTRR4S3sYRRO +fsRDBL8LED/kKUDWNni/zkzmFf/hVkmGd9zc6eif4Zr1gmtHlsHQdaMGxsomzxGL +qydBqDN+4TMmHmUmp2jR/0LIF5UMlNoCvHcxgfcCgYEAnOBNE6h1j4++n7Yd0IsO +PFufx+xwqGzvCVJgLHeV6xRo0NJLh1g7BSCvN7DP1Q0E6mImqxaRkyMr2A75hGWj +TqyBhY2ln/hJJxGSvij/PSA7NnKJN9E3xIazeBVGmXd+Ksm+lq2/X2mc5domgMZj +0iUqSrdsCSoyIy+Gf5bzMs0CgYBcquG044vLDpOj0DeJwS+H3iQN+yAwsYd3FtJZ +VlTejV//5ji9Fwwci5EnifmXxGfFErCIyT6m1KbXGvBa5KmYv6sl8d1x62BEzbmU +JBgeBHp/1JzhshD9BzAuzNAwmr4AZ5bR8UzRxuBP8AorhsRyg/STVjFq7ehM5CZ3 +Xfke4QKBgHCPo3R/oi/E2E7OIM/ELlDpvPQTMrV+rYlMFsy3JRvataIqEGnVbhOR +4dQHEM3u2bJxN79wUYYmZuymVB78wKxTn6hGWcGoM6Y8mrJjVv9D8V0Gc0sWw5pF +KZxuCgzjaN2T7i1LsXEV3gaQrKItToEpGPzSI23egFaG6g5SFqBt -----END RSA PRIVATE KEY----- diff --git a/.devcontainer/db/Dockerfile b/.devcontainer/db/Dockerfile index 76eb48a2fa..64cc3febb1 100644 --- a/.devcontainer/db/Dockerfile +++ b/.devcontainer/db/Dockerfile @@ -1,3 +1,3 @@ -FROM postgres:alpine -RUN apk update && \ - apk add --no-cache openssl +FROM postgres +RUN apt-get update && \ + apt-get install -y --no-install-recommends openssl postgresql-16-postgis-3 diff --git a/.devcontainer/db/init-db.sh b/.devcontainer/db/init-db.sh index 24804402fe..b4ccb371e9 100644 --- a/.devcontainer/db/init-db.sh +++ b/.devcontainer/db/init-db.sh @@ -19,7 +19,10 @@ echo "Configuring md5 authentication in $PGDATA/pg_hba.conf" echo 'local all all trust' > $PGDATA/pg_hba.conf echo "host all all all md5" >> $PGDATA/pg_hba.conf -# Standard test account for Npgsql -psql -U postgres -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'npgsql_tests'" -psql -U postgres -c "CREATE DATABASE npgsql_tests OWNER npgsql_tests" -psql -U postgres -c "CREATE EXTENSION ltree" npgsql_tests +# Standard test account for Npgsql and enable extensions +psql -U postgres <> /etc/apt/sources.list.d/pgdg.list' + + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ jammy-pgdg main ${{ matrix.pg_major }}" >> /etc/apt/sources.list.d/pgdg.list' sudo apt-get update -qq sudo apt-get install -qq postgresql-${{ matrix.pg_major }} - sudo -u postgres psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'npgsql_tests'" - sudo -u postgres psql -c "CREATE DATABASE npgsql_tests OWNER npgsql_tests" - sudo -u postgres psql -c "CREATE EXTENSION citext" npgsql_tests - sudo -u postgres psql -c "CREATE EXTENSION hstore" npgsql_tests - sudo -u postgres psql -c "CREATE EXTENSION ltree" npgsql_tests + export PGDATA=/etc/postgresql/${{ matrix.pg_major }}/main + + sudo cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} $PGDATA + sudo chmod 600 $PGDATA/{server.crt,server.key} + sudo chown postgres $PGDATA/{server.crt,server.key} + + # Create npgsql_tests user with md5 password 'npgsql_tests' + sudo -u postgres psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'md5adf74603a5772843f53e812f03dacb02'" + + sudo -u postgres psql -c "CREATE USER npgsql_tests_ssl SUPERUSER PASSWORD 'npgsql_tests_ssl'" + sudo -u postgres psql -c "CREATE USER npgsql_tests_nossl SUPERUSER PASSWORD 'npgsql_tests_nossl'" # To disable PostGIS for prereleases (because it usually isn't available until late), surround with the following: - # if [ -z "${{ matrix.pg_prerelease }}" ]; then + if [ -z "${{ matrix.pg_prerelease }}" ]; then sudo apt-get install -qq postgresql-${{ matrix.pg_major }}-postgis-${{ env.postgis_version }} - sudo -u postgres psql -c "CREATE EXTENSION postgis" npgsql_tests + fi + + if [ ${{ matrix.pg_major }} -ge 14 ]; then + sudo sed -i "s|unix_socket_directories = '/var/run/postgresql'|unix_socket_directories = '/var/run/postgresql, @/npgsql_unix'|" $PGDATA/postgresql.conf + fi - export PGDATA=/etc/postgresql/${{ matrix.pg_major }}/main + sudo sed -i 's/max_connections = 100/max_connections = 500/' $PGDATA/postgresql.conf sudo sed -i 's/#ssl = off/ssl = on/' $PGDATA/postgresql.conf - sudo sed -i 's/#max_prepared_transactions = 0/max_prepared_transactions = 10/' $PGDATA/postgresql.conf + sudo sed -i "s|ssl_cert_file =|ssl_cert_file = '$PGDATA/server.crt' #|" $PGDATA/postgresql.conf + sudo sed -i "s|ssl_key_file =|ssl_key_file = '$PGDATA/server.key' #|" $PGDATA/postgresql.conf sudo sed -i 's/#password_encryption = md5/password_encryption = scram-sha-256/' $PGDATA/postgresql.conf sudo sed -i 's/#wal_level =/wal_level = logical #/' $PGDATA/postgresql.conf sudo sed -i 's/#max_wal_senders =/max_wal_senders = 50 #/' $PGDATA/postgresql.conf + sudo sed -i 's/#logical_decoding_work_mem =/logical_decoding_work_mem = 64kB #/' $PGDATA/postgresql.conf sudo sed -i 's/#wal_sender_timeout =/wal_sender_timeout = 3s #/' $PGDATA/postgresql.conf sudo sed -i "s/#synchronous_standby_names =/synchronous_standby_names = 'npgsql_test_sync_standby' #/" $PGDATA/postgresql.conf sudo sed -i "s/#synchronous_commit =/synchronous_commit = local #/" $PGDATA/postgresql.conf + sudo sed -i "s/#max_prepared_transactions = 0/max_prepared_transactions = 100/" $PGDATA/postgresql.conf + # Disable trust authentication, requiring MD5 passwords - some tests must fail if a password isn't provided. sudo sh -c "echo 'local all all trust' > $PGDATA/pg_hba.conf" sudo sh -c "echo 'host all npgsql_tests_scram all scram-sha-256' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostssl all npgsql_tests_ssl all md5' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostnossl all npgsql_tests_ssl all reject' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostnossl all npgsql_tests_nossl all md5' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostssl all npgsql_tests_nossl all reject' >> $PGDATA/pg_hba.conf" sudo sh -c "echo 'host all all all md5' >> $PGDATA/pg_hba.conf" sudo sh -c "echo 'host replication all all md5' >> $PGDATA/pg_hba.conf" + sudo pg_ctlcluster ${{ matrix.pg_major }} main restart # user 'npgsql_tests_scram' must be created with password encrypted as scram-sha-256 (which only applies after restart) sudo -u postgres psql -c "CREATE USER npgsql_tests_scram SUPERUSER PASSWORD 'npgsql_tests_scram'" + + # Uncomment the following to SSH into the agent running the build (https://github.com/mxschmitt/action-tmate) + #- uses: actions/checkout@v4 + #- name: Setup tmate session + # uses: mxschmitt/action-tmate@v3 - name: Start PostgreSQL ${{ matrix.pg_major }} (Windows) if: startsWith(matrix.os, 'windows') @@ -102,44 +164,58 @@ jobs: unzip pgsql.zip -x 'pgsql/include/**' 'pgsql/doc/**' 'pgsql/pgAdmin 4/**' 'pgsql/StackBuilder/**' # Match Npgsql CI Docker image and stash one level up - cp {$GITHUB_WORKSPACE/.build,pgsql}/server.crt - cp {$GITHUB_WORKSPACE/.build,pgsql}/server.key + cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} pgsql # Find OSGEO version number OSGEO_VERSION=$(\ curl -Ls https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }} | - sed -n 's/.*>postgis-bundle-pg${{ matrix.pg_major }}-\(${{ env.postgis_version }}.[0-9]*.[0-9]*\)x64.zip<.*/\1/p') + sed -n 's/.*>postgis-bundle-pg${{ matrix.pg_major }}-\(${{ env.postgis_version }}.[0-9]*.[0-9]*\)x64.zip<.*/\1/p' | + tail -n 1) + if [ -z "$OSGEO_VERSION" ]; then + OSGEO_VERSION=$(\ + curl -Ls https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }}/archive | + sed -n 's/.*>postgis-bundle-pg${{ matrix.pg_major }}-\(${{ env.postgis_version }}.[0-9]*.[0-9]*\)x64.zip<.*/\1/p' | + tail -n 1) + POSTGIS_PATH="archive/" + else + POSTGIS_PATH="" + fi # Install PostGIS echo "Installing PostGIS (version: ${OSGEO_VERSION})" POSTGIS_FILE="postgis-bundle-pg${{ matrix.pg_major }}-${OSGEO_VERSION}x64" - curl -o postgis.zip -L https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }}/${POSTGIS_FILE}.zip + curl -o postgis.zip -L https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }}/${POSTGIS_PATH}${POSTGIS_FILE}.zip unzip postgis.zip -d postgis cp -a postgis/$POSTGIS_FILE/. pgsql/ # Start PostgreSQL pgsql/bin/initdb -D pgsql/PGDATA -E UTF8 -U postgres SOCKET_DIR=$(echo "$LOCALAPPDATA\Temp" | sed 's|\\|/|g') + sed -i "s|max_connections = 100|max_connections = 500|" pgsql/PGDATA/postgresql.conf sed -i "s|#unix_socket_directories = ''|unix_socket_directories = '$SOCKET_DIR'|" pgsql/PGDATA/postgresql.conf sed -i "s|#wal_level =|wal_level = logical #|" pgsql/PGDATA/postgresql.conf sed -i "s|#max_wal_senders =|max_wal_senders = 50 #|" pgsql/PGDATA/postgresql.conf + sed -i "s|#logical_decoding_work_mem =|logical_decoding_work_mem = 64kB #|" pgsql/PGDATA/postgresql.conf sed -i "s|#wal_sender_timeout =|wal_sender_timeout = 3s #|" pgsql/PGDATA/postgresql.conf sed -i "s|#synchronous_standby_names =|synchronous_standby_names = 'npgsql_test_sync_standby' #|" pgsql/PGDATA/postgresql.conf sed -i "s|#synchronous_commit =|synchronous_commit = local #|" pgsql/PGDATA/postgresql.conf - pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c max_prepared_transactions=10 -c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' start + sed -i "s|#max_prepared_transactions = 0|max_prepared_transactions = 100|" pgsql/PGDATA/postgresql.conf + pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' start - # Configure test account - pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests SUPERUSER LOGIN PASSWORD 'npgsql_tests'" - pgsql/bin/psql -U postgres -c "CREATE DATABASE npgsql_tests OWNER npgsql_tests" - pgsql/bin/psql -U postgres -c "CREATE EXTENSION citext" npgsql_tests - pgsql/bin/psql -U postgres -c "CREATE EXTENSION hstore" npgsql_tests - pgsql/bin/psql -U postgres -c "CREATE EXTENSION ltree" npgsql_tests - pgsql/bin/psql -U postgres -c "CREATE EXTENSION postgis" npgsql_tests + # Create npgsql_tests user with md5 password 'npgsql_tests' + pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests SUPERUSER LOGIN PASSWORD 'md5adf74603a5772843f53e812f03dacb02'" + + pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests_ssl SUPERUSER LOGIN PASSWORD 'npgsql_tests_ssl'" + pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests_nossl SUPERUSER LOGIN PASSWORD 'npgsql_tests_nossl'" # user 'npgsql_tests_scram' must be created with password encrypted as scram-sha-256 (which only applies after restart) - sed -i "s|#password_encryption = md5|password_encryption = scram-sha-256|" pgsql/PGDATA/postgresql.conf + if [ ${{ matrix.pg_major }} -ge 14 ]; then + sed -i "s|password_encryption = md5|password_encryption = scram-sha-256|" pgsql/PGDATA/postgresql.conf + else + sed -i "s|#password_encryption = md5|password_encryption = scram-sha-256|" pgsql/PGDATA/postgresql.conf + fi - pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c max_prepared_transactions=10 -c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' restart + pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' restart pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests_scram SUPERUSER LOGIN PASSWORD 'npgsql_tests_scram'" @@ -151,31 +227,122 @@ jobs: else echo "host all npgsql_tests_scram all scram-sha-256" > pgsql/PGDATA/pg_hba.conf fi + echo "hostssl all npgsql_tests_ssl all md5" >> pgsql/PGDATA/pg_hba.conf + echo "hostnossl all npgsql_tests_ssl all reject" >> pgsql/PGDATA/pg_hba.conf + echo "hostnossl all npgsql_tests_nossl all md5" >> pgsql/PGDATA/pg_hba.conf + echo "hostssl all npgsql_tests_nossl all reject" >> pgsql/PGDATA/pg_hba.conf echo "host all all all md5" >> pgsql/PGDATA/pg_hba.conf echo "host replication all all md5" >> pgsql/PGDATA/pg_hba.conf + - name: Start PostgreSQL ${{ matrix.pg_major }} (MacOS) + if: startsWith(matrix.os, 'macos') + run: | + PGDATA=/usr/local/var/postgresql@${{ matrix.pg_major }} + + sudo sed -i '' 's/#ssl = off/ssl = on/' $PGDATA/postgresql.conf + cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} $PGDATA + chmod 600 $PGDATA/{server.crt,server.key} + + postgreService=$(brew services list | grep -oe "postgresql\S*") + + brew services start $postgreService + echo "Check PostgreSQL service is running" + i=5 + COMMAND='pg_isready' + while [ $i -gt 0 ]; do + echo "Check PostgreSQL service status" + eval $COMMAND && break + ((i--)) + if [ $i == 0 ]; then + echo "PostgreSQL service not ready, all attempts exhausted" + exit 1 + fi + echo "PostgreSQL service not ready, wait 5 more sec, attempts left: $i" + sleep 5 + done + + # Create npgsql_tests user with md5 password 'npgsql_tests' + psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'md5adf74603a5772843f53e812f03dacb02'" postgres + + psql -c "CREATE USER npgsql_tests_ssl SUPERUSER PASSWORD 'npgsql_tests_ssl'" postgres + psql -c "CREATE USER npgsql_tests_nossl SUPERUSER PASSWORD 'npgsql_tests_nossl'" postgres + + sudo sed -i '' 's/max_connections = 100/max_connections = 500/' $PGDATA/postgresql.conf + sudo sed -i '' 's/#password_encryption = md5/password_encryption = scram-sha-256/' $PGDATA/postgresql.conf + sudo sed -i '' 's/#wal_level =/wal_level = logical #/' $PGDATA/postgresql.conf + sudo sed -i '' 's/#max_wal_senders =/max_wal_senders = 50 #/' $PGDATA/postgresql.conf + sudo sed -i '' 's/#logical_decoding_work_mem =/logical_decoding_work_mem = 64kB #/' $PGDATA/postgresql.conf + sudo sed -i '' 's/#wal_sender_timeout =/wal_sender_timeout = 3s #/' $PGDATA/postgresql.conf + sudo sed -i '' "s/#synchronous_standby_names =/synchronous_standby_names = 'npgsql_test_sync_standby' #/" $PGDATA/postgresql.conf + sudo sed -i '' "s/#synchronous_commit =/synchronous_commit = local #/" $PGDATA/postgresql.conf + sudo sed -i '' "s/#max_prepared_transactions = 0/max_prepared_transactions = 100/" $PGDATA/postgresql.conf + # Disable trust authentication, requiring MD5 passwords - some tests must fail if a password isn't provided. + sudo sh -c "echo 'local all all trust' > $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostssl all npgsql_tests_ssl all md5' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostnossl all npgsql_tests_ssl all reject' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostnossl all npgsql_tests_nossl all md5' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'hostssl all npgsql_tests_nossl all reject' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'host all npgsql_tests_scram all scram-sha-256' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'host all all all md5' >> $PGDATA/pg_hba.conf" + sudo sh -c "echo 'host replication all all md5' >> $PGDATA/pg_hba.conf" + + brew services restart $postgreService + echo "Check PostgreSQL service is running" + i=5 + COMMAND='pg_isready' + while [ $i -gt 0 ]; do + echo "Check PostgreSQL service status" + eval $COMMAND && break + ((i--)) + if [ $i == 0 ]; then + echo "PostgreSQL service not ready, all attempts exhausted" + exit 1 + fi + echo "PostgreSQL service not ready, wait 5 more sec, attempts left: $i" + sleep 5 + done + psql -c "CREATE USER npgsql_tests_scram SUPERUSER PASSWORD 'npgsql_tests_scram'" postgres + # TODO: Once test/Npgsql.Specification.Tests work, switch to just testing on the solution - name: Test - run: dotnet test test/Npgsql.Tests --logger "GitHubActions;report-warnings=false" + run: | + dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.Tests --logger "GitHubActions;report-warnings=false" + dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.DependencyInjection.Tests --logger "GitHubActions;report-warnings=false" shell: bash - name: Test Plugins - run: dotnet test test/Npgsql.PluginTests --logger "GitHubActions;report-warnings=false" + if: "!startsWith(matrix.os, 'macos')" + run: | + if [ -z "${{ matrix.pg_prerelease }}" ]; then + dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.PluginTests --logger "GitHubActions;report-warnings=false" + fi shell: bash - publish: - needs: build - - runs-on: windows-latest + - id: analyze_tag + name: Analyze tag + shell: bash + run: | + if [[ ${{ github.ref }} =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+ ]]; then + echo "Release tag detected" + echo "::set-output name=is_release::true" + if [[ ${{ github.ref }} =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+.*- ]]; then + echo "Prerelease tag detected" + echo "::set-output name=is_prerelease::true" + fi + fi + publish-ci: + needs: build + runs-on: ubuntu-22.04 if: github.event_name == 'push' && github.repository == 'npgsql/npgsql' + environment: myget steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: NuGet Cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.nuget/packages key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} @@ -183,34 +350,55 @@ jobs: ${{ runner.os }}-nuget- - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v1 + uses: actions/setup-dotnet@v3.2.0 with: dotnet-version: ${{ env.dotnet_sdk_version }} - - name: Pack NuGet packages (CI versions) - if: startsWith(github.ref, 'refs/heads/') - run: dotnet pack Npgsql.sln --configuration Release --output nupkgs --version-suffix "ci.$(date -u +%Y%m%dT%H%M%S)+sha.${GITHUB_SHA:0:9}" -p:ContinuousIntegrationBuild=true - shell: bash - - - name: Pack NuGet packages (Release versions) - if: startsWith(github.ref, 'refs/tags/v') - run: dotnet pack Npgsql.sln --configuration Release --output nupkgs -p:ContinuousIntegrationBuild=true - shell: bash + - name: Pack + run: dotnet pack Npgsql.sln --configuration Release --property:PackageOutputPath="$PWD/nupkgs" --version-suffix "ci.$(date -u +%Y%m%dT%H%M%S)+sha.${GITHUB_SHA:0:9}" -p:ContinuousIntegrationBuild=true - name: Upload artifacts (nupkg) - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: - name: Npgsql.nupkgs + name: Npgsql.CI path: nupkgs - name: Publish packages to MyGet (vnext) if: startsWith(github.ref, 'refs/heads/') && startsWith(github.ref, 'refs/heads/hotfix/') == false - run: dotnet nuget push "*.nupkg" --api-key ${{ secrets.MYGET_FEED_TOKEN }} --source https://www.myget.org/F/npgsql-unstable/api/v3/index.json + run: dotnet nuget push "*.nupkg" --api-key ${{ secrets.MYGET_FEED_TOKEN }} --source https://www.myget.org/F/npgsql-vnext/api/v3/index.json working-directory: nupkgs - shell: bash - name: Publish packages to MyGet (patch) if: startsWith(github.ref, 'refs/heads/hotfix/') run: dotnet nuget push "*.nupkg" --api-key ${{ secrets.MYGET_FEED_TOKEN }} --source https://www.myget.org/F/npgsql/api/v3/index.json working-directory: nupkgs - shell: bash + + release: + needs: build + runs-on: ubuntu-22.04 + if: github.event_name == 'push' && startsWith(github.repository, 'npgsql/') && needs.build.outputs.is_release == 'true' + environment: nuget.org + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup .NET Core SDK + uses: actions/setup-dotnet@v3.2.0 + with: + dotnet-version: ${{ env.dotnet_sdk_version }} + + - name: Pack + run: dotnet pack Npgsql.sln --configuration Release --property:PackageOutputPath="$PWD/nupkgs" -p:ContinuousIntegrationBuild=true + + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + name: Npgsql.Release + path: nupkgs + + # TODO: Create a release + + - name: Publish to nuget.org + run: dotnet nuget push "*.nupkg" --api-key ${{ secrets.NUGET_ORG_API_KEY }} --source https://api.nuget.org/v3/index.json + working-directory: nupkgs diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 0000000000..0e721ebc22 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,93 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: + - main + - 'hotfix/**' + - 'release/**' + pull_request: + # The branches below must be a subset of the branches above + branches: + - main + - 'hotfix/**' + - 'release/**' + schedule: + - cron: '21 0 * * 4' + +# Cancel previous PR branch commits (head_ref is only defined on PRs) +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + dotnet_sdk_version: '8.0.100' + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'csharp' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + - name: Setup .NET Core SDK + uses: actions/setup-dotnet@v3.2.0 + with: + dotnet-version: ${{ env.dotnet_sdk_version }} + + - name: Build + run: dotnet build -c Release + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + #- name: Autobuild + # uses: github/codeql-action/autobuild@v2 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/native-aot.yml b/.github/workflows/native-aot.yml new file mode 100644 index 0000000000..6ff04ffa5f --- /dev/null +++ b/.github/workflows/native-aot.yml @@ -0,0 +1,134 @@ +name: NativeAOT + +on: + push: + branches: + - main + - 'hotfix/**' + tags: + - '*' + pull_request: + +# Cancel previous PR branch commits (head_ref is only defined on PRs) +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + dotnet_sdk_version: '8.0.100' + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + # Uncomment and edit the following to use nightly/preview builds +# nuget_config: | +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04] + pg_major: [15] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: NuGet Cache + uses: actions/cache@v3 + with: + path: ~/.nuget/packages + key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} + restore-keys: | + ${{ runner.os }}-nuget- + + - name: Setup .NET Core SDK + uses: actions/setup-dotnet@v3.2.0 + with: + dotnet-version: | + ${{ env.dotnet_sdk_version }} + +# - name: Setup nuget config +# run: echo "$nuget_config" > NuGet.config + + - name: Setup Native AOT prerequisites + run: sudo apt-get install clang zlib1g-dev + shell: bash + + - name: Build + run: dotnet publish test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj -r linux-x64 -c Release -f net8.0 -p:OptimizationPreference=Size + shell: bash + + # Uncomment the following to SSH into the agent running the build (https://github.com/mxschmitt/action-tmate) + #- uses: actions/checkout@v4 + #- name: Setup tmate session + # uses: mxschmitt/action-tmate@v3 + + - name: Start PostgreSQL + run: | + sudo systemctl start postgresql.service + sudo -u postgres psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'npgsql_tests'" + sudo -u postgres psql -c "CREATE DATABASE npgsql_tests OWNER npgsql_tests" + + - name: Run + run: test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests + + - name: Write binary size to summary + run: | + size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" + echo "Binary size is $size bytes ($((size / (1024 * 1024))) mb)" >> $GITHUB_STEP_SUMMARY + + - name: Dump mstat + run: dotnet run --project test/MStatDumper/MStatDumper.csproj -c release -f net8.0 -- "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.mstat" md >> $GITHUB_STEP_SUMMARY + + - name: Upload mstat + uses: actions/upload-artifact@v3.1.2 + with: + name: npgsql.mstat + path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.mstat" + retention-days: 3 + + - name: Upload codedgen dgml + uses: actions/upload-artifact@v3.1.2 + with: + name: npgsql.codegen.dgml.xml + path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.codegen.dgml.xml" + retention-days: 3 + + - name: Upload scan dgml + uses: actions/upload-artifact@v3.1.2 + with: + name: npgsql.scan.dgml.xml + path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.scan.dgml.xml" + retention-days: 3 + + - name: Assert binary size + run: | + size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" + echo "Binary size is $size bytes ($((size / (1024 * 1024))) mb)" + + if (( size > 7340032 )); then + echo "Binary size exceeds 7mb threshold" + exit 1 + fi diff --git a/.github/workflows/rich-code-nav.yml b/.github/workflows/rich-code-nav.yml new file mode 100644 index 0000000000..1990c8ff78 --- /dev/null +++ b/.github/workflows/rich-code-nav.yml @@ -0,0 +1,44 @@ +name: Rich Code Navigation + +on: + push: + branches: + - main + - stable + tags: + - '*' + +env: + dotnet_sdk_version: '8.0.100' + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + +jobs: + build: + runs-on: windows-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: NuGet Cache + uses: actions/cache@v3 + with: + path: ~/.nuget/packages + key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} + restore-keys: | + ${{ runner.os }}-nuget- + + - name: Setup .NET Core SDK + uses: actions/setup-dotnet@v3.2.0 + with: + dotnet-version: ${{ env.dotnet_sdk_version }} + + - name: Build + run: dotnet build Npgsql.sln --configuration Debug + shell: bash + + - name: Rich Navigation Indexing + uses: microsoft/RichCodeNavIndexer@v0.1 + with: + languages: csharp + repo-token: ${{ github.token }} diff --git a/.github/workflows/trigger-doc-build.yml b/.github/workflows/trigger-doc-build.yml index 68796750d8..dfbe89601e 100644 --- a/.github/workflows/trigger-doc-build.yml +++ b/.github/workflows/trigger-doc-build.yml @@ -6,11 +6,11 @@ name: Trigger Documentation Build on: push: branches: - - stable + - docs jobs: build: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - name: Trigger documentation build run: | diff --git a/Directory.Build.props b/Directory.Build.props index 4500471878..a3f26fa3ef 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,9 +1,16 @@  - 5.0.0 + 8.0.3 + latest + true + enable + latest + true + $(MSBuildThisFileDirectory)Npgsql.snk + true + true - - Copyright 2020 © The Npgsql Development Team + Copyright 2023 © The Npgsql Development Team Npgsql PostgreSQL https://github.com/npgsql/npgsql @@ -12,24 +19,14 @@ true snupkg true - $(NoWarn);NU5105 - - - 9.0 - true - enable - latest - - - true - $(MSBuildThisFileDirectory)Npgsql.snk - true + $(NoWarn);NETSDK1138 + true - + disable - $(NoWarn);CS8632 + $(NoWarn);CS8632;CS8600 diff --git a/Directory.Build.targets b/Directory.Build.targets deleted file mode 100644 index 75fef39ad9..0000000000 --- a/Directory.Build.targets +++ /dev/null @@ -1,32 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/Directory.Packages.props b/Directory.Packages.props new file mode 100644 index 0000000000..2c3d303907 --- /dev/null +++ b/Directory.Packages.props @@ -0,0 +1,55 @@ + + + 8.0.0 + $(SystemVersion) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/LICENSE b/LICENSE index a4c7bd2b54..efec310cda 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2002-2019, Npgsql +Copyright (c) 2002-2023, Npgsql Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/Npgsql.sln b/Npgsql.sln index bf95732712..80ef02c3a8 100644 --- a/Npgsql.sln +++ b/Npgsql.sln @@ -19,8 +19,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.NodaTime", "src\Npgs EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.PluginTests", "test\Npgsql.PluginTests\Npgsql.PluginTests.csproj", "{9BD7FC3D-6956-42A8-A586-2558C499EBA2}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.LegacyPostgis", "src\Npgsql.LegacyPostgis\Npgsql.LegacyPostgis.csproj", "{D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.NetTopologySuite", "src\Npgsql.NetTopologySuite\Npgsql.NetTopologySuite.csproj", "{6CB12050-DC9B-4155-BADD-BFDD54CDD70F}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.GeoJSON", "src\Npgsql.GeoJSON\Npgsql.GeoJSON.csproj", "{F7C53EBD-0075-474F-A083-419257D04080}" @@ -30,11 +28,32 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{004A2E0F-D34A-44D4-8DF0-D2BC63B57073}" ProjectSection(SolutionItems) = preProject .editorconfig = .editorconfig - .github\workflows\build.yml = .github\workflows\build.yml Directory.Build.props = Directory.Build.props - Directory.Build.targets = Directory.Build.targets + Directory.Packages.props = Directory.Packages.props + README.md = README.md + global.json = global.json + NuGet.config = NuGet.config EndProjectSection EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.SourceGenerators", "src\Npgsql.SourceGenerators\Npgsql.SourceGenerators.csproj", "{63026A19-60B8-4906-81CB-216F30E8094B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.OpenTelemetry", "src\Npgsql.OpenTelemetry\Npgsql.OpenTelemetry.csproj", "{DA29F063-1828-47D8-B051-800AF7C9A0BE}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Github", "Github", "{BA7B6F53-D24D-45AC-927A-266857EA8D1E}" + ProjectSection(SolutionItems) = preProject + .github\workflows\build.yml = .github\workflows\build.yml + .github\dependabot.yml = .github\dependabot.yml + .github\workflows\codeql-analysis.yml = .github\workflows\codeql-analysis.yml + .github\workflows\rich-code-nav.yml = .github\workflows\rich-code-nav.yml + .github\workflows\native-aot.yml = .github\workflows\native-aot.yml + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.DependencyInjection", "src\Npgsql.DependencyInjection\Npgsql.DependencyInjection.csproj", "{B58E12EB-E43D-4D77-894E-5157D2269836}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.DependencyInjection.Tests", "test\Npgsql.DependencyInjection.Tests\Npgsql.DependencyInjection.Tests.csproj", "{EB2530FC-69F7-4DCB-A8B3-3671A157ED32}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.NativeAotTests", "test\Npgsql.NativeAotTests\Npgsql.NativeAotTests.csproj", "{20F2E9D6-A69E-4BAE-9236-574B0AA59139}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -91,14 +110,6 @@ Global {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|Any CPU.Build.0 = Release|Any CPU {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|x86.ActiveCfg = Release|Any CPU {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|x86.Build.0 = Release|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Debug|x86.ActiveCfg = Debug|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Debug|x86.Build.0 = Debug|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Release|Any CPU.Build.0 = Release|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Release|x86.ActiveCfg = Release|Any CPU - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE}.Release|x86.Build.0 = Release|Any CPU {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|Any CPU.Build.0 = Debug|Any CPU {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|x86.ActiveCfg = Debug|Any CPU @@ -123,6 +134,46 @@ Global {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|Any CPU.Build.0 = Release|Any CPU {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|x86.ActiveCfg = Release|Any CPU {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|x86.Build.0 = Release|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|x86.ActiveCfg = Debug|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|x86.Build.0 = Debug|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Release|Any CPU.Build.0 = Release|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.ActiveCfg = Release|Any CPU + {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.Build.0 = Release|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|x86.ActiveCfg = Debug|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|x86.Build.0 = Debug|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|Any CPU.Build.0 = Release|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|x86.ActiveCfg = Release|Any CPU + {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|x86.Build.0 = Release|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|x86.ActiveCfg = Debug|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|x86.Build.0 = Debug|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|Any CPU.Build.0 = Release|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|x86.ActiveCfg = Release|Any CPU + {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|x86.Build.0 = Release|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|x86.ActiveCfg = Debug|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|x86.Build.0 = Debug|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|Any CPU.Build.0 = Release|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|x86.ActiveCfg = Release|Any CPU + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|x86.Build.0 = Release|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|Any CPU.Build.0 = Debug|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|x86.ActiveCfg = Debug|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|x86.Build.0 = Debug|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|Any CPU.ActiveCfg = Release|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|Any CPU.Build.0 = Release|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|x86.ActiveCfg = Release|Any CPU + {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -134,10 +185,15 @@ Global {9CBE603F-6746-411D-A5FD-CB2C948CD7D0} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {D8DF12D6-FA70-4653-BD8F-C188944836DE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {9BD7FC3D-6956-42A8-A586-2558C499EBA2} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {D96CC113-7D64-4B31-9DCC-13FDE92C1ECE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {6CB12050-DC9B-4155-BADD-BFDD54CDD70F} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {F7C53EBD-0075-474F-A083-419257D04080} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} {A77E5FAF-D775-4AB4-8846-8965C2104E60} = {ED612DB1-AB32-4603-95E7-891BACA71C39} + {63026A19-60B8-4906-81CB-216F30E8094B} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} + {DA29F063-1828-47D8-B051-800AF7C9A0BE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} + {BA7B6F53-D24D-45AC-927A-266857EA8D1E} = {004A2E0F-D34A-44D4-8DF0-D2BC63B57073} + {B58E12EB-E43D-4D77-894E-5157D2269836} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} + {EB2530FC-69F7-4DCB-A8B3-3671A157ED32} = {ED612DB1-AB32-4603-95E7-891BACA71C39} + {20F2E9D6-A69E-4BAE-9236-574B0AA59139} = {ED612DB1-AB32-4603-95E7-891BACA71C39} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {C90AEECD-DB4C-4BE6-B506-16A449852FB8} diff --git a/Npgsql.sln.DotSettings b/Npgsql.sln.DotSettings index 09e69a826c..890df2d4be 100644 --- a/Npgsql.sln.DotSettings +++ b/Npgsql.sln.DotSettings @@ -81,26 +81,67 @@ True True True + True True + True True + True + True + True True + True True + True + True + True + True + True + True True + True True + True + True True True True True + True + True + True + True + True True + True True True True True + True + True + True + True True + True + True + True + True + True + True True True + True + True + True + True + True + True + True True True + True True True + True + True True diff --git a/NuGet.config b/NuGet.config new file mode 100644 index 0000000000..e49ffd89d8 --- /dev/null +++ b/NuGet.config @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/README.md b/README.md index 06114eba27..2b7bfef019 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,53 @@ # Npgsql - the .NET data provider for PostgreSQL [![stable](https://img.shields.io/nuget/v/Npgsql.svg?label=stable)](https://www.nuget.org/packages/Npgsql/) -[![unstable](https://img.shields.io/myget/npgsql-unstable/v/npgsql.svg?label=unstable)](https://www.myget.org/feed/npgsql-unstable/package/nuget/Npgsql) [![next patch](https://img.shields.io/myget/npgsql/v/npgsql.svg?label=next%20patch)](https://www.myget.org/feed/npgsql/package/nuget/Npgsql) -[![build](https://img.shields.io/github/workflow/status/npgsql/npgsql/Build)](https://github.com/npgsql/npgsql/actions) +[![daily builds (vnext)](https://img.shields.io/myget/npgsql-vnext/v/npgsql.svg?label=vnext)](https://www.myget.org/feed/npgsql-vnext/package/nuget/Npgsql) +[![build](https://github.com/npgsql/npgsql/actions/workflows/build.yml/badge.svg)](https://github.com/npgsql/npgsql/actions/workflows/build.yml) [![gitter](https://img.shields.io/badge/gitter-join%20chat-brightgreen.svg)](https://gitter.im/npgsql/npgsql) ## What is Npgsql? -Npgsql is a .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. -For any additional information, please visit the Npgsql website at [https://www.npgsql.org](https://www.npgsql.org). +For the full documentation, please visit [the Npgsql website](https://www.npgsql.org). For the Entity Framework Core provider that works with this provider, see [Npgsql.EntityFrameworkCore.PostgreSQL](https://github.com/npgsql/efcore.pg). + +## Quickstart + +Here's a basic code snippet to get you started: + +```csharp +using Npgsql; + +var connString = "Host=myserver;Username=mylogin;Password=mypass;Database=mydatabase"; + +var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString); +var dataSource = dataSourceBuilder.Build(); + +var conn = await dataSource.OpenConnectionAsync(); + +// Insert some data +await using (var cmd = new NpgsqlCommand("INSERT INTO data (some_field) VALUES (@p)", conn)) +{ + cmd.Parameters.AddWithValue("p", "Hello world"); + await cmd.ExecuteNonQueryAsync(); +} + +// Retrieve all rows +await using (var cmd = new NpgsqlCommand("SELECT some_field FROM data", conn)) +await using (var reader = await cmd.ExecuteReaderAsync()) +{ + while (await reader.ReadAsync()) + Console.WriteLine(reader.GetString(0)); +} +``` + +## Key features + +* High-performance PostgreSQL driver. Regularly figures in the top contenders on the [TechEmpower Web Framework Benchmarks](https://www.techempower.com/benchmarks/). +* Full support of most PostgreSQL types, including advanced ones such as arrays, enums, ranges, multiranges, composites, JSON, PostGIS and others. +* Highly-efficient bulk import/export API. +* Failover, load balancing and general multi-host support. +* Great integration with Entity Framework Core via [Npgsql.EntityFrameworkCore.PostgreSQL](https://www.nuget.org/packages/Npgsql.EntityFrameworkCore.PostgreSQL). + +For the full documentation, please visit the Npgsql website at [https://www.npgsql.org](https://www.npgsql.org). diff --git a/global.json b/global.json index f9f3192fb0..c4fc1c4611 100644 --- a/global.json +++ b/global.json @@ -1,7 +1,7 @@ { "sdk": { - "version": "5.0.100", + "version": "8.0.100", "rollForward": "latestMajor", - "allowPrerelease": "true" + "allowPrerelease": "false" } } diff --git a/src/.editorconfig b/src/.editorconfig new file mode 100644 index 0000000000..6574a9291a --- /dev/null +++ b/src/.editorconfig @@ -0,0 +1,16 @@ +# Public API Analyzers + +root = false + +[*.cs] + +# Constructor make noninheritable base class inheritable +dotnet_diagnostic.RS0022.severity = none + +# Do not add multiple public overloads with optional parameters +dotnet_diagnostic.RS0026.severity = none + +# Public API with optional parameter(s) should have the most parameters amongst its public overloads. +dotnet_diagnostic.RS0027.severity = none + +dotnet_diagnostic.CA2007.severity = warning; diff --git a/src/Directory.Build.props b/src/Directory.Build.props index b94a8a91bd..169a5988a2 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -3,8 +3,14 @@ true + + true + + + + diff --git a/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj b/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj new file mode 100644 index 0000000000..b3b92f69c6 --- /dev/null +++ b/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj @@ -0,0 +1,27 @@ + + + + Shay Rojansky + + netstandard2.0;net7.0 + net8.0 + npgsql;postgresql;postgres;ado;ado.net;database;sql;di;dependency injection + README.md + + + + + + + + + + + + + + + + + + diff --git a/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs new file mode 100644 index 0000000000..6e2b4e7d4f --- /dev/null +++ b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs @@ -0,0 +1,220 @@ +using System; +using System.ComponentModel; +using Npgsql; + +namespace Microsoft.Extensions.DependencyInjection; + +public static partial class NpgsqlServiceCollectionExtensions +{ + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); +} diff --git a/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs new file mode 100644 index 0000000000..7e22029a40 --- /dev/null +++ b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs @@ -0,0 +1,551 @@ +using System; +using System.Data.Common; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using Npgsql; + +// ReSharper disable once CheckNamespace +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension method for setting up Npgsql services in an . +/// +public static partial class NpgsqlServiceCollectionExtensions +{ + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder) + , connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The of the data source. + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + static IServiceCollection AddNpgsqlDataSourceCore( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action? dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime, + object? state) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + serviceKey, + (sp, key) => + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); + dataSourceBuilder.UseLoggerFactory(sp.GetService()); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); + return dataSourceBuilder.Build(); + }, + dataSourceLifetime)); + + AddCommonServices(serviceCollection, serviceKey, connectionLifetime, dataSourceLifetime); + + return serviceCollection; + } + + static IServiceCollection AddNpgsqlSlimDataSourceCore( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action? dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime, + object? state) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + serviceKey, + (sp, key) => + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(connectionString); + dataSourceBuilder.UseLoggerFactory(sp.GetService()); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); + return dataSourceBuilder.Build(); + }, + dataSourceLifetime)); + + AddCommonServices(serviceCollection, serviceKey, connectionLifetime, dataSourceLifetime); + + return serviceCollection; + } + + static IServiceCollection AddMultiHostNpgsqlDataSourceCore( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action? dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime, + object? state) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlMultiHostDataSource), + serviceKey, + (sp, key) => + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); + dataSourceBuilder.UseLoggerFactory(sp.GetService()); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); + return dataSourceBuilder.BuildMultiHost(); + }, + dataSourceLifetime)); + + if (serviceKey is not null) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + serviceKey, + (sp, key) => sp.GetRequiredKeyedService(key), + dataSourceLifetime)); + } + else + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + sp => sp.GetRequiredService(), + dataSourceLifetime)); + + } + + AddCommonServices(serviceCollection, serviceKey, connectionLifetime, dataSourceLifetime); + + return serviceCollection; + } + + static IServiceCollection AddMultiHostNpgsqlSlimDataSourceCore( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action? dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime, + object? state) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlMultiHostDataSource), + serviceKey, + (sp, _) => + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(connectionString); + dataSourceBuilder.UseLoggerFactory(sp.GetService()); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); + return dataSourceBuilder.BuildMultiHost(); + }, + dataSourceLifetime)); + + if (serviceKey is not null) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + serviceKey, + (sp, key) => sp.GetRequiredKeyedService(key), + dataSourceLifetime)); + } + else + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlDataSource), + sp => sp.GetRequiredService(), + dataSourceLifetime)); + + } + + AddCommonServices(serviceCollection, serviceKey, connectionLifetime, dataSourceLifetime); + + return serviceCollection; + } + + static void AddCommonServices( + IServiceCollection serviceCollection, + object? serviceKey, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + { + // We don't try to invoke KeyedService methods if there is no service key. + // This allows user code that use non-standard containers without support for IKeyedServiceProvider to keep on working. + if (serviceKey is not null) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlConnection), + serviceKey, + (sp, key) => sp.GetRequiredKeyedService(key).CreateConnection(), + connectionLifetime)); + + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(DbDataSource), + serviceKey, + (sp, key) => sp.GetRequiredKeyedService(key), + dataSourceLifetime)); + + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(DbConnection), + serviceKey, + (sp, key) => sp.GetRequiredKeyedService(key), + connectionLifetime)); + } + else + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(NpgsqlConnection), + sp => sp.GetRequiredService().CreateConnection(), + connectionLifetime)); + + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(DbDataSource), + sp => sp.GetRequiredService(), + dataSourceLifetime)); + + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(DbConnection), + sp => sp.GetRequiredService(), + connectionLifetime)); + } + } +} diff --git a/src/Npgsql.DependencyInjection/Properties/AssemblyInfo.cs b/src/Npgsql.DependencyInjection/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1a340b1a15 --- /dev/null +++ b/src/Npgsql.DependencyInjection/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif diff --git a/src/Npgsql.DependencyInjection/PublicAPI.Shipped.txt b/src/Npgsql.DependencyInjection/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..4066bf5273 --- /dev/null +++ b/src/Npgsql.DependencyInjection/PublicAPI.Shipped.txt @@ -0,0 +1,18 @@ +#nullable enable +Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt b/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..34f2d889e9 --- /dev/null +++ b/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt @@ -0,0 +1,5 @@ +#nullable enable +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Npgsql.DependencyInjection/README.md b/src/Npgsql.DependencyInjection/README.md new file mode 100644 index 0000000000..7b22c8d15d --- /dev/null +++ b/src/Npgsql.DependencyInjection/README.md @@ -0,0 +1,80 @@ +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. + +This package helps set up Npgsql in applications using dependency injection, notably ASP.NET applications. It allows easy configuration of your Npgsql connections and registers the appropriate services in your DI container. + +For example, if using the ASP.NET minimal web API, simply use the following to register Npgsql: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddNpgsqlDataSource("Host=pg_server;Username=test;Password=test;Database=test"); +``` + +This registers a transient [`NpgsqlConnection`](https://www.npgsql.org/doc/api/Npgsql.NpgsqlConnection.html) which can get injected into your controllers: + +```csharp +app.MapGet("/", async (NpgsqlConnection connection) => +{ + await connection.OpenAsync(); + await using var command = new NpgsqlCommand("SELECT number FROM data LIMIT 1", connection); + return "Hello World: " + await command.ExecuteScalarAsync(); +}); +``` + +But wait! If all you want is to execute some simple SQL, just use the singleton [`NpgsqlDataSource`](https://www.npgsql.org/doc/api/Npgsql.NpgsqlDataSource.html) to execute a command directly: + +```csharp +app.MapGet("/", async (NpgsqlDataSource dataSource) => +{ + await using var command = dataSource.CreateCommand("SELECT number FROM data LIMIT 1"); + return "Hello World: " + await command.ExecuteScalarAsync(); +}); +``` + +[`NpgsqlDataSource`](https://www.npgsql.org/doc/api/Npgsql.NpgsqlDataSource.html) can also come in handy when you need more than one connection: + +```csharp +app.MapGet("/", async (NpgsqlDataSource dataSource) => +{ + await using var connection1 = await dataSource.OpenConnectionAsync(); + await using var connection2 = await dataSource.OpenConnectionAsync(); + // Use the two connections... +}); +``` + +The `AddNpgsqlDataSource` method also accepts a lambda parameter allowing you to configure aspects of Npgsql beyond the connection string, e.g. to configure `UseLoggerFactory` and `UseNetTopologySuite`: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddNpgsqlDataSource( + "Host=pg_server;Username=test;Password=test;Database=test", + builder => builder + .UseLoggerFactory(loggerFactory) + .UseNetTopologySuite()); +``` + +Finally, starting with Npgsql and .NET 8.0, you can now register multiple data sources (and connections), using a service key to distinguish between them: + +```c# +var builder = WebApplication.CreateBuilder(args); + +builder.Services + .AddNpgsqlDataSource("Host=localhost;Database=CustomersDB;Username=test;Password=test", serviceKey: DatabaseType.CustomerDb) + .AddNpgsqlDataSource("Host=localhost;Database=OrdersDB;Username=test;Password=test", serviceKey: DatabaseType.OrdersDb); + +var app = builder.Build(); + +app.MapGet("/", async ([FromKeyedServices(DatabaseType.OrdersDb)] NpgsqlConnection connection) + => connection.ConnectionString); + +app.Run(); + +enum DatabaseType +{ + CustomerDb, + OrdersDb +} +``` + +For more information, [see the Npgsql documentation](https://www.npgsql.org/doc/index.html). diff --git a/src/Npgsql.GeoJSON/BoundingBoxBuilder.cs b/src/Npgsql.GeoJSON/BoundingBoxBuilder.cs deleted file mode 100644 index 8ae4ce8b69..0000000000 --- a/src/Npgsql.GeoJSON/BoundingBoxBuilder.cs +++ /dev/null @@ -1,54 +0,0 @@ -using GeoJSON.Net.Geometry; - -namespace Npgsql.GeoJSON -{ - sealed class BoundingBoxBuilder - { - bool _hasAltitude; - double _minLongitude, _maxLongitude; - double _minLatitude, _maxLatitude; - double _minAltitude, _maxAltitude; - - internal BoundingBoxBuilder() - { - _hasAltitude = false; - - _minLongitude = double.PositiveInfinity; - _minLatitude = double.PositiveInfinity; - _minAltitude = double.PositiveInfinity; - - _maxLongitude = double.NegativeInfinity; - _maxLatitude = double.NegativeInfinity; - _maxAltitude = double.NegativeInfinity; - } - - internal void Accumulate(Position position) - { - if (_minLongitude > position.Longitude) - _minLongitude = position.Longitude; - if (_maxLongitude < position.Longitude) - _maxLongitude = position.Longitude; - - if (_minLatitude > position.Latitude) - _minLatitude = position.Latitude; - if (_maxLatitude < position.Latitude) - _maxLatitude = position.Latitude; - - if (position.Altitude.HasValue) - { - var altitude = position.Altitude.Value; - if (_minAltitude > altitude) - _minAltitude = altitude; - if (_maxAltitude < altitude) - _maxAltitude = altitude; - - _hasAltitude = true; - } - } - - internal double[] Build() - => _hasAltitude - ? new[] { _minLongitude, _minLatitude, _minAltitude, _maxLongitude, _maxLatitude, _maxAltitude } - : new[] { _minLongitude, _minLatitude, _maxLongitude, _maxLatitude }; - } -} diff --git a/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs index 23091be713..14da2f893e 100644 --- a/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs +++ b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs @@ -1,590 +1,589 @@ -namespace Npgsql.GeoJSON +namespace Npgsql.GeoJSON; + +public partial class CrsMap { - readonly partial struct CrsMap + /// + /// These entries came from spatial_res_sys. They are used to elide memory allocations + /// if they are identical to the entries for the current connection. Otherwise, + /// memory allocated for overrided entries only (added, removed, or modified). + /// + internal static readonly CrsMapEntry[] WellKnown = { - /// - /// These entries came from spatial_res_sys. They are used to elide memory allocations - /// if they are identical to the entries for the current connection. Otherwise, - /// memory allocated for overrided entries only (added, removed, or modified). - /// - internal static readonly CrsMapEntry[] WellKnown = - { - new CrsMapEntry(2000, 2180, "EPSG"), - new CrsMapEntry(2188, 2217, "EPSG"), - new CrsMapEntry(2219, 2220, "EPSG"), - new CrsMapEntry(2222, 2292, "EPSG"), - new CrsMapEntry(2294, 2295, "EPSG"), - new CrsMapEntry(2308, 2962, "EPSG"), - new CrsMapEntry(2964, 2973, "EPSG"), - new CrsMapEntry(2975, 2984, "EPSG"), - new CrsMapEntry(2987, 3051, "EPSG"), - new CrsMapEntry(3054, 3138, "EPSG"), - new CrsMapEntry(3140, 3143, "EPSG"), - new CrsMapEntry(3146, 3172, "EPSG"), - new CrsMapEntry(3174, 3294, "EPSG"), - new CrsMapEntry(3296, 3791, "EPSG"), - new CrsMapEntry(3793, 3802, "EPSG"), - new CrsMapEntry(3812, 3812, "EPSG"), - new CrsMapEntry(3814, 3816, "EPSG"), - new CrsMapEntry(3819, 3819, "EPSG"), - new CrsMapEntry(3821, 3822, "EPSG"), - new CrsMapEntry(3824, 3829, "EPSG"), - new CrsMapEntry(3832, 3852, "EPSG"), - new CrsMapEntry(3854, 3854, "EPSG"), - new CrsMapEntry(3857, 3857, "EPSG"), - new CrsMapEntry(3873, 3885, "EPSG"), - new CrsMapEntry(3887, 3887, "EPSG"), - new CrsMapEntry(3889, 3893, "EPSG"), - new CrsMapEntry(3901, 3903, "EPSG"), - new CrsMapEntry(3906, 3912, "EPSG"), - new CrsMapEntry(3920, 3920, "EPSG"), - new CrsMapEntry(3942, 3950, "EPSG"), - new CrsMapEntry(3968, 3970, "EPSG"), - new CrsMapEntry(3973, 3976, "EPSG"), - new CrsMapEntry(3978, 3979, "EPSG"), - new CrsMapEntry(3985, 3989, "EPSG"), - new CrsMapEntry(3991, 3992, "EPSG"), - new CrsMapEntry(3994, 3997, "EPSG"), - new CrsMapEntry(4000, 4016, "EPSG"), - new CrsMapEntry(4018, 4039, "EPSG"), - new CrsMapEntry(4041, 4063, "EPSG"), - new CrsMapEntry(4071, 4071, "EPSG"), - new CrsMapEntry(4073, 4073, "EPSG"), - new CrsMapEntry(4075, 4075, "EPSG"), - new CrsMapEntry(4079, 4079, "EPSG"), - new CrsMapEntry(4081, 4083, "EPSG"), - new CrsMapEntry(4087, 4088, "EPSG"), - new CrsMapEntry(4093, 4100, "EPSG"), - new CrsMapEntry(4120, 4176, "EPSG"), - new CrsMapEntry(4178, 4185, "EPSG"), - new CrsMapEntry(4188, 4289, "EPSG"), - new CrsMapEntry(4291, 4304, "EPSG"), - new CrsMapEntry(4306, 4319, "EPSG"), - new CrsMapEntry(4322, 4322, "EPSG"), - new CrsMapEntry(4324, 4324, "EPSG"), - new CrsMapEntry(4326, 4326, "EPSG"), - new CrsMapEntry(4328, 4328, "EPSG"), - new CrsMapEntry(4330, 4338, "EPSG"), - new CrsMapEntry(4340, 4340, "EPSG"), - new CrsMapEntry(4342, 4342, "EPSG"), - new CrsMapEntry(4344, 4344, "EPSG"), - new CrsMapEntry(4346, 4346, "EPSG"), - new CrsMapEntry(4348, 4348, "EPSG"), - new CrsMapEntry(4350, 4350, "EPSG"), - new CrsMapEntry(4352, 4352, "EPSG"), - new CrsMapEntry(4354, 4354, "EPSG"), - new CrsMapEntry(4356, 4356, "EPSG"), - new CrsMapEntry(4358, 4358, "EPSG"), - new CrsMapEntry(4360, 4360, "EPSG"), - new CrsMapEntry(4362, 4362, "EPSG"), - new CrsMapEntry(4364, 4364, "EPSG"), - new CrsMapEntry(4366, 4366, "EPSG"), - new CrsMapEntry(4368, 4368, "EPSG"), - new CrsMapEntry(4370, 4370, "EPSG"), - new CrsMapEntry(4372, 4372, "EPSG"), - new CrsMapEntry(4374, 4374, "EPSG"), - new CrsMapEntry(4376, 4376, "EPSG"), - new CrsMapEntry(4378, 4378, "EPSG"), - new CrsMapEntry(4380, 4380, "EPSG"), - new CrsMapEntry(4382, 4382, "EPSG"), - new CrsMapEntry(4384, 4385, "EPSG"), - new CrsMapEntry(4387, 4387, "EPSG"), - new CrsMapEntry(4389, 4415, "EPSG"), - new CrsMapEntry(4417, 4434, "EPSG"), - new CrsMapEntry(4437, 4439, "EPSG"), - new CrsMapEntry(4455, 4457, "EPSG"), - new CrsMapEntry(4462, 4463, "EPSG"), - new CrsMapEntry(4465, 4465, "EPSG"), - new CrsMapEntry(4467, 4468, "EPSG"), - new CrsMapEntry(4470, 4471, "EPSG"), - new CrsMapEntry(4473, 4475, "EPSG"), - new CrsMapEntry(4479, 4479, "EPSG"), - new CrsMapEntry(4481, 4481, "EPSG"), - new CrsMapEntry(4483, 4556, "EPSG"), - new CrsMapEntry(4558, 4559, "EPSG"), - new CrsMapEntry(4568, 4589, "EPSG"), - new CrsMapEntry(4600, 4647, "EPSG"), - new CrsMapEntry(4652, 4824, "EPSG"), - new CrsMapEntry(4826, 4826, "EPSG"), - new CrsMapEntry(4839, 4839, "EPSG"), - new CrsMapEntry(4855, 4880, "EPSG"), - new CrsMapEntry(4882, 4882, "EPSG"), - new CrsMapEntry(4884, 4884, "EPSG"), - new CrsMapEntry(4886, 4886, "EPSG"), - new CrsMapEntry(4888, 4888, "EPSG"), - new CrsMapEntry(4890, 4890, "EPSG"), - new CrsMapEntry(4892, 4892, "EPSG"), - new CrsMapEntry(4894, 4894, "EPSG"), - new CrsMapEntry(4896, 4897, "EPSG"), - new CrsMapEntry(4899, 4899, "EPSG"), - new CrsMapEntry(4901, 4904, "EPSG"), - new CrsMapEntry(4906, 4906, "EPSG"), - new CrsMapEntry(4908, 4908, "EPSG"), - new CrsMapEntry(4910, 4920, "EPSG"), - new CrsMapEntry(4922, 4922, "EPSG"), - new CrsMapEntry(4924, 4924, "EPSG"), - new CrsMapEntry(4926, 4926, "EPSG"), - new CrsMapEntry(4928, 4928, "EPSG"), - new CrsMapEntry(4930, 4930, "EPSG"), - new CrsMapEntry(4932, 4932, "EPSG"), - new CrsMapEntry(4934, 4934, "EPSG"), - new CrsMapEntry(4936, 4936, "EPSG"), - new CrsMapEntry(4938, 4938, "EPSG"), - new CrsMapEntry(4940, 4940, "EPSG"), - new CrsMapEntry(4942, 4942, "EPSG"), - new CrsMapEntry(4944, 4944, "EPSG"), - new CrsMapEntry(4946, 4946, "EPSG"), - new CrsMapEntry(4948, 4948, "EPSG"), - new CrsMapEntry(4950, 4950, "EPSG"), - new CrsMapEntry(4952, 4952, "EPSG"), - new CrsMapEntry(4954, 4954, "EPSG"), - new CrsMapEntry(4956, 4956, "EPSG"), - new CrsMapEntry(4958, 4958, "EPSG"), - new CrsMapEntry(4960, 4960, "EPSG"), - new CrsMapEntry(4962, 4962, "EPSG"), - new CrsMapEntry(4964, 4964, "EPSG"), - new CrsMapEntry(4966, 4966, "EPSG"), - new CrsMapEntry(4968, 4968, "EPSG"), - new CrsMapEntry(4970, 4970, "EPSG"), - new CrsMapEntry(4972, 4972, "EPSG"), - new CrsMapEntry(4974, 4974, "EPSG"), - new CrsMapEntry(4976, 4976, "EPSG"), - new CrsMapEntry(4978, 4978, "EPSG"), - new CrsMapEntry(4980, 4980, "EPSG"), - new CrsMapEntry(4982, 4982, "EPSG"), - new CrsMapEntry(4984, 4984, "EPSG"), - new CrsMapEntry(4986, 4986, "EPSG"), - new CrsMapEntry(4988, 4988, "EPSG"), - new CrsMapEntry(4990, 4990, "EPSG"), - new CrsMapEntry(4992, 4992, "EPSG"), - new CrsMapEntry(4994, 4994, "EPSG"), - new CrsMapEntry(4996, 4996, "EPSG"), - new CrsMapEntry(4998, 4998, "EPSG"), - new CrsMapEntry(5011, 5011, "EPSG"), - new CrsMapEntry(5013, 5016, "EPSG"), - new CrsMapEntry(5018, 5018, "EPSG"), - new CrsMapEntry(5041, 5042, "EPSG"), - new CrsMapEntry(5048, 5048, "EPSG"), - new CrsMapEntry(5069, 5072, "EPSG"), - new CrsMapEntry(5105, 5130, "EPSG"), - new CrsMapEntry(5132, 5132, "EPSG"), - new CrsMapEntry(5167, 5188, "EPSG"), - new CrsMapEntry(5221, 5221, "EPSG"), - new CrsMapEntry(5223, 5223, "EPSG"), - new CrsMapEntry(5228, 5229, "EPSG"), - new CrsMapEntry(5233, 5235, "EPSG"), - new CrsMapEntry(5243, 5244, "EPSG"), - new CrsMapEntry(5246, 5247, "EPSG"), - new CrsMapEntry(5250, 5250, "EPSG"), - new CrsMapEntry(5252, 5259, "EPSG"), - new CrsMapEntry(5262, 5262, "EPSG"), - new CrsMapEntry(5264, 5264, "EPSG"), - new CrsMapEntry(5266, 5266, "EPSG"), - new CrsMapEntry(5269, 5275, "EPSG"), - new CrsMapEntry(5292, 5311, "EPSG"), - new CrsMapEntry(5316, 5316, "EPSG"), - new CrsMapEntry(5318, 5318, "EPSG"), - new CrsMapEntry(5320, 5322, "EPSG"), - new CrsMapEntry(5324, 5325, "EPSG"), - new CrsMapEntry(5329, 5332, "EPSG"), - new CrsMapEntry(5337, 5337, "EPSG"), - new CrsMapEntry(5340, 5341, "EPSG"), - new CrsMapEntry(5343, 5349, "EPSG"), - new CrsMapEntry(5352, 5352, "EPSG"), - new CrsMapEntry(5354, 5358, "EPSG"), - new CrsMapEntry(5360, 5363, "EPSG"), - new CrsMapEntry(5365, 5365, "EPSG"), - new CrsMapEntry(5367, 5369, "EPSG"), - new CrsMapEntry(5371, 5371, "EPSG"), - new CrsMapEntry(5373, 5373, "EPSG"), - new CrsMapEntry(5379, 5379, "EPSG"), - new CrsMapEntry(5381, 5383, "EPSG"), - new CrsMapEntry(5387, 5389, "EPSG"), - new CrsMapEntry(5391, 5391, "EPSG"), - new CrsMapEntry(5393, 5393, "EPSG"), - new CrsMapEntry(5396, 5396, "EPSG"), - new CrsMapEntry(5451, 5451, "EPSG"), - new CrsMapEntry(5456, 5464, "EPSG"), - new CrsMapEntry(5466, 5467, "EPSG"), - new CrsMapEntry(5469, 5469, "EPSG"), - new CrsMapEntry(5472, 5472, "EPSG"), - new CrsMapEntry(5479, 5482, "EPSG"), - new CrsMapEntry(5487, 5487, "EPSG"), - new CrsMapEntry(5489, 5490, "EPSG"), - new CrsMapEntry(5498, 5500, "EPSG"), - new CrsMapEntry(5513, 5514, "EPSG"), - new CrsMapEntry(5518, 5520, "EPSG"), - new CrsMapEntry(5523, 5524, "EPSG"), - new CrsMapEntry(5527, 5527, "EPSG"), - new CrsMapEntry(5530, 5539, "EPSG"), - new CrsMapEntry(5544, 5544, "EPSG"), - new CrsMapEntry(5546, 5546, "EPSG"), - new CrsMapEntry(5550, 5552, "EPSG"), - new CrsMapEntry(5554, 5556, "EPSG"), - new CrsMapEntry(5558, 5559, "EPSG"), - new CrsMapEntry(5561, 5583, "EPSG"), - new CrsMapEntry(5588, 5589, "EPSG"), - new CrsMapEntry(5591, 5591, "EPSG"), - new CrsMapEntry(5593, 5593, "EPSG"), - new CrsMapEntry(5596, 5596, "EPSG"), - new CrsMapEntry(5598, 5598, "EPSG"), - new CrsMapEntry(5623, 5625, "EPSG"), - new CrsMapEntry(5627, 5629, "EPSG"), - new CrsMapEntry(5631, 5639, "EPSG"), - new CrsMapEntry(5641, 5641, "EPSG"), - new CrsMapEntry(5643, 5644, "EPSG"), - new CrsMapEntry(5646, 5646, "EPSG"), - new CrsMapEntry(5649, 5655, "EPSG"), - new CrsMapEntry(5659, 5659, "EPSG"), - new CrsMapEntry(5663, 5685, "EPSG"), - new CrsMapEntry(5698, 5700, "EPSG"), - new CrsMapEntry(5707, 5708, "EPSG"), - new CrsMapEntry(5825, 5825, "EPSG"), - new CrsMapEntry(5828, 5828, "EPSG"), - new CrsMapEntry(5832, 5837, "EPSG"), - new CrsMapEntry(5839, 5839, "EPSG"), - new CrsMapEntry(5842, 5842, "EPSG"), - new CrsMapEntry(5844, 5858, "EPSG"), - new CrsMapEntry(5875, 5877, "EPSG"), - new CrsMapEntry(5879, 5880, "EPSG"), - new CrsMapEntry(5884, 5884, "EPSG"), - new CrsMapEntry(5886, 5887, "EPSG"), - new CrsMapEntry(5890, 5890, "EPSG"), - new CrsMapEntry(5921, 5940, "EPSG"), - new CrsMapEntry(5942, 5942, "EPSG"), - new CrsMapEntry(5945, 5976, "EPSG"), - new CrsMapEntry(6050, 6125, "EPSG"), - new CrsMapEntry(6128, 6129, "EPSG"), - new CrsMapEntry(6133, 6133, "EPSG"), - new CrsMapEntry(6135, 6135, "EPSG"), - new CrsMapEntry(6141, 6141, "EPSG"), - new CrsMapEntry(6144, 6176, "EPSG"), - new CrsMapEntry(6190, 6190, "EPSG"), - new CrsMapEntry(6204, 6204, "EPSG"), - new CrsMapEntry(6207, 6207, "EPSG"), - new CrsMapEntry(6210, 6211, "EPSG"), - new CrsMapEntry(6307, 6307, "EPSG"), - new CrsMapEntry(6309, 6309, "EPSG"), - new CrsMapEntry(6311, 6312, "EPSG"), - new CrsMapEntry(6316, 6318, "EPSG"), - new CrsMapEntry(6320, 6320, "EPSG"), - new CrsMapEntry(6322, 6323, "EPSG"), - new CrsMapEntry(6325, 6325, "EPSG"), - new CrsMapEntry(6328, 6356, "EPSG"), - new CrsMapEntry(6362, 6363, "EPSG"), - new CrsMapEntry(6365, 6372, "EPSG"), - new CrsMapEntry(6381, 6387, "EPSG"), - new CrsMapEntry(6391, 6391, "EPSG"), - new CrsMapEntry(6393, 6637, "EPSG"), - new CrsMapEntry(6646, 6646, "EPSG"), - new CrsMapEntry(6649, 6666, "EPSG"), - new CrsMapEntry(6668, 6692, "EPSG"), - new CrsMapEntry(6696, 6697, "EPSG"), - new CrsMapEntry(6700, 6700, "EPSG"), - new CrsMapEntry(6703, 6704, "EPSG"), - new CrsMapEntry(6706, 6709, "EPSG"), - new CrsMapEntry(6720, 6723, "EPSG"), - new CrsMapEntry(6732, 6738, "EPSG"), - new CrsMapEntry(6781, 6781, "EPSG"), - new CrsMapEntry(6783, 6863, "EPSG"), - new CrsMapEntry(6867, 6868, "EPSG"), - new CrsMapEntry(6870, 6871, "EPSG"), - new CrsMapEntry(6875, 6876, "EPSG"), - new CrsMapEntry(6879, 6887, "EPSG"), - new CrsMapEntry(6892, 6894, "EPSG"), - new CrsMapEntry(6915, 6915, "EPSG"), - new CrsMapEntry(6917, 6917, "EPSG"), - new CrsMapEntry(6922, 6925, "EPSG"), - new CrsMapEntry(6927, 6927, "EPSG"), - new CrsMapEntry(6931, 6934, "EPSG"), - new CrsMapEntry(6956, 6959, "EPSG"), - new CrsMapEntry(6962, 6962, "EPSG"), - new CrsMapEntry(6978, 6978, "EPSG"), - new CrsMapEntry(6980, 6981, "EPSG"), - new CrsMapEntry(6983, 6985, "EPSG"), - new CrsMapEntry(6987, 6988, "EPSG"), - new CrsMapEntry(6990, 6991, "EPSG"), - new CrsMapEntry(6996, 6997, "EPSG"), - new CrsMapEntry(7005, 7007, "EPSG"), - new CrsMapEntry(7035, 7035, "EPSG"), - new CrsMapEntry(7037, 7037, "EPSG"), - new CrsMapEntry(7039, 7039, "EPSG"), - new CrsMapEntry(7041, 7041, "EPSG"), - new CrsMapEntry(7057, 7071, "EPSG"), - new CrsMapEntry(7073, 7081, "EPSG"), - new CrsMapEntry(7084, 7084, "EPSG"), - new CrsMapEntry(7086, 7086, "EPSG"), - new CrsMapEntry(7088, 7088, "EPSG"), - new CrsMapEntry(7109, 7128, "EPSG"), - new CrsMapEntry(7131, 7134, "EPSG"), - new CrsMapEntry(7136, 7137, "EPSG"), - new CrsMapEntry(7139, 7139, "EPSG"), - new CrsMapEntry(7142, 7142, "EPSG"), - new CrsMapEntry(7257, 7371, "EPSG"), - new CrsMapEntry(7373, 7376, "EPSG"), - new CrsMapEntry(7400, 7423, "EPSG"), - new CrsMapEntry(7528, 7645, "EPSG"), - new CrsMapEntry(7656, 7656, "EPSG"), - new CrsMapEntry(7658, 7658, "EPSG"), - new CrsMapEntry(7660, 7660, "EPSG"), - new CrsMapEntry(7662, 7662, "EPSG"), - new CrsMapEntry(7664, 7664, "EPSG"), - new CrsMapEntry(7677, 7677, "EPSG"), - new CrsMapEntry(7679, 7679, "EPSG"), - new CrsMapEntry(7681, 7681, "EPSG"), - new CrsMapEntry(7683, 7684, "EPSG"), - new CrsMapEntry(7686, 7686, "EPSG"), - new CrsMapEntry(7692, 7696, "EPSG"), - new CrsMapEntry(7755, 7787, "EPSG"), - new CrsMapEntry(7789, 7789, "EPSG"), - new CrsMapEntry(7791, 7796, "EPSG"), - new CrsMapEntry(7798, 7801, "EPSG"), - new CrsMapEntry(7803, 7805, "EPSG"), - new CrsMapEntry(7815, 7815, "EPSG"), - new CrsMapEntry(7825, 7831, "EPSG"), - new CrsMapEntry(7842, 7842, "EPSG"), - new CrsMapEntry(7844, 7859, "EPSG"), - new CrsMapEntry(7877, 7879, "EPSG"), - new CrsMapEntry(7881, 7884, "EPSG"), - new CrsMapEntry(7886, 7887, "EPSG"), - new CrsMapEntry(7899, 7899, "EPSG"), - new CrsMapEntry(7914, 7914, "EPSG"), - new CrsMapEntry(7916, 7916, "EPSG"), - new CrsMapEntry(7918, 7918, "EPSG"), - new CrsMapEntry(7920, 7920, "EPSG"), - new CrsMapEntry(7922, 7922, "EPSG"), - new CrsMapEntry(7924, 7924, "EPSG"), - new CrsMapEntry(7926, 7926, "EPSG"), - new CrsMapEntry(7928, 7928, "EPSG"), - new CrsMapEntry(7930, 7930, "EPSG"), - new CrsMapEntry(7954, 7956, "EPSG"), - new CrsMapEntry(7991, 7992, "EPSG"), - new CrsMapEntry(8013, 8032, "EPSG"), - new CrsMapEntry(8035, 8036, "EPSG"), - new CrsMapEntry(8042, 8045, "EPSG"), - new CrsMapEntry(8058, 8059, "EPSG"), - new CrsMapEntry(8065, 8068, "EPSG"), - new CrsMapEntry(8082, 8084, "EPSG"), - new CrsMapEntry(8086, 8086, "EPSG"), - new CrsMapEntry(8088, 8088, "EPSG"), - new CrsMapEntry(8090, 8093, "EPSG"), - new CrsMapEntry(8095, 8173, "EPSG"), - new CrsMapEntry(8177, 8177, "EPSG"), - new CrsMapEntry(8179, 8182, "EPSG"), - new CrsMapEntry(8184, 8185, "EPSG"), - new CrsMapEntry(8187, 8187, "EPSG"), - new CrsMapEntry(8189, 8189, "EPSG"), - new CrsMapEntry(8191, 8191, "EPSG"), - new CrsMapEntry(8193, 8193, "EPSG"), - new CrsMapEntry(8196, 8198, "EPSG"), - new CrsMapEntry(8200, 8210, "EPSG"), - new CrsMapEntry(8212, 8214, "EPSG"), - new CrsMapEntry(8216, 8216, "EPSG"), - new CrsMapEntry(8218, 8218, "EPSG"), - new CrsMapEntry(8220, 8220, "EPSG"), - new CrsMapEntry(8222, 8222, "EPSG"), - new CrsMapEntry(8224, 8227, "EPSG"), - new CrsMapEntry(8230, 8230, "EPSG"), - new CrsMapEntry(8232, 8233, "EPSG"), - new CrsMapEntry(8237, 8238, "EPSG"), - new CrsMapEntry(8240, 8240, "EPSG"), - new CrsMapEntry(8242, 8242, "EPSG"), - new CrsMapEntry(8246, 8247, "EPSG"), - new CrsMapEntry(8249, 8250, "EPSG"), - new CrsMapEntry(8252, 8253, "EPSG"), - new CrsMapEntry(8255, 8255, "EPSG"), - new CrsMapEntry(8311, 8350, "EPSG"), - new CrsMapEntry(20004, 20032, "EPSG"), - new CrsMapEntry(20064, 20092, "EPSG"), - new CrsMapEntry(20135, 20138, "EPSG"), - new CrsMapEntry(20248, 20258, "EPSG"), - new CrsMapEntry(20348, 20358, "EPSG"), - new CrsMapEntry(20436, 20440, "EPSG"), - new CrsMapEntry(20499, 20499, "EPSG"), - new CrsMapEntry(20538, 20539, "EPSG"), - new CrsMapEntry(20790, 20791, "EPSG"), - new CrsMapEntry(20822, 20824, "EPSG"), - new CrsMapEntry(20934, 20936, "EPSG"), - new CrsMapEntry(21035, 21037, "EPSG"), - new CrsMapEntry(21095, 21097, "EPSG"), - new CrsMapEntry(21100, 21100, "EPSG"), - new CrsMapEntry(21148, 21150, "EPSG"), - new CrsMapEntry(21291, 21292, "EPSG"), - new CrsMapEntry(21413, 21423, "EPSG"), - new CrsMapEntry(21453, 21463, "EPSG"), - new CrsMapEntry(21473, 21483, "EPSG"), - new CrsMapEntry(21500, 21500, "EPSG"), - new CrsMapEntry(21780, 21782, "EPSG"), - new CrsMapEntry(21817, 21818, "EPSG"), - new CrsMapEntry(21891, 21894, "EPSG"), - new CrsMapEntry(21896, 21899, "EPSG"), - new CrsMapEntry(22032, 22033, "EPSG"), - new CrsMapEntry(22091, 22092, "EPSG"), - new CrsMapEntry(22171, 22177, "EPSG"), - new CrsMapEntry(22181, 22187, "EPSG"), - new CrsMapEntry(22191, 22197, "EPSG"), - new CrsMapEntry(22234, 22236, "EPSG"), - new CrsMapEntry(22275, 22275, "EPSG"), - new CrsMapEntry(22277, 22277, "EPSG"), - new CrsMapEntry(22279, 22279, "EPSG"), - new CrsMapEntry(22281, 22281, "EPSG"), - new CrsMapEntry(22283, 22283, "EPSG"), - new CrsMapEntry(22285, 22285, "EPSG"), - new CrsMapEntry(22287, 22287, "EPSG"), - new CrsMapEntry(22289, 22289, "EPSG"), - new CrsMapEntry(22291, 22291, "EPSG"), - new CrsMapEntry(22293, 22293, "EPSG"), - new CrsMapEntry(22300, 22300, "EPSG"), - new CrsMapEntry(22332, 22332, "EPSG"), - new CrsMapEntry(22391, 22392, "EPSG"), - new CrsMapEntry(22521, 22525, "EPSG"), - new CrsMapEntry(22700, 22700, "EPSG"), - new CrsMapEntry(22770, 22770, "EPSG"), - new CrsMapEntry(22780, 22780, "EPSG"), - new CrsMapEntry(22832, 22832, "EPSG"), - new CrsMapEntry(22991, 22994, "EPSG"), - new CrsMapEntry(23028, 23038, "EPSG"), - new CrsMapEntry(23090, 23090, "EPSG"), - new CrsMapEntry(23095, 23095, "EPSG"), - new CrsMapEntry(23239, 23240, "EPSG"), - new CrsMapEntry(23433, 23433, "EPSG"), - new CrsMapEntry(23700, 23700, "EPSG"), - new CrsMapEntry(23830, 23853, "EPSG"), - new CrsMapEntry(23866, 23872, "EPSG"), - new CrsMapEntry(23877, 23884, "EPSG"), - new CrsMapEntry(23886, 23894, "EPSG"), - new CrsMapEntry(23946, 23948, "EPSG"), - new CrsMapEntry(24047, 24048, "EPSG"), - new CrsMapEntry(24100, 24100, "EPSG"), - new CrsMapEntry(24200, 24200, "EPSG"), - new CrsMapEntry(24305, 24306, "EPSG"), - new CrsMapEntry(24311, 24313, "EPSG"), - new CrsMapEntry(24342, 24347, "EPSG"), - new CrsMapEntry(24370, 24383, "EPSG"), - new CrsMapEntry(24500, 24500, "EPSG"), - new CrsMapEntry(24547, 24548, "EPSG"), - new CrsMapEntry(24571, 24571, "EPSG"), - new CrsMapEntry(24600, 24600, "EPSG"), - new CrsMapEntry(24718, 24720, "EPSG"), - new CrsMapEntry(24817, 24821, "EPSG"), - new CrsMapEntry(24877, 24882, "EPSG"), - new CrsMapEntry(24891, 24893, "EPSG"), - new CrsMapEntry(25000, 25000, "EPSG"), - new CrsMapEntry(25231, 25231, "EPSG"), - new CrsMapEntry(25391, 25395, "EPSG"), - new CrsMapEntry(25700, 25700, "EPSG"), - new CrsMapEntry(25828, 25838, "EPSG"), - new CrsMapEntry(25884, 25884, "EPSG"), - new CrsMapEntry(25932, 25932, "EPSG"), - new CrsMapEntry(26191, 26195, "EPSG"), - new CrsMapEntry(26237, 26237, "EPSG"), - new CrsMapEntry(26331, 26332, "EPSG"), - new CrsMapEntry(26391, 26393, "EPSG"), - new CrsMapEntry(26432, 26432, "EPSG"), - new CrsMapEntry(26591, 26592, "EPSG"), - new CrsMapEntry(26632, 26632, "EPSG"), - new CrsMapEntry(26692, 26692, "EPSG"), - new CrsMapEntry(26701, 26722, "EPSG"), - new CrsMapEntry(26729, 26760, "EPSG"), - new CrsMapEntry(26766, 26787, "EPSG"), - new CrsMapEntry(26791, 26799, "EPSG"), - new CrsMapEntry(26801, 26803, "EPSG"), - new CrsMapEntry(26811, 26815, "EPSG"), - new CrsMapEntry(26819, 26826, "EPSG"), - new CrsMapEntry(26830, 26837, "EPSG"), - new CrsMapEntry(26841, 26870, "EPSG"), - new CrsMapEntry(26891, 26899, "EPSG"), - new CrsMapEntry(26901, 26923, "EPSG"), - new CrsMapEntry(26929, 26946, "EPSG"), - new CrsMapEntry(26948, 26998, "EPSG"), - new CrsMapEntry(27037, 27040, "EPSG"), - new CrsMapEntry(27120, 27120, "EPSG"), - new CrsMapEntry(27200, 27200, "EPSG"), - new CrsMapEntry(27205, 27232, "EPSG"), - new CrsMapEntry(27258, 27260, "EPSG"), - new CrsMapEntry(27291, 27292, "EPSG"), - new CrsMapEntry(27391, 27398, "EPSG"), - new CrsMapEntry(27429, 27429, "EPSG"), - new CrsMapEntry(27492, 27493, "EPSG"), - new CrsMapEntry(27500, 27500, "EPSG"), - new CrsMapEntry(27561, 27564, "EPSG"), - new CrsMapEntry(27571, 27574, "EPSG"), - new CrsMapEntry(27581, 27584, "EPSG"), - new CrsMapEntry(27591, 27594, "EPSG"), - new CrsMapEntry(27700, 27700, "EPSG"), - new CrsMapEntry(28191, 28193, "EPSG"), - new CrsMapEntry(28232, 28232, "EPSG"), - new CrsMapEntry(28348, 28358, "EPSG"), - new CrsMapEntry(28402, 28432, "EPSG"), - new CrsMapEntry(28462, 28492, "EPSG"), - new CrsMapEntry(28600, 28600, "EPSG"), - new CrsMapEntry(28991, 28992, "EPSG"), - new CrsMapEntry(29100, 29101, "EPSG"), - new CrsMapEntry(29118, 29122, "EPSG"), - new CrsMapEntry(29168, 29172, "EPSG"), - new CrsMapEntry(29177, 29185, "EPSG"), - new CrsMapEntry(29187, 29195, "EPSG"), - new CrsMapEntry(29220, 29221, "EPSG"), - new CrsMapEntry(29333, 29333, "EPSG"), - new CrsMapEntry(29371, 29371, "EPSG"), - new CrsMapEntry(29373, 29373, "EPSG"), - new CrsMapEntry(29375, 29375, "EPSG"), - new CrsMapEntry(29377, 29377, "EPSG"), - new CrsMapEntry(29379, 29379, "EPSG"), - new CrsMapEntry(29381, 29381, "EPSG"), - new CrsMapEntry(29383, 29383, "EPSG"), - new CrsMapEntry(29385, 29385, "EPSG"), - new CrsMapEntry(29635, 29636, "EPSG"), - new CrsMapEntry(29700, 29702, "EPSG"), - new CrsMapEntry(29738, 29739, "EPSG"), - new CrsMapEntry(29849, 29850, "EPSG"), - new CrsMapEntry(29871, 29873, "EPSG"), - new CrsMapEntry(29900, 29903, "EPSG"), - new CrsMapEntry(30161, 30179, "EPSG"), - new CrsMapEntry(30200, 30200, "EPSG"), - new CrsMapEntry(30339, 30340, "EPSG"), - new CrsMapEntry(30491, 30494, "EPSG"), - new CrsMapEntry(30729, 30732, "EPSG"), - new CrsMapEntry(30791, 30792, "EPSG"), - new CrsMapEntry(30800, 30800, "EPSG"), - new CrsMapEntry(31028, 31028, "EPSG"), - new CrsMapEntry(31121, 31121, "EPSG"), - new CrsMapEntry(31154, 31154, "EPSG"), - new CrsMapEntry(31170, 31171, "EPSG"), - new CrsMapEntry(31251, 31259, "EPSG"), - new CrsMapEntry(31265, 31268, "EPSG"), - new CrsMapEntry(31275, 31279, "EPSG"), - new CrsMapEntry(31281, 31297, "EPSG"), - new CrsMapEntry(31300, 31300, "EPSG"), - new CrsMapEntry(31370, 31370, "EPSG"), - new CrsMapEntry(31461, 31469, "EPSG"), - new CrsMapEntry(31528, 31529, "EPSG"), - new CrsMapEntry(31600, 31600, "EPSG"), - new CrsMapEntry(31700, 31700, "EPSG"), - new CrsMapEntry(31838, 31839, "EPSG"), - new CrsMapEntry(31900, 31901, "EPSG"), - new CrsMapEntry(31965, 32003, "EPSG"), - new CrsMapEntry(32005, 32031, "EPSG"), - new CrsMapEntry(32033, 32058, "EPSG"), - new CrsMapEntry(32061, 32062, "EPSG"), - new CrsMapEntry(32064, 32067, "EPSG"), - new CrsMapEntry(32074, 32077, "EPSG"), - new CrsMapEntry(32081, 32086, "EPSG"), - new CrsMapEntry(32098, 32100, "EPSG"), - new CrsMapEntry(32104, 32104, "EPSG"), - new CrsMapEntry(32107, 32130, "EPSG"), - new CrsMapEntry(32133, 32158, "EPSG"), - new CrsMapEntry(32161, 32161, "EPSG"), - new CrsMapEntry(32164, 32167, "EPSG"), - new CrsMapEntry(32180, 32199, "EPSG"), - new CrsMapEntry(32201, 32260, "EPSG"), - new CrsMapEntry(32301, 32360, "EPSG"), - new CrsMapEntry(32401, 32460, "EPSG"), - new CrsMapEntry(32501, 32560, "EPSG"), - new CrsMapEntry(32601, 32667, "EPSG"), - new CrsMapEntry(32701, 32761, "EPSG"), - new CrsMapEntry(32766, 32766, "EPSG"), - new CrsMapEntry(900913, 900913, "spatialreferencing.org"), - }; - } + new(2000, 2180, "EPSG"), + new(2188, 2217, "EPSG"), + new(2219, 2220, "EPSG"), + new(2222, 2292, "EPSG"), + new(2294, 2295, "EPSG"), + new(2308, 2962, "EPSG"), + new(2964, 2973, "EPSG"), + new(2975, 2984, "EPSG"), + new(2987, 3051, "EPSG"), + new(3054, 3138, "EPSG"), + new(3140, 3143, "EPSG"), + new(3146, 3172, "EPSG"), + new(3174, 3294, "EPSG"), + new(3296, 3791, "EPSG"), + new(3793, 3802, "EPSG"), + new(3812, 3812, "EPSG"), + new(3814, 3816, "EPSG"), + new(3819, 3819, "EPSG"), + new(3821, 3822, "EPSG"), + new(3824, 3829, "EPSG"), + new(3832, 3852, "EPSG"), + new(3854, 3854, "EPSG"), + new(3857, 3857, "EPSG"), + new(3873, 3885, "EPSG"), + new(3887, 3887, "EPSG"), + new(3889, 3893, "EPSG"), + new(3901, 3903, "EPSG"), + new(3906, 3912, "EPSG"), + new(3920, 3920, "EPSG"), + new(3942, 3950, "EPSG"), + new(3968, 3970, "EPSG"), + new(3973, 3976, "EPSG"), + new(3978, 3979, "EPSG"), + new(3985, 3989, "EPSG"), + new(3991, 3992, "EPSG"), + new(3994, 3997, "EPSG"), + new(4000, 4016, "EPSG"), + new(4018, 4039, "EPSG"), + new(4041, 4063, "EPSG"), + new(4071, 4071, "EPSG"), + new(4073, 4073, "EPSG"), + new(4075, 4075, "EPSG"), + new(4079, 4079, "EPSG"), + new(4081, 4083, "EPSG"), + new(4087, 4088, "EPSG"), + new(4093, 4100, "EPSG"), + new(4120, 4176, "EPSG"), + new(4178, 4185, "EPSG"), + new(4188, 4289, "EPSG"), + new(4291, 4304, "EPSG"), + new(4306, 4319, "EPSG"), + new(4322, 4322, "EPSG"), + new(4324, 4324, "EPSG"), + new(4326, 4326, "EPSG"), + new(4328, 4328, "EPSG"), + new(4330, 4338, "EPSG"), + new(4340, 4340, "EPSG"), + new(4342, 4342, "EPSG"), + new(4344, 4344, "EPSG"), + new(4346, 4346, "EPSG"), + new(4348, 4348, "EPSG"), + new(4350, 4350, "EPSG"), + new(4352, 4352, "EPSG"), + new(4354, 4354, "EPSG"), + new(4356, 4356, "EPSG"), + new(4358, 4358, "EPSG"), + new(4360, 4360, "EPSG"), + new(4362, 4362, "EPSG"), + new(4364, 4364, "EPSG"), + new(4366, 4366, "EPSG"), + new(4368, 4368, "EPSG"), + new(4370, 4370, "EPSG"), + new(4372, 4372, "EPSG"), + new(4374, 4374, "EPSG"), + new(4376, 4376, "EPSG"), + new(4378, 4378, "EPSG"), + new(4380, 4380, "EPSG"), + new(4382, 4382, "EPSG"), + new(4384, 4385, "EPSG"), + new(4387, 4387, "EPSG"), + new(4389, 4415, "EPSG"), + new(4417, 4434, "EPSG"), + new(4437, 4439, "EPSG"), + new(4455, 4457, "EPSG"), + new(4462, 4463, "EPSG"), + new(4465, 4465, "EPSG"), + new(4467, 4468, "EPSG"), + new(4470, 4471, "EPSG"), + new(4473, 4475, "EPSG"), + new(4479, 4479, "EPSG"), + new(4481, 4481, "EPSG"), + new(4483, 4556, "EPSG"), + new(4558, 4559, "EPSG"), + new(4568, 4589, "EPSG"), + new(4600, 4647, "EPSG"), + new(4652, 4824, "EPSG"), + new(4826, 4826, "EPSG"), + new(4839, 4839, "EPSG"), + new(4855, 4880, "EPSG"), + new(4882, 4882, "EPSG"), + new(4884, 4884, "EPSG"), + new(4886, 4886, "EPSG"), + new(4888, 4888, "EPSG"), + new(4890, 4890, "EPSG"), + new(4892, 4892, "EPSG"), + new(4894, 4894, "EPSG"), + new(4896, 4897, "EPSG"), + new(4899, 4899, "EPSG"), + new(4901, 4904, "EPSG"), + new(4906, 4906, "EPSG"), + new(4908, 4908, "EPSG"), + new(4910, 4920, "EPSG"), + new(4922, 4922, "EPSG"), + new(4924, 4924, "EPSG"), + new(4926, 4926, "EPSG"), + new(4928, 4928, "EPSG"), + new(4930, 4930, "EPSG"), + new(4932, 4932, "EPSG"), + new(4934, 4934, "EPSG"), + new(4936, 4936, "EPSG"), + new(4938, 4938, "EPSG"), + new(4940, 4940, "EPSG"), + new(4942, 4942, "EPSG"), + new(4944, 4944, "EPSG"), + new(4946, 4946, "EPSG"), + new(4948, 4948, "EPSG"), + new(4950, 4950, "EPSG"), + new(4952, 4952, "EPSG"), + new(4954, 4954, "EPSG"), + new(4956, 4956, "EPSG"), + new(4958, 4958, "EPSG"), + new(4960, 4960, "EPSG"), + new(4962, 4962, "EPSG"), + new(4964, 4964, "EPSG"), + new(4966, 4966, "EPSG"), + new(4968, 4968, "EPSG"), + new(4970, 4970, "EPSG"), + new(4972, 4972, "EPSG"), + new(4974, 4974, "EPSG"), + new(4976, 4976, "EPSG"), + new(4978, 4978, "EPSG"), + new(4980, 4980, "EPSG"), + new(4982, 4982, "EPSG"), + new(4984, 4984, "EPSG"), + new(4986, 4986, "EPSG"), + new(4988, 4988, "EPSG"), + new(4990, 4990, "EPSG"), + new(4992, 4992, "EPSG"), + new(4994, 4994, "EPSG"), + new(4996, 4996, "EPSG"), + new(4998, 4998, "EPSG"), + new(5011, 5011, "EPSG"), + new(5013, 5016, "EPSG"), + new(5018, 5018, "EPSG"), + new(5041, 5042, "EPSG"), + new(5048, 5048, "EPSG"), + new(5069, 5072, "EPSG"), + new(5105, 5130, "EPSG"), + new(5132, 5132, "EPSG"), + new(5167, 5188, "EPSG"), + new(5221, 5221, "EPSG"), + new(5223, 5223, "EPSG"), + new(5228, 5229, "EPSG"), + new(5233, 5235, "EPSG"), + new(5243, 5244, "EPSG"), + new(5246, 5247, "EPSG"), + new(5250, 5250, "EPSG"), + new(5252, 5259, "EPSG"), + new(5262, 5262, "EPSG"), + new(5264, 5264, "EPSG"), + new(5266, 5266, "EPSG"), + new(5269, 5275, "EPSG"), + new(5292, 5311, "EPSG"), + new(5316, 5316, "EPSG"), + new(5318, 5318, "EPSG"), + new(5320, 5322, "EPSG"), + new(5324, 5325, "EPSG"), + new(5329, 5332, "EPSG"), + new(5337, 5337, "EPSG"), + new(5340, 5341, "EPSG"), + new(5343, 5349, "EPSG"), + new(5352, 5352, "EPSG"), + new(5354, 5358, "EPSG"), + new(5360, 5363, "EPSG"), + new(5365, 5365, "EPSG"), + new(5367, 5369, "EPSG"), + new(5371, 5371, "EPSG"), + new(5373, 5373, "EPSG"), + new(5379, 5379, "EPSG"), + new(5381, 5383, "EPSG"), + new(5387, 5389, "EPSG"), + new(5391, 5391, "EPSG"), + new(5393, 5393, "EPSG"), + new(5396, 5396, "EPSG"), + new(5451, 5451, "EPSG"), + new(5456, 5464, "EPSG"), + new(5466, 5467, "EPSG"), + new(5469, 5469, "EPSG"), + new(5472, 5472, "EPSG"), + new(5479, 5482, "EPSG"), + new(5487, 5487, "EPSG"), + new(5489, 5490, "EPSG"), + new(5498, 5500, "EPSG"), + new(5513, 5514, "EPSG"), + new(5518, 5520, "EPSG"), + new(5523, 5524, "EPSG"), + new(5527, 5527, "EPSG"), + new(5530, 5539, "EPSG"), + new(5544, 5544, "EPSG"), + new(5546, 5546, "EPSG"), + new(5550, 5552, "EPSG"), + new(5554, 5556, "EPSG"), + new(5558, 5559, "EPSG"), + new(5561, 5583, "EPSG"), + new(5588, 5589, "EPSG"), + new(5591, 5591, "EPSG"), + new(5593, 5593, "EPSG"), + new(5596, 5596, "EPSG"), + new(5598, 5598, "EPSG"), + new(5623, 5625, "EPSG"), + new(5627, 5629, "EPSG"), + new(5631, 5639, "EPSG"), + new(5641, 5641, "EPSG"), + new(5643, 5644, "EPSG"), + new(5646, 5646, "EPSG"), + new(5649, 5655, "EPSG"), + new(5659, 5659, "EPSG"), + new(5663, 5685, "EPSG"), + new(5698, 5700, "EPSG"), + new(5707, 5708, "EPSG"), + new(5825, 5825, "EPSG"), + new(5828, 5828, "EPSG"), + new(5832, 5837, "EPSG"), + new(5839, 5839, "EPSG"), + new(5842, 5842, "EPSG"), + new(5844, 5858, "EPSG"), + new(5875, 5877, "EPSG"), + new(5879, 5880, "EPSG"), + new(5884, 5884, "EPSG"), + new(5886, 5887, "EPSG"), + new(5890, 5890, "EPSG"), + new(5921, 5940, "EPSG"), + new(5942, 5942, "EPSG"), + new(5945, 5976, "EPSG"), + new(6050, 6125, "EPSG"), + new(6128, 6129, "EPSG"), + new(6133, 6133, "EPSG"), + new(6135, 6135, "EPSG"), + new(6141, 6141, "EPSG"), + new(6144, 6176, "EPSG"), + new(6190, 6190, "EPSG"), + new(6204, 6204, "EPSG"), + new(6207, 6207, "EPSG"), + new(6210, 6211, "EPSG"), + new(6307, 6307, "EPSG"), + new(6309, 6309, "EPSG"), + new(6311, 6312, "EPSG"), + new(6316, 6318, "EPSG"), + new(6320, 6320, "EPSG"), + new(6322, 6323, "EPSG"), + new(6325, 6325, "EPSG"), + new(6328, 6356, "EPSG"), + new(6362, 6363, "EPSG"), + new(6365, 6372, "EPSG"), + new(6381, 6387, "EPSG"), + new(6391, 6391, "EPSG"), + new(6393, 6637, "EPSG"), + new(6646, 6646, "EPSG"), + new(6649, 6666, "EPSG"), + new(6668, 6692, "EPSG"), + new(6696, 6697, "EPSG"), + new(6700, 6700, "EPSG"), + new(6703, 6704, "EPSG"), + new(6706, 6709, "EPSG"), + new(6720, 6723, "EPSG"), + new(6732, 6738, "EPSG"), + new(6781, 6781, "EPSG"), + new(6783, 6863, "EPSG"), + new(6867, 6868, "EPSG"), + new(6870, 6871, "EPSG"), + new(6875, 6876, "EPSG"), + new(6879, 6887, "EPSG"), + new(6892, 6894, "EPSG"), + new(6915, 6915, "EPSG"), + new(6917, 6917, "EPSG"), + new(6922, 6925, "EPSG"), + new(6927, 6927, "EPSG"), + new(6931, 6934, "EPSG"), + new(6956, 6959, "EPSG"), + new(6962, 6962, "EPSG"), + new(6978, 6978, "EPSG"), + new(6980, 6981, "EPSG"), + new(6983, 6985, "EPSG"), + new(6987, 6988, "EPSG"), + new(6990, 6991, "EPSG"), + new(6996, 6997, "EPSG"), + new(7005, 7007, "EPSG"), + new(7035, 7035, "EPSG"), + new(7037, 7037, "EPSG"), + new(7039, 7039, "EPSG"), + new(7041, 7041, "EPSG"), + new(7057, 7071, "EPSG"), + new(7073, 7081, "EPSG"), + new(7084, 7084, "EPSG"), + new(7086, 7086, "EPSG"), + new(7088, 7088, "EPSG"), + new(7109, 7128, "EPSG"), + new(7131, 7134, "EPSG"), + new(7136, 7137, "EPSG"), + new(7139, 7139, "EPSG"), + new(7142, 7142, "EPSG"), + new(7257, 7371, "EPSG"), + new(7373, 7376, "EPSG"), + new(7400, 7423, "EPSG"), + new(7528, 7645, "EPSG"), + new(7656, 7656, "EPSG"), + new(7658, 7658, "EPSG"), + new(7660, 7660, "EPSG"), + new(7662, 7662, "EPSG"), + new(7664, 7664, "EPSG"), + new(7677, 7677, "EPSG"), + new(7679, 7679, "EPSG"), + new(7681, 7681, "EPSG"), + new(7683, 7684, "EPSG"), + new(7686, 7686, "EPSG"), + new(7692, 7696, "EPSG"), + new(7755, 7787, "EPSG"), + new(7789, 7789, "EPSG"), + new(7791, 7796, "EPSG"), + new(7798, 7801, "EPSG"), + new(7803, 7805, "EPSG"), + new(7815, 7815, "EPSG"), + new(7825, 7831, "EPSG"), + new(7842, 7842, "EPSG"), + new(7844, 7859, "EPSG"), + new(7877, 7879, "EPSG"), + new(7881, 7884, "EPSG"), + new(7886, 7887, "EPSG"), + new(7899, 7899, "EPSG"), + new(7914, 7914, "EPSG"), + new(7916, 7916, "EPSG"), + new(7918, 7918, "EPSG"), + new(7920, 7920, "EPSG"), + new(7922, 7922, "EPSG"), + new(7924, 7924, "EPSG"), + new(7926, 7926, "EPSG"), + new(7928, 7928, "EPSG"), + new(7930, 7930, "EPSG"), + new(7954, 7956, "EPSG"), + new(7991, 7992, "EPSG"), + new(8013, 8032, "EPSG"), + new(8035, 8036, "EPSG"), + new(8042, 8045, "EPSG"), + new(8058, 8059, "EPSG"), + new(8065, 8068, "EPSG"), + new(8082, 8084, "EPSG"), + new(8086, 8086, "EPSG"), + new(8088, 8088, "EPSG"), + new(8090, 8093, "EPSG"), + new(8095, 8173, "EPSG"), + new(8177, 8177, "EPSG"), + new(8179, 8182, "EPSG"), + new(8184, 8185, "EPSG"), + new(8187, 8187, "EPSG"), + new(8189, 8189, "EPSG"), + new(8191, 8191, "EPSG"), + new(8193, 8193, "EPSG"), + new(8196, 8198, "EPSG"), + new(8200, 8210, "EPSG"), + new(8212, 8214, "EPSG"), + new(8216, 8216, "EPSG"), + new(8218, 8218, "EPSG"), + new(8220, 8220, "EPSG"), + new(8222, 8222, "EPSG"), + new(8224, 8227, "EPSG"), + new(8230, 8230, "EPSG"), + new(8232, 8233, "EPSG"), + new(8237, 8238, "EPSG"), + new(8240, 8240, "EPSG"), + new(8242, 8242, "EPSG"), + new(8246, 8247, "EPSG"), + new(8249, 8250, "EPSG"), + new(8252, 8253, "EPSG"), + new(8255, 8255, "EPSG"), + new(8311, 8350, "EPSG"), + new(20004, 20032, "EPSG"), + new(20064, 20092, "EPSG"), + new(20135, 20138, "EPSG"), + new(20248, 20258, "EPSG"), + new(20348, 20358, "EPSG"), + new(20436, 20440, "EPSG"), + new(20499, 20499, "EPSG"), + new(20538, 20539, "EPSG"), + new(20790, 20791, "EPSG"), + new(20822, 20824, "EPSG"), + new(20934, 20936, "EPSG"), + new(21035, 21037, "EPSG"), + new(21095, 21097, "EPSG"), + new(21100, 21100, "EPSG"), + new(21148, 21150, "EPSG"), + new(21291, 21292, "EPSG"), + new(21413, 21423, "EPSG"), + new(21453, 21463, "EPSG"), + new(21473, 21483, "EPSG"), + new(21500, 21500, "EPSG"), + new(21780, 21782, "EPSG"), + new(21817, 21818, "EPSG"), + new(21891, 21894, "EPSG"), + new(21896, 21899, "EPSG"), + new(22032, 22033, "EPSG"), + new(22091, 22092, "EPSG"), + new(22171, 22177, "EPSG"), + new(22181, 22187, "EPSG"), + new(22191, 22197, "EPSG"), + new(22234, 22236, "EPSG"), + new(22275, 22275, "EPSG"), + new(22277, 22277, "EPSG"), + new(22279, 22279, "EPSG"), + new(22281, 22281, "EPSG"), + new(22283, 22283, "EPSG"), + new(22285, 22285, "EPSG"), + new(22287, 22287, "EPSG"), + new(22289, 22289, "EPSG"), + new(22291, 22291, "EPSG"), + new(22293, 22293, "EPSG"), + new(22300, 22300, "EPSG"), + new(22332, 22332, "EPSG"), + new(22391, 22392, "EPSG"), + new(22521, 22525, "EPSG"), + new(22700, 22700, "EPSG"), + new(22770, 22770, "EPSG"), + new(22780, 22780, "EPSG"), + new(22832, 22832, "EPSG"), + new(22991, 22994, "EPSG"), + new(23028, 23038, "EPSG"), + new(23090, 23090, "EPSG"), + new(23095, 23095, "EPSG"), + new(23239, 23240, "EPSG"), + new(23433, 23433, "EPSG"), + new(23700, 23700, "EPSG"), + new(23830, 23853, "EPSG"), + new(23866, 23872, "EPSG"), + new(23877, 23884, "EPSG"), + new(23886, 23894, "EPSG"), + new(23946, 23948, "EPSG"), + new(24047, 24048, "EPSG"), + new(24100, 24100, "EPSG"), + new(24200, 24200, "EPSG"), + new(24305, 24306, "EPSG"), + new(24311, 24313, "EPSG"), + new(24342, 24347, "EPSG"), + new(24370, 24383, "EPSG"), + new(24500, 24500, "EPSG"), + new(24547, 24548, "EPSG"), + new(24571, 24571, "EPSG"), + new(24600, 24600, "EPSG"), + new(24718, 24720, "EPSG"), + new(24817, 24821, "EPSG"), + new(24877, 24882, "EPSG"), + new(24891, 24893, "EPSG"), + new(25000, 25000, "EPSG"), + new(25231, 25231, "EPSG"), + new(25391, 25395, "EPSG"), + new(25700, 25700, "EPSG"), + new(25828, 25838, "EPSG"), + new(25884, 25884, "EPSG"), + new(25932, 25932, "EPSG"), + new(26191, 26195, "EPSG"), + new(26237, 26237, "EPSG"), + new(26331, 26332, "EPSG"), + new(26391, 26393, "EPSG"), + new(26432, 26432, "EPSG"), + new(26591, 26592, "EPSG"), + new(26632, 26632, "EPSG"), + new(26692, 26692, "EPSG"), + new(26701, 26722, "EPSG"), + new(26729, 26760, "EPSG"), + new(26766, 26787, "EPSG"), + new(26791, 26799, "EPSG"), + new(26801, 26803, "EPSG"), + new(26811, 26815, "EPSG"), + new(26819, 26826, "EPSG"), + new(26830, 26837, "EPSG"), + new(26841, 26870, "EPSG"), + new(26891, 26899, "EPSG"), + new(26901, 26923, "EPSG"), + new(26929, 26946, "EPSG"), + new(26948, 26998, "EPSG"), + new(27037, 27040, "EPSG"), + new(27120, 27120, "EPSG"), + new(27200, 27200, "EPSG"), + new(27205, 27232, "EPSG"), + new(27258, 27260, "EPSG"), + new(27291, 27292, "EPSG"), + new(27391, 27398, "EPSG"), + new(27429, 27429, "EPSG"), + new(27492, 27493, "EPSG"), + new(27500, 27500, "EPSG"), + new(27561, 27564, "EPSG"), + new(27571, 27574, "EPSG"), + new(27581, 27584, "EPSG"), + new(27591, 27594, "EPSG"), + new(27700, 27700, "EPSG"), + new(28191, 28193, "EPSG"), + new(28232, 28232, "EPSG"), + new(28348, 28358, "EPSG"), + new(28402, 28432, "EPSG"), + new(28462, 28492, "EPSG"), + new(28600, 28600, "EPSG"), + new(28991, 28992, "EPSG"), + new(29100, 29101, "EPSG"), + new(29118, 29122, "EPSG"), + new(29168, 29172, "EPSG"), + new(29177, 29185, "EPSG"), + new(29187, 29195, "EPSG"), + new(29220, 29221, "EPSG"), + new(29333, 29333, "EPSG"), + new(29371, 29371, "EPSG"), + new(29373, 29373, "EPSG"), + new(29375, 29375, "EPSG"), + new(29377, 29377, "EPSG"), + new(29379, 29379, "EPSG"), + new(29381, 29381, "EPSG"), + new(29383, 29383, "EPSG"), + new(29385, 29385, "EPSG"), + new(29635, 29636, "EPSG"), + new(29700, 29702, "EPSG"), + new(29738, 29739, "EPSG"), + new(29849, 29850, "EPSG"), + new(29871, 29873, "EPSG"), + new(29900, 29903, "EPSG"), + new(30161, 30179, "EPSG"), + new(30200, 30200, "EPSG"), + new(30339, 30340, "EPSG"), + new(30491, 30494, "EPSG"), + new(30729, 30732, "EPSG"), + new(30791, 30792, "EPSG"), + new(30800, 30800, "EPSG"), + new(31028, 31028, "EPSG"), + new(31121, 31121, "EPSG"), + new(31154, 31154, "EPSG"), + new(31170, 31171, "EPSG"), + new(31251, 31259, "EPSG"), + new(31265, 31268, "EPSG"), + new(31275, 31279, "EPSG"), + new(31281, 31297, "EPSG"), + new(31300, 31300, "EPSG"), + new(31370, 31370, "EPSG"), + new(31461, 31469, "EPSG"), + new(31528, 31529, "EPSG"), + new(31600, 31600, "EPSG"), + new(31700, 31700, "EPSG"), + new(31838, 31839, "EPSG"), + new(31900, 31901, "EPSG"), + new(31965, 32003, "EPSG"), + new(32005, 32031, "EPSG"), + new(32033, 32058, "EPSG"), + new(32061, 32062, "EPSG"), + new(32064, 32067, "EPSG"), + new(32074, 32077, "EPSG"), + new(32081, 32086, "EPSG"), + new(32098, 32100, "EPSG"), + new(32104, 32104, "EPSG"), + new(32107, 32130, "EPSG"), + new(32133, 32158, "EPSG"), + new(32161, 32161, "EPSG"), + new(32164, 32167, "EPSG"), + new(32180, 32199, "EPSG"), + new(32201, 32260, "EPSG"), + new(32301, 32360, "EPSG"), + new(32401, 32460, "EPSG"), + new(32501, 32560, "EPSG"), + new(32601, 32667, "EPSG"), + new(32701, 32761, "EPSG"), + new(32766, 32766, "EPSG"), + new(900913, 900913, "spatialreferencing.org"), + }; } diff --git a/src/Npgsql.GeoJSON/CrsMap.cs b/src/Npgsql.GeoJSON/CrsMap.cs index 01f6701d99..dd556d9b33 100644 --- a/src/Npgsql.GeoJSON/CrsMap.cs +++ b/src/Npgsql.GeoJSON/CrsMap.cs @@ -1,109 +1,59 @@ -using System; -namespace Npgsql.GeoJSON +namespace Npgsql.GeoJSON; + +/// +/// A map of entries that map the authority to the inclusive range of SRID. +/// +public partial class CrsMap { - /// - /// An entry which maps the authority to the inclusive range of SRID. - /// - readonly struct CrsMapEntry - { - internal readonly int MinSrid; - internal readonly int MaxSrid; - internal readonly string? Authority; + readonly CrsMapEntry[]? _overriden; - internal CrsMapEntry(int minSrid, int maxSrid, string? authority) - { - MinSrid = minSrid; - MaxSrid = maxSrid; - Authority = authority != null - ? string.IsInterned(authority) ?? authority - : null; - } - } + internal CrsMap(CrsMapEntry[]? overriden) + => _overriden = overriden; + + internal string? GetAuthority(int srid) + => GetAuthority(_overriden, srid) ?? GetAuthority(WellKnown, srid); - ref struct CrsMapBuilder + static string? GetAuthority(CrsMapEntry[]? entries, int srid) { - CrsMapEntry[] _overrides; - int _overridenIndex; - int _wellKnownIndex; + if (entries == null) + return null; - internal void Add(in CrsMapEntry entry) + var left = 0; + var right = entries.Length; + while (left <= right) { - var wellKnown = CrsMap.WellKnown[_wellKnownIndex]; - if (wellKnown.MinSrid == entry.MinSrid && - wellKnown.MaxSrid == entry.MaxSrid && - string.Equals(wellKnown.Authority, entry.Authority, StringComparison.Ordinal)) - { - _wellKnownIndex++; - return; - } + var middle = left + (right - left) / 2; + var entry = entries[middle]; - if (wellKnown.MinSrid < entry.MinSrid) - { - do - _wellKnownIndex++; - while (CrsMap.WellKnown.Length < _wellKnownIndex && - CrsMap.WellKnown[_wellKnownIndex].MaxSrid < entry.MaxSrid); - AddCore(new CrsMapEntry(wellKnown.MinSrid, Math.Min(wellKnown.MaxSrid, entry.MinSrid - 1), null)); - } - - AddCore(entry); - } - - void AddCore(in CrsMapEntry entry) - { - var index = _overridenIndex + 1; - if (_overrides == null) - _overrides = new CrsMapEntry[4]; + if (srid < entry.MinSrid) + right = middle - 1; else - if (_overrides.Length == index) - Array.Resize(ref _overrides, _overrides.Length << 1); - - _overrides[_overridenIndex] = entry; - _overridenIndex = index; + if (srid > entry.MaxSrid) + left = middle + 1; + else + return entry.Authority; } - internal CrsMap Build() - { - if (_overrides != null && _overrides.Length < _overridenIndex) - Array.Resize(ref _overrides, _overridenIndex); - - return new CrsMap(_overrides); - } + return null; } +} - readonly partial struct CrsMap - { - readonly CrsMapEntry[]? _overriden; - - internal CrsMap(CrsMapEntry[]? overriden) - => _overriden = overriden; - - internal string? GetAuthority(int srid) - => GetAuthority(_overriden, srid) ?? GetAuthority(WellKnown, srid); - - static string? GetAuthority(CrsMapEntry[]? entries, int srid) - { - if (entries == null) - return null; - - var left = 0; - var right = entries.Length; - while (left <= right) - { - var middle = left + (right - left) / 2; - var entry = entries[middle]; - - if (srid < entry.MinSrid) - right = middle - 1; - else - if (srid > entry.MaxSrid) - left = middle + 1; - else - return entry.Authority; - } +/// +/// An entry which maps the authority to the inclusive range of SRID. +/// +readonly struct CrsMapEntry +{ + internal readonly int MinSrid; + internal readonly int MaxSrid; + internal readonly string? Authority; - return null; - } + internal CrsMapEntry(int minSrid, int maxSrid, string? authority) + { + MinSrid = minSrid; + MaxSrid = maxSrid; + Authority = authority != null + ? string.IsInterned(authority) ?? authority + : null; } } diff --git a/src/Npgsql.GeoJSON/CrsMapExtensions.cs b/src/Npgsql.GeoJSON/CrsMapExtensions.cs new file mode 100644 index 0000000000..dde5e0f688 --- /dev/null +++ b/src/Npgsql.GeoJSON/CrsMapExtensions.cs @@ -0,0 +1,51 @@ +using System; +using System.Threading.Tasks; +using Npgsql.GeoJSON.Internal; + +namespace Npgsql.GeoJSON; + +/// +/// Extensions for getting a CrsMap from a database. +/// +public static class CrsMapExtensions +{ + /// + /// Gets the full crs details from the database. + /// + /// + public static async Task GetCrsMapAsync(this NpgsqlDataSource dataSource) + { + var builder = new CrsMapBuilder(); + using var cmd = GetCsrCommand(dataSource); + using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + + while (await reader.ReadAsync().ConfigureAwait(false)) + builder.Add(new CrsMapEntry(reader.GetInt32(0), reader.GetInt32(1), reader.GetString(2))); + + return builder.Build(); + } + + /// + /// Gets the full crs details from the database. + /// + /// + public static CrsMap GetCrsMap(this NpgsqlDataSource dataSource) + { + var builder = new CrsMapBuilder(); + using var cmd = GetCsrCommand(dataSource); + using var reader = cmd.ExecuteReader(); + + while (reader.Read()) + builder.Add(new CrsMapEntry(reader.GetInt32(0), reader.GetInt32(1), reader.GetString(2))); + + return builder.Build(); + } + + static NpgsqlCommand GetCsrCommand(NpgsqlDataSource dataSource) + => dataSource.CreateCommand(""" + SELECT min(srid), max(srid), auth_name + FROM(SELECT srid, auth_name, srid - rank() OVER(PARTITION BY auth_name ORDER BY srid) AS range FROM spatial_ref_sys) AS s + GROUP BY range, auth_name + ORDER BY 1; + """); +} diff --git a/src/Npgsql.GeoJSON/GeoJSONHandler.cs b/src/Npgsql.GeoJSON/GeoJSONHandler.cs deleted file mode 100644 index 53861a06a7..0000000000 --- a/src/Npgsql.GeoJSON/GeoJSONHandler.cs +++ /dev/null @@ -1,753 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.ObjectModel; -using System.Threading; -using System.Threading.Tasks; -using GeoJSON.Net; -using GeoJSON.Net.CoordinateReferenceSystem; -using GeoJSON.Net.Geometry; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.GeoJSON -{ - [Flags] - public enum GeoJSONOptions - { - None = 0, - BoundingBox = 1, - ShortCRS = 2, - LongCRS = 4 - } - - public sealed class GeoJSONHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly GeoJSONOptions _options; - - public GeoJSONHandlerFactory(GeoJSONOptions options = GeoJSONOptions.None) - => _options = options; - - static readonly ConcurrentDictionary s_crsMaps = new ConcurrentDictionary(); - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - { - var crsMap = (_options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS)) == GeoJSONOptions.None - ? default : s_crsMaps.GetOrAdd(conn.ConnectionString, _ => - { - var builder = new CrsMapBuilder(); - using (var cmd = new NpgsqlCommand( - "SELECT min(srid), max(srid), auth_name " + - "FROM(SELECT srid, auth_name, srid - rank() OVER(ORDER BY srid) AS range " + - "FROM spatial_ref_sys) AS s GROUP BY range, auth_name ORDER BY 1;", conn)) - using (var reader = cmd.ExecuteReader()) - while (reader.Read()) - { - builder.Add(new CrsMapEntry( - reader.GetInt32(0), - reader.GetInt32(1), - reader.GetString(2))); - } - return builder.Build(); - }); - return new GeoJsonHandler(postgresType, _options, crsMap); - } - } - - sealed class GeoJsonHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler - { - readonly GeoJSONOptions _options; - readonly CrsMap _crsMap; - NamedCRS? _lastCrs; - int _lastSrid; - - internal GeoJsonHandler(PostgresType postgresType, GeoJSONOptions options, CrsMap crsMap) - : base(postgresType) - { - _options = options; - _crsMap = crsMap; - } - - GeoJSONOptions CrsType => _options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS); - - bool BoundingBox => (_options & GeoJSONOptions.BoundingBox) != 0; - - static bool HasSrid(EwkbGeometryType type) - => (type & EwkbGeometryType.HasSrid) != 0; - - static bool HasZ(EwkbGeometryType type) - => (type & EwkbGeometryType.HasZ) != 0; - - static bool HasM(EwkbGeometryType type) - => (type & EwkbGeometryType.HasM) != 0; - - static bool HasZ(IPosition coordinates) - => coordinates.Altitude.HasValue; - - const int SizeOfLength = sizeof(int); - const int SizeOfHeader = sizeof(byte) + sizeof(EwkbGeometryType); - const int SizeOfHeaderWithLength = SizeOfHeader + SizeOfLength; - const int SizeOfPoint2D = 2 * sizeof(double); - const int SizeOfPoint3D = 3 * sizeof(double); - - static int SizeOfPoint(bool hasZ) - => hasZ ? SizeOfPoint3D : SizeOfPoint2D; - - static int SizeOfPoint(EwkbGeometryType type) - { - var size = SizeOfPoint2D; - if (HasZ(type)) - size += sizeof(double); - if (HasM(type)) - size += sizeof(double); - return size; - } - - #region Throw - - static Exception UnknownPostGisType() - => throw new InvalidOperationException("Invalid PostGIS type"); - - static Exception AllOrNoneCoordiantesMustHaveZ(NpgsqlParameter? parameter, string typeName) - => parameter is null - ? new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName}") - : new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName} in the {parameter.ParameterName} parameter", parameter.ParameterName); - - #endregion - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (Point)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (LineString)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (Polygon)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiPoint)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiLineString)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (MultiPolygon)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (GeometryCollection)await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => await ReadGeometry(buf, async); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (IGeometryObject)await ReadGeometry(buf, async); - - async ValueTask ReadGeometry(NpgsqlReadBuffer buf, bool async) - { - var boundingBox = BoundingBox ? new BoundingBoxBuilder() : null; - var geometry = await ReadGeometryCore(buf, async, boundingBox); - - geometry.BoundingBoxes = boundingBox?.Build(); - return geometry; - } - - async ValueTask ReadGeometryCore(NpgsqlReadBuffer buf, bool async, BoundingBoxBuilder? boundingBox) - { - await buf.Ensure(SizeOfHeader, async); - var littleEndian = buf.ReadByte() > 0; - var type = (EwkbGeometryType)buf.ReadUInt32(littleEndian); - - GeoJSONObject geometry; - NamedCRS? crs = null; - - if (HasSrid(type)) - { - await buf.Ensure(4, async); - crs = GetCrs(buf.ReadInt32(littleEndian)); - } - - switch (type & EwkbGeometryType.BaseType) - { - case EwkbGeometryType.Point: - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - geometry = new Point(position); - break; - } - - case EwkbGeometryType.LineString: - { - await buf.Ensure(SizeOfLength, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < coordinates.Length; ++i) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[i] = position; - } - geometry = new LineString(coordinates); - break; - } - - case EwkbGeometryType.Polygon: - { - await buf.Ensure(SizeOfLength, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < lines.Length; ++i) - { - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < coordinates.Length; ++j) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[j] = position; - } - lines[i] = new LineString(coordinates); - } - geometry = new Polygon(lines); - break; - } - - case EwkbGeometryType.MultiPoint: - { - await buf.Ensure(SizeOfLength, async); - var points = new Point[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < points.Length; ++i) - { - await buf.Ensure(SizeOfHeader + SizeOfPoint(type), async); - await buf.Skip(SizeOfHeader, async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - points[i] = new Point(position); - } - geometry = new MultiPoint(points); - break; - } - - case EwkbGeometryType.MultiLineString: - { - await buf.Ensure(SizeOfLength, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < lines.Length; ++i) - { - await buf.Ensure(SizeOfHeaderWithLength, async); - await buf.Skip(SizeOfHeader, async); - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < coordinates.Length; ++j) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[j] = position; - } - lines[i] = new LineString(coordinates); - } - geometry = new MultiLineString(lines); - break; - } - - case EwkbGeometryType.MultiPolygon: - { - await buf.Ensure(SizeOfLength, async); - var polygons = new Polygon[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < polygons.Length; ++i) - { - await buf.Ensure(SizeOfHeaderWithLength, async); - await buf.Skip(SizeOfHeader, async); - var lines = new LineString[buf.ReadInt32(littleEndian)]; - for (var j = 0; j < lines.Length; ++j) - { - var coordinates = new Position[buf.ReadInt32(littleEndian)]; - for (var k = 0; k < coordinates.Length; ++k) - { - await buf.Ensure(SizeOfPoint(type), async); - var position = ReadPosition(buf, type, littleEndian); - boundingBox?.Accumulate(position); - coordinates[k] = position; - } - lines[j] = new LineString(coordinates); - } - polygons[i] = new Polygon(lines); - } - geometry = new MultiPolygon(polygons); - break; - } - - case EwkbGeometryType.GeometryCollection: - { - await buf.Ensure(SizeOfLength, async); - var elements = new IGeometryObject[buf.ReadInt32(littleEndian)]; - for (var i = 0; i < elements.Length; ++i) - elements[i] = (IGeometryObject)await ReadGeometryCore(buf, async, boundingBox); - geometry = new GeometryCollection(elements); - break; - } - - default: - throw UnknownPostGisType(); - } - - geometry.CRS = crs; - return geometry; - } - - static Position ReadPosition(NpgsqlReadBuffer buf, EwkbGeometryType type, bool littleEndian) - { - var position = new Position( - longitude: buf.ReadDouble(littleEndian), - latitude: buf.ReadDouble(littleEndian), - altitude: HasZ(type) ? buf.ReadDouble() : (double?)null); - if (HasM(type)) buf.ReadDouble(littleEndian); - return position; - } - - #endregion - - #region Write - - public override int ValidateAndGetLength(GeoJSONObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Type switch - { - GeoJSONObjectType.Point => ValidateAndGetLength((Point)value, ref lengthCache, parameter), - GeoJSONObjectType.LineString => ValidateAndGetLength((LineString)value, ref lengthCache, parameter), - GeoJSONObjectType.Polygon => ValidateAndGetLength((Polygon)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiPoint => ValidateAndGetLength((MultiPoint)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiLineString => ValidateAndGetLength((MultiLineString)value, ref lengthCache, parameter), - GeoJSONObjectType.MultiPolygon => ValidateAndGetLength((MultiPolygon)value, ref lengthCache, parameter), - GeoJSONObjectType.GeometryCollection => ValidateAndGetLength((GeometryCollection)value, ref lengthCache, parameter), - _ => throw UnknownPostGisType() - }; - - public int ValidateAndGetLength(Point value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeader + SizeOfPoint(HasZ(value.Coordinates)); - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - return length; - } - - public int ValidateAndGetLength(LineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var coordinates = value.Coordinates; - if (NotValid(coordinates, out var hasZ)) - throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(LineString)); - - var length = SizeOfHeaderWithLength + coordinates.Count * SizeOfPoint(hasZ); - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - return length; - } - - public int ValidateAndGetLength(Polygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var lines = value.Coordinates; - var length = SizeOfHeaderWithLength + SizeOfLength * lines.Count; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var hasZ = false; - for (var i = 0; i < lines.Count; ++i) - { - var coordinates = lines[i].Coordinates; - if (NotValid(coordinates, out var lineHasZ)) - throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(Polygon)); - - if (hasZ != lineHasZ) - { - if (i == 0) hasZ = lineHasZ; - else throw AllOrNoneCoordiantesMustHaveZ(parameter, nameof(LineString)); - } - - length += coordinates.Count * SizeOfPoint(hasZ); - } - - return length; - } - - static bool NotValid(ReadOnlyCollection coordinates, out bool hasZ) - { - if (coordinates.Count == 0) - hasZ = false; - else - { - hasZ = HasZ(coordinates[0]); - for (var i = 1; i < coordinates.Count; ++i) - if (HasZ(coordinates[i]) != hasZ) return true; - } - return false; - } - - public int ValidateAndGetLength(MultiPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(MultiLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(MultiPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var coordinates = value.Coordinates; - for (var i = 0; i < coordinates.Count; ++i) - length += ValidateAndGetLength(coordinates[i], ref lengthCache, parameter); - - return length; - } - - public int ValidateAndGetLength(GeometryCollection value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var length = SizeOfHeaderWithLength; - if (GetSrid(value.CRS) != 0) - length += sizeof(int); - - var geometries = value.Geometries; - for (var i = 0; i < geometries.Count; ++i) - length += ValidateAndGetLength((GeoJSONObject)geometries[i], ref lengthCache, parameter); - - return length; - } - - int INpgsqlTypeHandler.ValidateAndGetLength(IGeoJSONObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((GeoJSONObject)value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(IGeometryObject value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((GeoJSONObject)value, ref lengthCache, parameter); - - public override Task Write(GeoJSONObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Type switch - { - GeoJSONObjectType.Point => Write((Point)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.LineString => Write((LineString)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.Polygon => Write((Polygon)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiPoint => Write((MultiPoint)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiLineString => Write((MultiLineString)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.MultiPolygon => Write((MultiPolygon)value, buf, lengthCache, parameter, async, cancellationToken), - GeoJSONObjectType.GeometryCollection => Write((GeometryCollection)value, buf, lengthCache, parameter, async, cancellationToken), - _ => throw UnknownPostGisType() - }; - - public async Task Write(Point value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.Point; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - - if (srid != 0) - buf.WriteInt32(srid); - - await WritePosition(value.Coordinates, buf, async, cancellationToken); - } - - public async Task Write(LineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.LineString; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await WritePosition(coordinates[i], buf, async, cancellationToken); - } - - public async Task Write(Polygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.Polygon; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var lines = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(lines.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < lines.Count; ++i) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - var coordinates = lines[i].Coordinates; - buf.WriteInt32(coordinates.Count); - for (var j = 0; j < coordinates.Count; ++j) - await WritePosition(coordinates[j], buf, async, cancellationToken); - } - } - - public async Task Write(MultiPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiPoint; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(MultiLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiLineString; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(MultiPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.MultiPolygon; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var coordinates = value.Coordinates; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(coordinates.Count); - - if (srid != 0) - buf.WriteInt32(srid); - for (var i = 0; i < coordinates.Count; ++i) - await Write(coordinates[i], buf, lengthCache, parameter, async, cancellationToken); - } - - public async Task Write(GeometryCollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = EwkbGeometryType.GeometryCollection; - var size = SizeOfHeader; - var srid = GetSrid(value.CRS); - if (srid != 0) - { - size += sizeof(int); - type |= EwkbGeometryType.HasSrid; - } - - if (buf.WriteSpaceLeft < size) - await buf.Flush(async, cancellationToken); - - var geometries = value.Geometries; - - buf.WriteByte(0); // Most significant byte first - buf.WriteInt32((int)type); - buf.WriteInt32(geometries.Count); - - if (srid != 0) - buf.WriteInt32(srid); - - for (var i = 0; i < geometries.Count; ++i) - await Write((GeoJSONObject) geometries[i], buf, lengthCache, parameter, async, cancellationToken); - } - - Task INpgsqlTypeHandler.Write(IGeoJSONObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => Write((GeoJSONObject)value, buf, lengthCache, parameter, async, cancellationToken); - - Task INpgsqlTypeHandler.Write(IGeometryObject value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => Write((GeoJSONObject)value, buf, lengthCache, parameter, async, cancellationToken); - - static async Task WritePosition(IPosition coordinate, NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - { - var altitude = coordinate.Altitude; - if (buf.WriteSpaceLeft < SizeOfPoint(altitude.HasValue)) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(coordinate.Longitude); - buf.WriteDouble(coordinate.Latitude); - if (altitude.HasValue) - buf.WriteDouble(altitude.Value); - } - - #endregion - - #region Crs - - NamedCRS? GetCrs(int srid) - { - var crsType = CrsType; - if (crsType == GeoJSONOptions.None) - return null; - - if (_lastSrid == srid && _lastCrs != null) - return _lastCrs; - - var authority = _crsMap.GetAuthority(srid); - if (authority == null) - throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table"); - - _lastCrs = new NamedCRS(crsType == GeoJSONOptions.LongCRS - ? "urn:ogc:def:crs:" + authority + "::" + srid : authority + ":" + srid); - _lastSrid = srid; - return _lastCrs; - } - - static int GetSrid(ICRSObject crs) - { - if (crs == null || crs is UnspecifiedCRS) - return 0; - - var namedCrs = crs as NamedCRS; - if (namedCrs == null) - throw new NotSupportedException("The LinkedCRS class isn't supported"); - - if (namedCrs.Properties.TryGetValue("name", out var value) && value != null) - { - var name = value.ToString()!; - if (string.Equals(name, "urn:ogc:def:crs:OGC::CRS84", StringComparison.Ordinal)) - return 4326; - - var index = name.LastIndexOf(':'); - if (index != -1 && int.TryParse(name.Substring(index + 1), out var srid)) - return srid; - - throw new FormatException("The specified CRS isn't properly named"); - } - - return 0; - } - - #endregion - } - - /// - /// Represents the identifier of the Well Known Binary representation of a geographical feature specified by the OGC. - /// http://portal.opengeospatial.org/files/?artifact_id=13227 Chapter 6.3.2.7 - /// - [Flags] - enum EwkbGeometryType : uint - { - // Types - Point = 1, - LineString = 2, - Polygon = 3, - MultiPoint = 4, - MultiLineString = 5, - MultiPolygon = 6, - GeometryCollection = 7, - - // Masks - BaseType = Point | LineString | Polygon | MultiPoint | MultiLineString | MultiPolygon | GeometryCollection, - - // Flags - HasSrid = 0x20000000, - HasM = 0x40000000, - HasZ = 0x80000000 - } -} diff --git a/src/Npgsql.GeoJSON/GeoJSONOptions.cs b/src/Npgsql.GeoJSON/GeoJSONOptions.cs new file mode 100644 index 0000000000..9aa8797529 --- /dev/null +++ b/src/Npgsql.GeoJSON/GeoJSONOptions.cs @@ -0,0 +1,15 @@ +using System; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + +// ReSharper disable once CheckNamespace +namespace Npgsql; + +[Flags] +public enum GeoJSONOptions +{ + None = 0, + BoundingBox = 1, + ShortCRS = 2, + LongCRS = 4 +} \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs b/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs new file mode 100644 index 0000000000..7702a7e0b3 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs @@ -0,0 +1,53 @@ +using GeoJSON.Net.Geometry; + +namespace Npgsql.GeoJSON.Internal; + +sealed class BoundingBoxBuilder +{ + bool _hasAltitude; + double _minLongitude, _maxLongitude; + double _minLatitude, _maxLatitude; + double _minAltitude, _maxAltitude; + + internal BoundingBoxBuilder() + { + _hasAltitude = false; + + _minLongitude = double.PositiveInfinity; + _minLatitude = double.PositiveInfinity; + _minAltitude = double.PositiveInfinity; + + _maxLongitude = double.NegativeInfinity; + _maxLatitude = double.NegativeInfinity; + _maxAltitude = double.NegativeInfinity; + } + + internal void Accumulate(Position position) + { + if (_minLongitude > position.Longitude) + _minLongitude = position.Longitude; + if (_maxLongitude < position.Longitude) + _maxLongitude = position.Longitude; + + if (_minLatitude > position.Latitude) + _minLatitude = position.Latitude; + if (_maxLatitude < position.Latitude) + _maxLatitude = position.Latitude; + + if (position.Altitude.HasValue) + { + var altitude = position.Altitude.Value; + if (_minAltitude > altitude) + _minAltitude = altitude; + if (_maxAltitude < altitude) + _maxAltitude = altitude; + + _hasAltitude = true; + } + } + + internal double[] Build() + => _hasAltitude + ? new[] { _minLongitude, _minLatitude, _minAltitude, _maxLongitude, _maxLatitude, _maxAltitude } + : new[] { _minLongitude, _minLatitude, _maxLongitude, _maxLatitude }; +} \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs new file mode 100644 index 0000000000..44829761c9 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs @@ -0,0 +1,54 @@ +using System; + +namespace Npgsql.GeoJSON.Internal; + +struct CrsMapBuilder +{ + CrsMapEntry[] _overrides; + int _overridenIndex; + int _wellKnownIndex; + + internal void Add(in CrsMapEntry entry) + { + var wellKnown = CrsMap.WellKnown[_wellKnownIndex]; + if (wellKnown.MinSrid == entry.MinSrid && + wellKnown.MaxSrid == entry.MaxSrid && + string.Equals(wellKnown.Authority, entry.Authority, StringComparison.Ordinal)) + { + _wellKnownIndex++; + return; + } + + if (wellKnown.MinSrid < entry.MinSrid) + { + do + _wellKnownIndex++; + while (CrsMap.WellKnown.Length < _wellKnownIndex && + CrsMap.WellKnown[_wellKnownIndex].MaxSrid < entry.MaxSrid); + AddCore(new CrsMapEntry(wellKnown.MinSrid, Math.Min(wellKnown.MaxSrid, entry.MinSrid - 1), null)); + } + + AddCore(entry); + } + + void AddCore(in CrsMapEntry entry) + { + var index = _overridenIndex + 1; + if (_overrides == null) + _overrides = new CrsMapEntry[4]; + else + if (_overrides.Length == index) + Array.Resize(ref _overrides, _overrides.Length << 1); + + _overrides[_overridenIndex] = entry; + _overridenIndex = index; + } + + internal CrsMap Build() + { + if (_overrides != null && _overrides.Length < _overridenIndex) + Array.Resize(ref _overrides, _overridenIndex); + + return new CrsMap(_overrides); + } +} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs new file mode 100644 index 0000000000..544fb306e8 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs @@ -0,0 +1,746 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using GeoJSON.Net; +using GeoJSON.Net.CoordinateReferenceSystem; +using GeoJSON.Net.Geometry; +using Npgsql.Internal; + +namespace Npgsql.GeoJSON.Internal; + +sealed class GeoJSONConverter : PgStreamingConverter where T : IGeoJSONObject +{ + readonly ConcurrentDictionary _cachedCrs = new(); + readonly GeoJSONOptions _options; + readonly Func _getCrs; + + public GeoJSONConverter(GeoJSONOptions options, CrsMap crsMap) + { + _options = options; + _getCrs = GetCrs( + crsMap, + _cachedCrs, + crsType: _options & (GeoJSONOptions.ShortCRS | GeoJSONOptions.LongCRS) + ); + } + + bool BoundingBox => (_options & GeoJSONOptions.BoundingBox) != 0; + + public override T Read(PgReader reader) + => (T)GeoJSONConverter.Read(async: false, reader, BoundingBox ? new BoundingBoxBuilder() : null, _getCrs, CancellationToken.None).GetAwaiter().GetResult(); + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T)await GeoJSONConverter.Read(async: true, reader, BoundingBox ? new BoundingBoxBuilder() : null, _getCrs, cancellationToken).ConfigureAwait(false); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => GeoJSONConverter.GetSize(context, value, ref writeState); + + public override void Write(PgWriter writer, T value) + => GeoJSONConverter.Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => GeoJSONConverter.Write(async: true, writer, value, CancellationToken.None); + + static Func GetCrs(CrsMap crsMap, ConcurrentDictionary cachedCrs, GeoJSONOptions crsType) + => srid => + { + if (crsType == GeoJSONOptions.None) + return null; + +#if NETSTANDARD2_0 + return cachedCrs.GetOrAdd(srid, srid => + { + var authority = crsMap.GetAuthority(srid); + + return authority is null + ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") + : new NamedCRS(crsType == GeoJSONOptions.LongCRS + ? "urn:ogc:def:crs:" + authority + "::" + srid + : authority + ":" + srid); + }); +#else + return cachedCrs.GetOrAdd(srid, static (srid, state) => + { + var (crsMap, crsType) = state; + var authority = crsMap.GetAuthority(srid); + + return authority is null + ? throw new InvalidOperationException($"SRID {srid} unknown in spatial_ref_sys table") + : new NamedCRS(crsType == GeoJSONOptions.LongCRS + ? "urn:ogc:def:crs:" + authority + "::" + srid + : authority + ":" + srid); + }, (crsMap, crsType)); +#endif + }; +} + +static class GeoJSONConverter +{ + public static async ValueTask Read(bool async, PgReader reader, BoundingBoxBuilder? boundingBox, Func getCrs, CancellationToken cancellationToken) + { + var geometry = await Core(async, reader, boundingBox, getCrs, cancellationToken).ConfigureAwait(false); + geometry.BoundingBoxes = boundingBox?.Build(); + return geometry; + + static async ValueTask Core(bool async, PgReader reader, BoundingBoxBuilder? boundingbox, Func getCrs, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(SizeOfHeader)) + await reader.BufferData(async, SizeOfHeader, cancellationToken).ConfigureAwait(false); + + var littleEndian = reader.ReadByte() > 0; + var type = (EwkbGeometryType)ReadUInt32(littleEndian); + + GeoJSONObject geometry; + NamedCRS? crs = null; + + if (HasSrid(type)) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.BufferData(async, sizeof(int), cancellationToken).ConfigureAwait(false); + crs = getCrs(ReadInt32(littleEndian)); + } + + switch (type & EwkbGeometryType.BaseType) + { + case EwkbGeometryType.Point: + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + geometry = new Point(position); + break; + } + + case EwkbGeometryType.LineString: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var i = 0; i < coordinates.Length; ++i) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[i] = position; + } + geometry = new LineString(coordinates); + break; + } + + case EwkbGeometryType.Polygon: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var i = 0; i < lines.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var j = 0; j < coordinates.Length; ++j) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[j] = position; + } + lines[i] = new LineString(coordinates); + } + geometry = new Polygon(lines); + break; + } + + case EwkbGeometryType.MultiPoint: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var points = new Point[ReadInt32(littleEndian)]; + for (var i = 0; i < points.Length; ++i) + { + if (SizeOfHeader + SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + points[i] = new Point(position); + } + geometry = new MultiPoint(points); + break; + } + + case EwkbGeometryType.MultiLineString: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var i = 0; i < lines.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfHeaderWithLength)) + await reader.BufferData(async, SizeOfHeaderWithLength, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var j = 0; j < coordinates.Length; ++j) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[j] = position; + } + lines[i] = new LineString(coordinates); + } + geometry = new MultiLineString(lines); + break; + } + + case EwkbGeometryType.MultiPolygon: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var polygons = new Polygon[ReadInt32(littleEndian)]; + for (var i = 0; i < polygons.Length; ++i) + { + if (reader.ShouldBuffer(SizeOfHeaderWithLength)) + await reader.BufferData(async, SizeOfHeaderWithLength, cancellationToken).ConfigureAwait(false); + + if (async) + await reader.ConsumeAsync(SizeOfHeader, cancellationToken).ConfigureAwait(false); + else + reader.Consume(SizeOfHeader); + + var lines = new LineString[ReadInt32(littleEndian)]; + for (var j = 0; j < lines.Length; ++j) + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + var coordinates = new Position[ReadInt32(littleEndian)]; + for (var k = 0; k < coordinates.Length; ++k) + { + if (SizeOfPoint(type) is var size && reader.ShouldBuffer(size)) + await reader.BufferData(async, size, cancellationToken).ConfigureAwait(false); + var position = ReadPosition(reader, type, littleEndian); + boundingbox?.Accumulate(position); + coordinates[k] = position; + } + lines[j] = new LineString(coordinates); + } + polygons[i] = new Polygon(lines); + } + geometry = new MultiPolygon(polygons); + break; + } + + case EwkbGeometryType.GeometryCollection: + { + if (reader.ShouldBuffer(SizeOfLength)) + await reader.BufferData(async, SizeOfLength, cancellationToken).ConfigureAwait(false); + + var elements = new IGeometryObject[ReadInt32(littleEndian)]; + for (var i = 0; i < elements.Length; ++i) + elements[i] = (IGeometryObject)await Core(async, reader, boundingbox, getCrs, cancellationToken).ConfigureAwait(false); + geometry = new GeometryCollection(elements); + break; + } + + default: + throw UnknownPostGisType(); + } + + geometry.CRS = crs; + return geometry; + + int ReadInt32(bool littleEndian) + => littleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadInt32()) : reader.ReadInt32(); + uint ReadUInt32(bool littleEndian) + => littleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadUInt32()) : reader.ReadUInt32(); + } + + static Position ReadPosition(PgReader reader, EwkbGeometryType type, bool littleEndian) + { + var position = new Position( + longitude: ReadDouble(littleEndian), + latitude: ReadDouble(littleEndian), + altitude: HasZ(type) ? reader.ReadDouble() : null); + if (HasM(type)) ReadDouble(littleEndian); + return position; + + double ReadDouble(bool littleEndian) + => littleEndian + ? BitConverter.Int64BitsToDouble(BinaryPrimitives.ReverseEndianness(BitConverter.DoubleToInt64Bits(reader.ReadDouble()))) + : reader.ReadDouble(); + } + } + + public static Size GetSize(SizeContext context, IGeoJSONObject value, ref object? writeState) + => value.Type switch + { + GeoJSONObjectType.Point => GetSize((Point)value), + GeoJSONObjectType.LineString => GetSize((LineString)value), + GeoJSONObjectType.Polygon => GetSize((Polygon)value), + GeoJSONObjectType.MultiPoint => GetSize((MultiPoint)value), + GeoJSONObjectType.MultiLineString => GetSize((MultiLineString)value), + GeoJSONObjectType.MultiPolygon => GetSize((MultiPolygon)value), + GeoJSONObjectType.GeometryCollection => GetSize(context, (GeometryCollection)value, ref writeState), + _ => throw UnknownPostGisType() + }; + + static bool NotValid(ReadOnlyCollection coordinates, out bool hasZ) + { + if (coordinates.Count == 0) + hasZ = false; + else + { + hasZ = HasZ(coordinates[0]); + for (var i = 1; i < coordinates.Count; ++i) + if (HasZ(coordinates[i]) != hasZ) return true; + } + return false; + } + + static Size GetSize(Point value) + { + var length = Size.Create(SizeOfHeader + SizeOfPoint(HasZ(value.Coordinates))); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + return length; + } + + static Size GetSize(LineString value) + { + var coordinates = value.Coordinates; + if (NotValid(coordinates, out var hasZ)) + throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + + var length = Size.Create(SizeOfHeaderWithLength + coordinates.Count * SizeOfPoint(hasZ)); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + return length; + } + + static Size GetSize(Polygon value) + { + var lines = value.Coordinates; + var length = Size.Create(SizeOfHeaderWithLength + SizeOfLength * lines.Count); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var hasZ = false; + for (var i = 0; i < lines.Count; ++i) + { + var coordinates = lines[i].Coordinates; + if (NotValid(coordinates, out var lineHasZ)) + throw AllOrNoneCoordiantesMustHaveZ(nameof(Polygon)); + + if (hasZ != lineHasZ) + { + if (i == 0) hasZ = lineHasZ; + else throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + } + + length = length.Combine(coordinates.Count * SizeOfPoint(hasZ)); + } + + return length; + } + + static Size GetSize(MultiPoint value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(MultiLineString value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(MultiPolygon value) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var coordinates = value.Coordinates; + foreach (var t in coordinates) + length = length.Combine(GetSize(t)); + + return length; + } + + static Size GetSize(SizeContext context, GeometryCollection value, ref object? writeState) + { + var length = Size.Create(SizeOfHeaderWithLength); + if (GetSrid(value.CRS) != 0) + length = length.Combine(sizeof(int)); + + var geometries = value.Geometries; + foreach (var t in geometries) + length = length.Combine(GetSize(context, (IGeoJSONObject)t, ref writeState)); + + return length; + } + + public static ValueTask Write(bool async, PgWriter writer, IGeoJSONObject value, CancellationToken cancellationToken = default) + => value.Type switch + { + GeoJSONObjectType.Point => Write(async, writer, (Point)value, cancellationToken), + GeoJSONObjectType.LineString => Write(async, writer, (LineString)value, cancellationToken), + GeoJSONObjectType.Polygon => Write(async, writer, (Polygon)value, cancellationToken), + GeoJSONObjectType.MultiPoint => Write(async, writer, (MultiPoint)value, cancellationToken), + GeoJSONObjectType.MultiLineString => Write(async, writer, (MultiLineString)value, cancellationToken), + GeoJSONObjectType.MultiPolygon => Write(async, writer, (MultiPolygon)value, cancellationToken), + GeoJSONObjectType.GeometryCollection => Write(async, writer, (GeometryCollection)value, cancellationToken), + _ => throw UnknownPostGisType() + }; + + static async ValueTask Write(bool async, PgWriter writer, Point value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.Point; + var size = SizeOfHeader; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + + if (srid != 0) + writer.WriteInt32(srid); + + await WritePosition(async, writer, value.Coordinates, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, LineString value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.LineString; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await WritePosition(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, Polygon value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.Polygon; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var lines = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(lines.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in lines) + { + if (writer.ShouldFlush(SizeOfLength)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + var coordinates = t.Coordinates; + writer.WriteInt32(coordinates.Count); + foreach (var t1 in coordinates) + await WritePosition(async, writer, t1, cancellationToken).ConfigureAwait(false); + } + } + + static async ValueTask Write(bool async, PgWriter writer, MultiPoint value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiPoint; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, MultiLineString value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiLineString; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, MultiPolygon value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.MultiPolygon; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var coordinates = value.Coordinates; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(coordinates.Count); + + if (srid != 0) + writer.WriteInt32(srid); + foreach (var t in coordinates) + await Write(async, writer, t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask Write(bool async, PgWriter writer, GeometryCollection value, CancellationToken cancellationToken) + { + var type = EwkbGeometryType.GeometryCollection; + var size = SizeOfHeaderWithLength; + var srid = GetSrid(value.CRS); + if (srid != 0) + { + size += sizeof(int); + type |= EwkbGeometryType.HasSrid; + } + + if (writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var geometries = value.Geometries; + + writer.WriteByte(0); // Most significant byte first + writer.WriteInt32((int)type); + writer.WriteInt32(geometries.Count); + + if (srid != 0) + writer.WriteInt32(srid); + + foreach (var t in geometries) + await Write(async, writer, (IGeoJSONObject)t, cancellationToken).ConfigureAwait(false); + } + + static async ValueTask WritePosition(bool async, PgWriter writer, IPosition coordinate, CancellationToken cancellationToken) + { + var altitude = coordinate.Altitude; + if (SizeOfPoint(altitude.HasValue) is var size && writer.ShouldFlush(size)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteDouble(coordinate.Longitude); + writer.WriteDouble(coordinate.Latitude); + if (altitude.HasValue) + writer.WriteDouble(altitude.Value); + } + + static ValueTask BufferData(this PgReader reader, bool async, int byteCount, CancellationToken cancellationToken) + { + if (async) + return reader.BufferAsync(byteCount, cancellationToken); + + reader.Buffer(byteCount); + return new(); + } + + static ValueTask Flush(this PgWriter writer, bool async, CancellationToken cancellationToken) + { + if (async) + return writer.FlushAsync(cancellationToken); + + writer.Flush(); + return new(); + } + + static bool HasSrid(EwkbGeometryType type) + => (type & EwkbGeometryType.HasSrid) != 0; + + static bool HasZ(EwkbGeometryType type) + => (type & EwkbGeometryType.HasZ) != 0; + + static bool HasM(EwkbGeometryType type) + => (type & EwkbGeometryType.HasM) != 0; + + static bool HasZ(IPosition coordinates) + => coordinates.Altitude.HasValue; + + const int SizeOfLength = sizeof(int); + const int SizeOfHeader = sizeof(byte) + sizeof(EwkbGeometryType); + const int SizeOfHeaderWithLength = SizeOfHeader + SizeOfLength; + const int SizeOfPoint2D = 2 * sizeof(double); + const int SizeOfPoint3D = 3 * sizeof(double); + + static int SizeOfPoint(bool hasZ) + => hasZ ? SizeOfPoint3D : SizeOfPoint2D; + + static int SizeOfPoint(EwkbGeometryType type) + { + var size = SizeOfPoint2D; + if (HasZ(type)) + size += sizeof(double); + if (HasM(type)) + size += sizeof(double); + return size; + } + + static Exception UnknownPostGisType() + => throw new InvalidOperationException("Invalid PostGIS type"); + + static Exception AllOrNoneCoordiantesMustHaveZ(string typeName) + => new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName}"); + + static int GetSrid(ICRSObject crs) + { + if (crs is null or UnspecifiedCRS) + return 0; + + var namedCrs = crs as NamedCRS; + if (namedCrs == null) + throw new NotSupportedException("The LinkedCRS class isn't supported"); + + if (namedCrs.Properties.TryGetValue("name", out var value) && value != null) + { + var name = value.ToString()!; + if (string.Equals(name, "urn:ogc:def:crs:OGC::CRS84", StringComparison.Ordinal)) + return 4326; + + var index = name.LastIndexOf(':'); + if (index != -1 && int.TryParse(name.Substring(index + 1), out var srid)) + return srid; + + throw new FormatException("The specified CRS isn't properly named"); + } + + return 0; + } +} + +/// +/// Represents the identifier of the Well Known Binary representation of a geographical feature specified by the OGC. +/// http://portal.opengeospatial.org/files/?artifact_id=13227 Chapter 6.3.2.7 +/// +[Flags] +enum EwkbGeometryType : uint +{ + // Types + Point = 1, + LineString = 2, + Polygon = 3, + MultiPoint = 4, + MultiLineString = 5, + MultiPolygon = 6, + GeometryCollection = 7, + + // Masks + BaseType = Point | LineString | Polygon | MultiPoint | MultiLineString | MultiPolygon | GeometryCollection, + + // Flags + HasSrid = 0x20000000, + HasM = 0x40000000, + HasZ = 0x80000000 +} diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..c25118f1d7 --- /dev/null +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs @@ -0,0 +1,116 @@ +using System; +using GeoJSON.Net; +using GeoJSON.Net.Geometry; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.GeoJSON.Internal; + +sealed class GeoJSONTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly GeoJSONOptions _options; + readonly bool _geographyAsDefault; + readonly CrsMap? _crsMap; + + public GeoJSONTypeInfoResolverFactory(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + { + _options = options; + _geographyAsDefault = geographyAsDefault; + _crsMap = crsMap; + } + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_options, _geographyAsDefault, _crsMap); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_options, _geographyAsDefault, _crsMap); + + class Resolver : IPgTypeInfoResolver + { + readonly GeoJSONOptions _options; + readonly bool _geographyAsDefault; + readonly CrsMap? _crsMap; + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _options, _geographyAsDefault, _crsMap); + + public Resolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + { + _options = options; + _geographyAsDefault = geographyAsDefault; + _crsMap = crsMap; + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, GeoJSONOptions geoJsonOptions, + bool geographyAsDefault, CrsMap? crsMap) + { + crsMap ??= new CrsMap(CrsMap.WellKnown); + + var geometryMatchRequirement = !geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName; + var geographyMatchRequirement = geographyAsDefault ? MatchRequirement.Single : MatchRequirement.DataTypeName; + + foreach (var dataTypeName in new[] { "geometry", "geography" }) + { + var matchRequirement = dataTypeName == "geometry" ? geometryMatchRequirement : geographyMatchRequirement; + + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new GeoJSONConverter(geoJsonOptions, crsMap)), + matchRequirement); + } + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public ArrayResolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + : base(options, geographyAsDefault, crsMap) + { + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { "geometry", "geography" }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + + return mappings; + } + } +} diff --git a/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj b/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj index fbbd4e029a..d0b66b8460 100644 --- a/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj +++ b/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj @@ -1,16 +1,20 @@  - Yoh Deadfall, Shay Rojansky + Yoh Deadfall;Shay Rojansky GeoJSON plugin for Npgsql, allowing mapping of PostGIS geometry types to GeoJSON types. - npgsql postgresql postgres postgis geojson spatial ado ado.net database sql + npgsql;postgresql;postgres;postgis;geojson;spatial;ado;ado.net;database;sql netstandard2.0 - net5.0 - false + net8.0 + $(NoWarn);NPG9001 + + + - + + diff --git a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs index 5068de4589..b47a9b211f 100644 --- a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs +++ b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs @@ -1,53 +1,37 @@ -using System; -using System.Data; -using GeoJSON.Net; -using GeoJSON.Net.Geometry; -using Npgsql.GeoJSON; +using Npgsql.GeoJSON; +using Npgsql.GeoJSON.Internal; using Npgsql.TypeMapping; -using NpgsqlTypes; // ReSharper disable once CheckNamespace -namespace Npgsql +namespace Npgsql; + +/// +/// Extension allowing adding the GeoJSON plugin to an Npgsql type mapper. +/// +public static class NpgsqlGeoJSONExtensions { /// - /// Extension allowing adding the GeoJSON plugin to an Npgsql type mapper. + /// Sets up GeoJSON mappings for the PostGIS types. /// - public static class NpgsqlGeoJSONExtensions + /// The type mapper to set up (global or connection-specific) + /// Options to use when constructing objects. + /// Specifies that the geography type is used for mapping by default. + public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) { - static readonly Type[] ClrTypes = new[] - { - typeof(GeoJSONObject), typeof(IGeoJSONObject), typeof(IGeometryObject), - typeof(Point), typeof(LineString), typeof(Polygon), - typeof(MultiPoint), typeof(MultiLineString), typeof(MultiPolygon), - typeof(GeometryCollection) - }; + mapper.AddTypeInfoResolverFactory(new GeoJSONTypeInfoResolverFactory(options, geographyAsDefault, crsMap: null)); + return mapper; + } - /// - /// Sets up GeoJSON mappings for the PostGIS types. - /// - /// The type mapper to set up (global or connection-specific) - /// Options to use when constructing objects. - /// Specifies that the geography type is used for mapping by default. - public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) - { - var factory = new GeoJSONHandlerFactory(options); - return mapper - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geometry", - NpgsqlDbType = NpgsqlDbType.Geometry, - ClrTypes = geographyAsDefault ? Type.EmptyTypes : ClrTypes, - InferredDbType = DbType.Object, - TypeHandlerFactory = factory - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geography", - NpgsqlDbType = NpgsqlDbType.Geography, - ClrTypes = geographyAsDefault ? ClrTypes : Type.EmptyTypes, - InferredDbType = DbType.Object, - TypeHandlerFactory = factory - }.Build()); - } + /// + /// Sets up GeoJSON mappings for the PostGIS types. + /// + /// The type mapper to set up (global or connection-specific) + /// A custom crs map that might contain more or less entries than the default well-known crs map. + /// Options to use when constructing objects. + /// Specifies that the geography type is used for mapping by default. + public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, CrsMap crsMap, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) + { + mapper.AddTypeInfoResolverFactory(new GeoJSONTypeInfoResolverFactory(options, geographyAsDefault, crsMap)); + return mapper; } } diff --git a/src/Npgsql.GeoJSON/Properties/AssemblyInfo.cs b/src/Npgsql.GeoJSON/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1a340b1a15 --- /dev/null +++ b/src/Npgsql.GeoJSON/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif diff --git a/src/Npgsql.GeoJSON/PublicAPI.Shipped.txt b/src/Npgsql.GeoJSON/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..7f92ef111d --- /dev/null +++ b/src/Npgsql.GeoJSON/PublicAPI.Shipped.txt @@ -0,0 +1,13 @@ +#nullable enable +Npgsql.GeoJSON.CrsMap +Npgsql.GeoJSON.CrsMapExtensions +Npgsql.GeoJSONOptions +Npgsql.GeoJSONOptions.BoundingBox = 1 -> Npgsql.GeoJSONOptions +Npgsql.GeoJSONOptions.LongCRS = 4 -> Npgsql.GeoJSONOptions +Npgsql.GeoJSONOptions.None = 0 -> Npgsql.GeoJSONOptions +Npgsql.GeoJSONOptions.ShortCRS = 2 -> Npgsql.GeoJSONOptions +Npgsql.NpgsqlGeoJSONExtensions +static Npgsql.GeoJSON.CrsMapExtensions.GetCrsMap(this Npgsql.NpgsqlDataSource! dataSource) -> Npgsql.GeoJSON.CrsMap! +static Npgsql.GeoJSON.CrsMapExtensions.GetCrsMapAsync(this Npgsql.NpgsqlDataSource! dataSource) -> System.Threading.Tasks.Task! +static Npgsql.NpgsqlGeoJSONExtensions.UseGeoJson(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper, Npgsql.GeoJSON.CrsMap! crsMap, Npgsql.GeoJSONOptions options = Npgsql.GeoJSONOptions.None, bool geographyAsDefault = false) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +static Npgsql.NpgsqlGeoJSONExtensions.UseGeoJson(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper, Npgsql.GeoJSONOptions options = Npgsql.GeoJSONOptions.None, bool geographyAsDefault = false) -> Npgsql.TypeMapping.INpgsqlTypeMapper! \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..ab058de62d --- /dev/null +++ b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs new file mode 100644 index 0000000000..42b7c88e0d --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs @@ -0,0 +1,121 @@ +using System; +using System.Globalization; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Npgsql.Internal; +using JsonSerializer = Newtonsoft.Json.JsonSerializer; + +namespace Npgsql.Json.NET.Internal; + +sealed class JsonNetJsonConverter : PgStreamingConverter +{ + readonly bool _jsonb; + readonly Encoding _textEncoding; + readonly JsonSerializerSettings _settings; + + public JsonNetJsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) + { + _jsonb = jsonb; + _textEncoding = textEncoding; + _settings = settings; + } + + public override T? Read(PgReader reader) + => (T?)JsonNetJsonConverter.Read(async: false, _jsonb, reader, typeof(T), _settings, _textEncoding, CancellationToken.None).GetAwaiter().GetResult(); + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T?)await JsonNetJsonConverter.Read(async: true, _jsonb, reader, typeof(T), _settings, _textEncoding, cancellationToken).ConfigureAwait(false); + + public override Size GetSize(SizeContext context, T? value, ref object? writeState) + => JsonNetJsonConverter.GetSize(_jsonb, context, typeof(T), _settings, _textEncoding, value, ref writeState); + + public override void Write(PgWriter writer, T? value) + => JsonNetJsonConverter.Write(_jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => JsonNetJsonConverter.Write(_jsonb, async: true, writer, cancellationToken); +} + +// Split out to avoid unneccesary code duplication. +static class JsonNetJsonConverter +{ + public const byte JsonbProtocolVersion = 1; + + public static async ValueTask Read(bool async, bool jsonb, PgReader reader, Type type, JsonSerializerSettings settings, Encoding encoding, CancellationToken cancellationToken) + { + if (jsonb) + { + if (reader.ShouldBuffer(sizeof(byte))) + { + if (async) + await reader.BufferAsync(sizeof(byte), cancellationToken).ConfigureAwait(false); + else + reader.Buffer(sizeof(byte)); + } + var version = reader.ReadByte(); + if (version != JsonbProtocolVersion) + throw new InvalidCastException($"Unknown jsonb wire format version {version}"); + } + + using var stream = reader.GetStream(); + var mem = new MemoryStream(); + if (async) + await stream.CopyToAsync(mem, Math.Min((int)mem.Length, 81920), cancellationToken).ConfigureAwait(false); + else + stream.CopyTo(mem); + mem.Position = 0; + var jsonSerializer = JsonSerializer.CreateDefault(settings); + using var textReader = new JsonTextReader(new StreamReader(mem, encoding)); + return jsonSerializer.Deserialize(textReader, type); + } + + public static Size GetSize(bool jsonb, SizeContext context, Type type, JsonSerializerSettings settings, Encoding encoding, object? value, ref object? writeState) + { + var jsonSerializer = JsonSerializer.CreateDefault(settings); + var sb = new StringBuilder(256); + var sw = new StringWriter(sb, CultureInfo.InvariantCulture); + using (var jsonWriter = new JsonTextWriter(sw)) + { + jsonWriter.Formatting = jsonSerializer.Formatting; + + jsonSerializer.Serialize(jsonWriter, value, type); + } + + var str = sw.ToString(); + var bytes = encoding.GetBytes(str); + writeState = bytes; + return bytes.Length + (jsonb ? sizeof(byte) : 0); + } + + public static async ValueTask Write(bool jsonb, bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (jsonb) + { + if (writer.ShouldFlush(sizeof(byte))) + { + if (async) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + else + writer.Flush(); + } + writer.WriteByte(JsonbProtocolVersion); + } + + ArraySegment buffer; + switch (writer.Current.WriteState) + { + case byte[] bytes: + buffer = new ArraySegment(bytes); + break; + default: + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + } + + if (async) + await writer.WriteBytesAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(buffer.AsSpan()); + } +} diff --git a/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..27f719deca --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs @@ -0,0 +1,132 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using Newtonsoft.Json; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Json.NET.Internal; + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +sealed class JsonNetPocoTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly Type[]? _jsonbClrTypes; + readonly Type[]? _jsonClrTypes; + readonly JsonSerializerSettings? _serializerSettings; + + public JsonNetPocoTypeInfoResolverFactory(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + { + _jsonbClrTypes = jsonbClrTypes; + _jsonClrTypes = jsonClrTypes; + _serializerSettings = serializerSettings; + } + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_jsonbClrTypes, _jsonClrTypes, _serializerSettings); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_jsonbClrTypes, _jsonClrTypes, _serializerSettings); + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + class Resolver : DynamicTypeInfoResolver, IPgTypeInfoResolver + { + readonly Type[]? _jsonbClrTypes; + readonly Type[]? _jsonClrTypes; + readonly JsonSerializerSettings _serializerSettings; + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _jsonbClrTypes ?? Array.Empty(), _jsonClrTypes ?? Array.Empty(), _serializerSettings); + + const string JsonDataTypeName = "pg_catalog.json"; + const string JsonbDataTypeName = "pg_catalog.jsonb"; + + public Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + { + _jsonbClrTypes = jsonbClrTypes; + _jsonClrTypes = jsonClrTypes; + // Capture default settings during construction. + _serializerSettings = serializerSettings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); + } + + TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, Type[] jsonbClrTypes, Type[] jsonClrTypes, JsonSerializerSettings serializerSettings) + { + AddUserMappings(mappings, jsonb: true, jsonbClrTypes, serializerSettings); + AddUserMappings(mappings, jsonb: false, jsonClrTypes, serializerSettings); + return mappings; + + static void AddUserMappings(TypeInfoMappingCollection mappings, bool jsonb, Type[] clrTypes, JsonSerializerSettings serializerSettings) + { + var dynamicMappings = CreateCollection(); + var dataTypeName = jsonb ? JsonbDataTypeName : JsonDataTypeName; + foreach (var jsonType in clrTypes) + { + dynamicMappings.AddMapping(jsonType, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateConverter(mapping.Type, jsonb, options.TextEncoding, serializerSettings))); + } + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + // Match all types except null, object and text types as long as DataTypeName (json/jsonb) is present. + if (type is null || type == typeof(object) || PgSerializerOptions.IsWellKnownTextType(type) + || dataTypeName != JsonbDataTypeName && dataTypeName != JsonDataTypeName) + return null; + + return CreateCollection().AddMapping(type, dataTypeName, (options, mapping, _) => + { + var jsonb = dataTypeName == JsonbDataTypeName; + return mapping.CreateInfo(options, + CreateConverter(mapping.Type, jsonb, options.TextEncoding, _serializerSettings)); + }); + } + + static PgConverter CreateConverter(Type valueType, bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) + => (PgConverter)Activator.CreateInstance( + typeof(JsonNetJsonConverter<>).MakeGenericType(valueType), + jsonb, + textEncoding, + settings + )!; + } + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), base.Mappings); + + public ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + : base(jsonbClrTypes, jsonClrTypes, serializerSettings) + { + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, TypeInfoMappingCollection baseMappings) + { + if (baseMappings.Items.Count == 0) + return mappings; + + var dynamicMappings = CreateCollection(baseMappings); + foreach (var mapping in baseMappings.Items) + dynamicMappings.AddArrayMapping(mapping.Type, mapping.DataTypeName); + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + + return mappings; + } + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; + } + +} + diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..1f07bf0252 --- /dev/null +++ b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs @@ -0,0 +1,77 @@ +using System; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Json.NET.Internal; + +sealed class JsonNetTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly JsonSerializerSettings? _settings; + + public JsonNetTypeInfoResolverFactory(JsonSerializerSettings? settings = null) => _settings = settings; + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_settings); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_settings); + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + readonly JsonSerializerSettings _serializerSettings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _serializerSettings); + + public Resolver(JsonSerializerSettings? settings = null) + { + // Capture default settings during construction. + _serializerSettings = settings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); + } + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, JsonSerializerSettings settings) + { + // Jsonb is the first default for JToken etc. + foreach (var dataTypeName in new[] { "jsonb", "json" }) + { + var jsonb = dataTypeName == "jsonb"; + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings)), isDefault: true); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); + } + + return mappings; + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public ArrayResolver(JsonSerializerSettings? settings = null) : base(settings) {} + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { "jsonb", "json" }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + + return mappings; + } + } +} + diff --git a/src/Npgsql.Json.NET/JsonHandler.cs b/src/Npgsql.Json.NET/JsonHandler.cs deleted file mode 100644 index 5d1b7522ff..0000000000 --- a/src/Npgsql.Json.NET/JsonHandler.cs +++ /dev/null @@ -1,110 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Newtonsoft.Json; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.Json.NET -{ - public class JsonHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly JsonSerializerSettings _settings; - - public JsonHandlerFactory(JsonSerializerSettings? settings = null) - => _settings = settings ?? new JsonSerializerSettings(); - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new JsonHandler(postgresType, conn, _settings); - } - - class JsonHandler : TypeHandlers.TextHandler - { - readonly JsonSerializerSettings _settings; - - public JsonHandler(PostgresType postgresType, NpgsqlConnection connection, JsonSerializerSettings settings) - : base(postgresType, connection) => _settings = settings; - - protected override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - if (typeof(T) == typeof(string) || - typeof(T) == typeof(char[]) || - typeof(T) == typeof(ArraySegment) || - typeof(T) == typeof(char) || - typeof(T) == typeof(byte[])) - { - return await base.Read(buf, len, async, fieldDescription); - } - - return JsonConvert.DeserializeObject(await base.Read(buf, len, async, fieldDescription), _settings); - } - - protected override int ValidateAndGetLength(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (typeof(T2) == typeof(string) || - typeof(T2) == typeof(char[]) || - typeof(T2) == typeof(ArraySegment) || - typeof(T2) == typeof(char) || - typeof(T2) == typeof(byte[])) - { - return base.ValidateAndGetLength(value, ref lengthCache, parameter); - } - - var serialized = JsonConvert.SerializeObject(value, _settings); - if (parameter != null) - parameter.ConvertedValue = serialized; - return base.ValidateAndGetLength(serialized, ref lengthCache, parameter); - } - - protected override Task WriteWithLength(T2 value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (typeof(T2) == typeof(string) || - typeof(T2) == typeof(char[]) || - typeof(T2) == typeof(ArraySegment) || - typeof(T2) == typeof(char) || - typeof(T2) == typeof(byte[])) - { - return base.WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - - // User POCO, read serialized representation from the validation phase - var serialized = parameter?.ConvertedValue != null - ? (string)parameter.ConvertedValue - : JsonConvert.SerializeObject(value, _settings); - return base.WriteWithLength(serialized, buf, lengthCache, parameter, async, cancellationToken); - } - - protected override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value is DBNull || - value is string || - value is char[] || - value is ArraySegment || - value is char || - value is byte[]) - { - return base.ValidateObjectAndGetLength(value, ref lengthCache, parameter); - } - - return ValidateAndGetLength(value, ref lengthCache, parameter); - } - - protected override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is DBNull || - value is string || - value is char[] || - value is ArraySegment || - value is char || - value is byte[]) - { - return base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - - return WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - } -} diff --git a/src/Npgsql.Json.NET/JsonbHandler.cs b/src/Npgsql.Json.NET/JsonbHandler.cs deleted file mode 100644 index 72d5a870e3..0000000000 --- a/src/Npgsql.Json.NET/JsonbHandler.cs +++ /dev/null @@ -1,110 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Newtonsoft.Json; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.Json.NET -{ - public class JsonbHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly JsonSerializerSettings _settings; - - public JsonbHandlerFactory(JsonSerializerSettings? settings = null) - => _settings = settings ?? new JsonSerializerSettings(); - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new JsonbHandler(postgresType, conn, _settings); - } - - class JsonbHandler : Npgsql.TypeHandlers.JsonHandler - { - readonly JsonSerializerSettings _settings; - - public JsonbHandler(PostgresType postgresType, NpgsqlConnection connection, JsonSerializerSettings settings) - : base(postgresType, connection, isJsonb: true) => _settings = settings; - - protected override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - if (typeof(T) == typeof(string) || - typeof(T) == typeof(char[]) || - typeof(T) == typeof(ArraySegment) || - typeof(T) == typeof(char) || - typeof(T) == typeof(byte[])) - { - return await base.Read(buf, len, async, fieldDescription); - } - - return JsonConvert.DeserializeObject(await base.Read(buf, len, async, fieldDescription), _settings); - } - - protected override int ValidateAndGetLength(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (typeof(T2) == typeof(string) || - typeof(T2) == typeof(char[]) || - typeof(T2) == typeof(ArraySegment) || - typeof(T2) == typeof(char) || - typeof(T2) == typeof(byte[])) - { - return base.ValidateAndGetLength(value, ref lengthCache, parameter); - } - - var serialized = JsonConvert.SerializeObject(value, _settings); - if (parameter != null) - parameter.ConvertedValue = serialized; - return base.ValidateAndGetLength(serialized, ref lengthCache, parameter); - } - - protected override Task WriteWithLength(T2 value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (typeof(T2) == typeof(string) || - typeof(T2) == typeof(char[]) || - typeof(T2) == typeof(ArraySegment) || - typeof(T2) == typeof(char) || - typeof(T2) == typeof(byte[])) - { - return base.WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - - // User POCO, read serialized representation from the validation phase - var serialized = parameter?.ConvertedValue != null - ? (string)parameter.ConvertedValue - : JsonConvert.SerializeObject(value, _settings); - return base.WriteWithLength(serialized, buf, lengthCache, parameter, async, cancellationToken); - } - - protected override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value is DBNull || - value is string || - value is char[] || - value is ArraySegment || - value is char || - value is byte[]) - { - return base.ValidateObjectAndGetLength(value, ref lengthCache, parameter); - } - - return ValidateAndGetLength(value, ref lengthCache, parameter); - } - - protected override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is DBNull || - value is string || - value is char[] || - value is ArraySegment || - value is char || - value is byte[]) - { - return base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - - return WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - } -} diff --git a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj index 3a3b6b1a7c..b6c50353b7 100644 --- a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj +++ b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj @@ -2,13 +2,18 @@ Shay Rojansky Json.NET plugin for Npgsql, allowing transparent serialization/deserialization of JSON objects directly to and from the database. - npgsql postgresql json postgres ado ado.net database sql - netstandard2.0 - net5.0 + npgsql;postgresql;json;postgres;ado;ado.net;database;sql + netstandard2.0;net6.0 + net8.0 + enable + $(NoWarn);NPG9001 + + + diff --git a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs index 1a65b62206..f2b33933b8 100644 --- a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs +++ b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs @@ -1,48 +1,40 @@ using System; -using Npgsql.Json.NET; +using System.Diagnostics.CodeAnalysis; using Npgsql.TypeMapping; using NpgsqlTypes; using Newtonsoft.Json; +using Npgsql.Json.NET.Internal; // ReSharper disable once CheckNamespace -namespace Npgsql +namespace Npgsql; + +/// +/// Extension allowing adding the Json.NET plugin to an Npgsql type mapper. +/// +public static class NpgsqlJsonNetExtensions { /// - /// Extension allowing adding the Json.NET plugin to an Npgsql type mapper. + /// Sets up JSON.NET mappings for the PostgreSQL json and jsonb types. /// - public static class NpgsqlJsonNetExtensions + /// The type mapper to set up. + /// Optional settings to customize JSON serialization. + /// + /// A list of CLR types to map to PostgreSQL jsonb (no need to specify ). + /// + /// + /// A list of CLR types to map to PostgreSQL json (no need to specify ). + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + public static INpgsqlTypeMapper UseJsonNet( + this INpgsqlTypeMapper mapper, + JsonSerializerSettings? settings = null, + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null) { - /// - /// Sets up JSON.NET mappings for the PostgreSQL json and jsonb types. - /// - /// The type mapper to set up (global or connection-specific) - /// A list of CLR types to map to PostgreSQL jsonb (no need to specify NpgsqlDbType.Jsonb) - /// A list of CLR types to map to PostgreSQL json (no need to specify NpgsqlDbType.Json) - /// Optional settings to customize JSON serialization - public static INpgsqlTypeMapper UseJsonNet( - this INpgsqlTypeMapper mapper, - Type[]? jsonbClrTypes = null, - Type[]? jsonClrTypes = null, - JsonSerializerSettings? settings = null - ) - { - mapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "jsonb", - NpgsqlDbType = NpgsqlDbType.Jsonb, - ClrTypes = jsonbClrTypes, - TypeHandlerFactory = new JsonbHandlerFactory(settings) - }.Build()); - - mapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "json", - NpgsqlDbType = NpgsqlDbType.Json, - ClrTypes = jsonClrTypes, - TypeHandlerFactory = new JsonHandlerFactory(settings) - }.Build()); - - return mapper; - } + // Reverse order + mapper.AddTypeInfoResolverFactory(new JsonNetPocoTypeInfoResolverFactory(jsonbClrTypes, jsonClrTypes, settings)); + mapper.AddTypeInfoResolverFactory(new JsonNetTypeInfoResolverFactory(settings)); + return mapper; } } diff --git a/src/Npgsql.Json.NET/Properties/AssemblyInfo.cs b/src/Npgsql.Json.NET/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1a340b1a15 --- /dev/null +++ b/src/Npgsql.Json.NET/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif diff --git a/src/Npgsql.Json.NET/PublicAPI.Shipped.txt b/src/Npgsql.Json.NET/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..912eb76bcb --- /dev/null +++ b/src/Npgsql.Json.NET/PublicAPI.Shipped.txt @@ -0,0 +1,3 @@ +#nullable enable +Npgsql.NpgsqlJsonNetExtensions +static Npgsql.NpgsqlJsonNetExtensions.UseJsonNet(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper, Newtonsoft.Json.JsonSerializerSettings? settings = null, System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! diff --git a/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt b/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..ab058de62d --- /dev/null +++ b/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/Npgsql.LegacyPostgis/CodeAnnotations.cs b/src/Npgsql.LegacyPostgis/CodeAnnotations.cs deleted file mode 100644 index 63c38a4d87..0000000000 --- a/src/Npgsql.LegacyPostgis/CodeAnnotations.cs +++ /dev/null @@ -1,237 +0,0 @@ -using System; - -#pragma warning disable 1591 -// ReSharper disable UnusedMember.Global -// ReSharper disable MemberCanBePrivate.Global -// ReSharper disable UnusedAutoPropertyAccessor.Global -// ReSharper disable IntroduceOptionalParameters.Global -// ReSharper disable MemberCanBeProtected.Global -// ReSharper disable InconsistentNaming -// ReSharper disable CheckNamespace - -namespace JetBrains.Annotations -{ - /// - /// Indicates that the value of the marked element could be null sometimes, - /// so the check for null is necessary before its usage. - /// - /// - /// [CanBeNull] public object Test() { return null; } - /// public void UseTest() { - /// var p = Test(); - /// var s = p.ToString(); // Warning: Possible 'System.NullReferenceException' - /// } - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field | AttributeTargets.Event)] - sealed class CanBeNullAttribute : Attribute { - // ReSharper disable once EmptyConstructor - public CanBeNullAttribute() {} - } - - /// - /// Indicates that the value of the marked element could never be null. - /// - /// - /// [NotNull] public object Foo() { - /// return null; // Warning: Possible 'null' assignment - /// } - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field | AttributeTargets.Event)] - sealed class NotNullAttribute : Attribute { } - - /// - /// Can be appplied to symbols of types derived from IEnumerable as well as to symbols of Task - /// and Lazy classes to indicate that the value of a collection item, of the Task.Result property - /// or of the Lazy.Value property can never be null. - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field)] - sealed class ItemNotNullAttribute : Attribute { } - - /// - /// Can be appplied to symbols of types derived from IEnumerable as well as to symbols of Task - /// and Lazy classes to indicate that the value of a collection item, of the Task.Result property - /// or of the Lazy.Value property can be null. - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field)] - sealed class ItemCanBeNullAttribute : Attribute { } - - /// - /// Indicates that the marked symbol is used implicitly (e.g. via reflection, in external library), - /// so this symbol will not be marked as unused (as well as by other usage inspections). - /// - [AttributeUsage(AttributeTargets.All)] - sealed class UsedImplicitlyAttribute : Attribute - { - public UsedImplicitlyAttribute() - : this(ImplicitUseKindFlags.Default, ImplicitUseTargetFlags.Default) - { } - - public UsedImplicitlyAttribute(ImplicitUseKindFlags useKindFlags) - : this(useKindFlags, ImplicitUseTargetFlags.Default) - { } - - public UsedImplicitlyAttribute(ImplicitUseTargetFlags targetFlags) - : this(ImplicitUseKindFlags.Default, targetFlags) - { } - - public UsedImplicitlyAttribute(ImplicitUseKindFlags useKindFlags, ImplicitUseTargetFlags targetFlags) - { - UseKindFlags = useKindFlags; - TargetFlags = targetFlags; - } - - public ImplicitUseKindFlags UseKindFlags { get; private set; } - public ImplicitUseTargetFlags TargetFlags { get; private set; } - } - - /// - /// Should be used on attributes and causes ReSharper to not mark symbols marked with such attributes - /// as unused (as well as by other usage inspections) - /// - [AttributeUsage(AttributeTargets.Class | AttributeTargets.GenericParameter)] - sealed class MeansImplicitUseAttribute : Attribute - { - public MeansImplicitUseAttribute() - : this(ImplicitUseKindFlags.Default, ImplicitUseTargetFlags.Default) - { } - - public MeansImplicitUseAttribute(ImplicitUseKindFlags useKindFlags) - : this(useKindFlags, ImplicitUseTargetFlags.Default) - { } - - public MeansImplicitUseAttribute(ImplicitUseTargetFlags targetFlags) - : this(ImplicitUseKindFlags.Default, targetFlags) - { } - - public MeansImplicitUseAttribute(ImplicitUseKindFlags useKindFlags, ImplicitUseTargetFlags targetFlags) - { - UseKindFlags = useKindFlags; - TargetFlags = targetFlags; - } - - [UsedImplicitly] - public ImplicitUseKindFlags UseKindFlags { get; private set; } - [UsedImplicitly] - public ImplicitUseTargetFlags TargetFlags { get; private set; } - } - - [Flags] - internal enum ImplicitUseKindFlags - { - Default = Access | Assign | InstantiatedWithFixedConstructorSignature, - /// Only entity marked with attribute considered used. - Access = 1, - /// Indicates implicit assignment to a member. - Assign = 2, - /// - /// Indicates implicit instantiation of a type with fixed constructor signature. - /// That means any unused constructor parameters won't be reported as such. - /// - InstantiatedWithFixedConstructorSignature = 4, - /// Indicates implicit instantiation of a type. - InstantiatedNoFixedConstructorSignature = 8, - } - - /// - /// Specify what is considered used implicitly when marked - /// with or . - /// - [Flags] - internal enum ImplicitUseTargetFlags - { - Default = Itself, - Itself = 1, - /// Members of entity marked with attribute are considered used. - Members = 2, - /// Entity marked with attribute and all its members considered used. - WithMembers = Itself | Members - } - - /// - /// Describes dependency between method input and output. - /// - /// - ///

Function Definition Table syntax:

- /// - /// FDT ::= FDTRow [;FDTRow]* - /// FDTRow ::= Input => Output | Output <= Input - /// Input ::= ParameterName: Value [, Input]* - /// Output ::= [ParameterName: Value]* {halt|stop|void|nothing|Value} - /// Value ::= true | false | null | notnull | canbenull - /// - /// If method has single input parameter, it's name could be omitted.
- /// Using halt (or void/nothing, which is the same) - /// for method output means that the methos doesn't return normally.
- /// canbenull annotation is only applicable for output parameters.
- /// You can use multiple [ContractAnnotation] for each FDT row, - /// or use single attribute with rows separated by semicolon.
- ///
- /// - /// - /// [ContractAnnotation("=> halt")] - /// public void TerminationMethod() - /// - /// - /// [ContractAnnotation("halt <= condition: false")] - /// public void Assert(bool condition, string text) // regular assertion method - /// - /// - /// [ContractAnnotation("s:null => true")] - /// public bool IsNullOrEmpty(string s) // string.IsNullOrEmpty() - /// - /// - /// // A method that returns null if the parameter is null, - /// // and not null if the parameter is not null - /// [ContractAnnotation("null => null; notnull => notnull")] - /// public object Transform(object data) - /// - /// - /// [ContractAnnotation("s:null=>false; =>true,result:notnull; =>false, result:null")] - /// public bool TryParse(string s, out Person result) - /// - /// - [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] - sealed class ContractAnnotationAttribute : Attribute - { - public ContractAnnotationAttribute([NotNull] string contract) - : this(contract, false) - { } - - public ContractAnnotationAttribute([NotNull] string contract, bool forceFullStates) - { - Contract = contract; - ForceFullStates = forceFullStates; - } - - public string Contract { get; private set; } - public bool ForceFullStates { get; private set; } - } - - /// - /// Indicates that the function argument should be string literal and match one - /// of the parameters of the caller function. For example, ReSharper annotates - /// the parameter of . - /// - /// - /// public void Foo(string param) { - /// if (param == null) - /// throw new ArgumentNullException("par"); // Warning: Cannot resolve symbol - /// } - /// - [AttributeUsage(AttributeTargets.Parameter)] - sealed class InvokerParameterNameAttribute : Attribute { } - - /// - /// Indicates that IEnumerable, passed as parameter, is not enumerated. - /// - [AttributeUsage(AttributeTargets.Parameter)] - sealed class NoEnumerationAttribute : Attribute { } -} diff --git a/src/Npgsql.LegacyPostgis/LegacyPostgisHandler.cs b/src/Npgsql.LegacyPostgis/LegacyPostgisHandler.cs deleted file mode 100644 index e386021880..0000000000 --- a/src/Npgsql.LegacyPostgis/LegacyPostgisHandler.cs +++ /dev/null @@ -1,378 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.LegacyPostgis -{ - public class LegacyPostgisHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new LegacyPostgisHandler(postgresType); - } - - class LegacyPostgisHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler - { - public LegacyPostgisHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(5, async); - var le = buf.ReadByte() != 0; - var id = buf.ReadUInt32(le); - - var srid = 0u; - if ((id & (uint)EwkbModifiers.HasSRID) != 0) - { - await buf.Ensure(4, async); - srid = buf.ReadUInt32(le); - } - - var geom = await DoRead(buf, (WkbIdentifier)(id & 7), le, async); - geom.SRID = srid; - return geom; - } - - async ValueTask DoRead(NpgsqlReadBuffer buf, WkbIdentifier id, bool le, bool async) - { - switch (id) - { - case WkbIdentifier.Point: - await buf.Ensure(16, async); - return new PostgisPoint(buf.ReadDouble(le), buf.ReadDouble(le)); - - case WkbIdentifier.LineString: - { - await buf.Ensure(4, async); - var points = new Coordinate2D[buf.ReadInt32(le)]; - for (var ipts = 0; ipts < points.Length; ipts++) - { - await buf.Ensure(16, async); - points[ipts] = new Coordinate2D(buf.ReadDouble(le), buf.ReadDouble(le)); - } - return new PostgisLineString(points); - } - - case WkbIdentifier.Polygon: - { - await buf.Ensure(4, async); - var rings = new Coordinate2D[buf.ReadInt32(le)][]; - - for (var irng = 0; irng < rings.Length; irng++) - { - await buf.Ensure(4, async); - rings[irng] = new Coordinate2D[buf.ReadInt32(le)]; - for (var ipts = 0; ipts < rings[irng].Length; ipts++) - { - await buf.Ensure(16, async); - rings[irng][ipts] = new Coordinate2D(buf.ReadDouble(le), buf.ReadDouble(le)); - } - } - return new PostgisPolygon(rings); - } - - case WkbIdentifier.MultiPoint: - { - await buf.Ensure(4, async); - var points = new Coordinate2D[buf.ReadInt32(le)]; - for (var ipts = 0; ipts < points.Length; ipts++) - { - await buf.Ensure(21, async); - await buf.Skip(5, async); - points[ipts] = new Coordinate2D(buf.ReadDouble(le), buf.ReadDouble(le)); - } - return new PostgisMultiPoint(points); - } - - case WkbIdentifier.MultiLineString: - { - await buf.Ensure(4, async); - var rings = new Coordinate2D[buf.ReadInt32(le)][]; - - for (var irng = 0; irng < rings.Length; irng++) - { - await buf.Ensure(9, async); - await buf.Skip(5, async); - rings[irng] = new Coordinate2D[buf.ReadInt32(le)]; - for (var ipts = 0; ipts < rings[irng].Length; ipts++) - { - await buf.Ensure(16, async); - rings[irng][ipts] = new Coordinate2D(buf.ReadDouble(le), buf.ReadDouble(le)); - } - } - return new PostgisMultiLineString(rings); - } - - case WkbIdentifier.MultiPolygon: - { - await buf.Ensure(4, async); - var pols = new Coordinate2D[buf.ReadInt32(le)][][]; - - for (var ipol = 0; ipol < pols.Length; ipol++) - { - await buf.Ensure(9, async); - await buf.Skip(5, async); - pols[ipol] = new Coordinate2D[buf.ReadInt32(le)][]; - for (var irng = 0; irng < pols[ipol].Length; irng++) - { - await buf.Ensure(4, async); - pols[ipol][irng] = new Coordinate2D[buf.ReadInt32(le)]; - for (var ipts = 0; ipts < pols[ipol][irng].Length; ipts++) - { - await buf.Ensure(16, async); - pols[ipol][irng][ipts] = new Coordinate2D(buf.ReadDouble(le), buf.ReadDouble(le)); - } - } - } - return new PostgisMultiPolygon(pols); - } - - case WkbIdentifier.GeometryCollection: - { - await buf.Ensure(4, async); - var g = new PostgisGeometry[buf.ReadInt32(le)]; - - for (var i = 0; i < g.Length; i++) - { - await buf.Ensure(5, async); - var elemLe = buf.ReadByte() != 0; - var elemId = (WkbIdentifier)(buf.ReadUInt32(le) & 7); - - g[i] = await DoRead(buf, elemId, elemLe, async); - } - return new PostgisGeometryCollection(g); - } - - default: - throw new InvalidOperationException("Unknown Postgis identifier."); - } - } - - #endregion Read - - #region Read concrete types - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisPoint)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisMultiPoint)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisLineString)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisMultiLineString)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisPolygon)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisMultiPolygon)await Read(buf, len, async, fieldDescription); - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (PostgisGeometryCollection)await Read(buf, len, async, fieldDescription); - - #endregion - - #region Write - - public override int ValidateAndGetLength(PostgisGeometry value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisMultiPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisMultiPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisMultiLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(PostgisGeometryCollection value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.GetLen(true); - - public int ValidateAndGetLength(byte[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Length; - - public override async Task Write(PostgisGeometry value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // Common header - if (value.SRID == 0) - { - if (buf.WriteSpaceLeft < 5) - await buf.Flush(async, cancellationToken); - buf.WriteByte(0); // We choose to ouput only XDR structure - buf.WriteInt32((int)value.Identifier); - } - else - { - if (buf.WriteSpaceLeft < 9) - await buf.Flush(async, cancellationToken); - buf.WriteByte(0); - buf.WriteInt32((int) ((uint)value.Identifier | (uint)EwkbModifiers.HasSRID)); - buf.WriteInt32((int) value.SRID); - } - - switch (value.Identifier) - { - case WkbIdentifier.Point: - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - var p = (PostgisPoint)value; - buf.WriteDouble(p.X); - buf.WriteDouble(p.Y); - return; - - case WkbIdentifier.LineString: - var l = (PostgisLineString)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(l.PointCount); - for (var ipts = 0; ipts < l.PointCount; ipts++) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(l[ipts].X); - buf.WriteDouble(l[ipts].Y); - } - return; - - case WkbIdentifier.Polygon: - var pol = (PostgisPolygon)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(pol.RingCount); - for (var irng = 0; irng < pol.RingCount; irng++) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(pol[irng].Length); - for (var ipts = 0; ipts < pol[irng].Length; ipts++) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(pol[irng][ipts].X); - buf.WriteDouble(pol[irng][ipts].Y); - } - } - return; - - case WkbIdentifier.MultiPoint: - var mp = (PostgisMultiPoint)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(mp.PointCount); - for (var ipts = 0; ipts < mp.PointCount; ipts++) - { - if (buf.WriteSpaceLeft < 21) - await buf.Flush(async, cancellationToken); - buf.WriteByte(0); - buf.WriteInt32((int)WkbIdentifier.Point); - buf.WriteDouble(mp[ipts].X); - buf.WriteDouble(mp[ipts].Y); - } - return; - - case WkbIdentifier.MultiLineString: - var ml = (PostgisMultiLineString)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(ml.LineCount); - for (var irng = 0; irng < ml.LineCount; irng++) - { - if (buf.WriteSpaceLeft < 9) - await buf.Flush(async, cancellationToken); - buf.WriteByte(0); - buf.WriteInt32((int)WkbIdentifier.LineString); - buf.WriteInt32(ml[irng].PointCount); - for (var ipts = 0; ipts < ml[irng].PointCount; ipts++) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(ml[irng][ipts].X); - buf.WriteDouble(ml[irng][ipts].Y); - } - } - return; - - case WkbIdentifier.MultiPolygon: - var mpl = (PostgisMultiPolygon)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(mpl.PolygonCount); - for (var ipol = 0; ipol < mpl.PolygonCount; ipol++) - { - if (buf.WriteSpaceLeft < 9) - await buf.Flush(async, cancellationToken); - buf.WriteByte(0); - buf.WriteInt32((int)WkbIdentifier.Polygon); - buf.WriteInt32(mpl[ipol].RingCount); - for (var irng = 0; irng < mpl[ipol].RingCount; irng++) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(mpl[ipol][irng].Length); - for (var ipts = 0; ipts < mpl[ipol][irng].Length; ipts++) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(mpl[ipol][irng][ipts].X); - buf.WriteDouble(mpl[ipol][irng][ipts].Y); - } - } - } - return; - - case WkbIdentifier.GeometryCollection: - var coll = (PostgisGeometryCollection)value; - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(coll.GeometryCount); - - foreach (var x in coll) - await Write(x, buf, lengthCache, null, async, cancellationToken); - return; - - default: - throw new InvalidOperationException("Unknown Postgis identifier."); - } - } - - public Task Write(PostgisPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisMultiPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisMultiPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisMultiLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - public Task Write(PostgisGeometryCollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((PostgisGeometry)value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Write - } -} diff --git a/src/Npgsql.LegacyPostgis/Npgsql.LegacyPostgis.csproj b/src/Npgsql.LegacyPostgis/Npgsql.LegacyPostgis.csproj deleted file mode 100644 index 157730e96d..0000000000 --- a/src/Npgsql.LegacyPostgis/Npgsql.LegacyPostgis.csproj +++ /dev/null @@ -1,12 +0,0 @@ - - - Shay Rojansky - PostGIS plugin for Npgsql, allowing mapping of PostGIS types to the legacy types (e.g. PostgisPoint). - npgsql postgresql postgres postgis spatial geometry geography ado ado.net database sql - netstandard2.0 - net5.0 - - - - - diff --git a/src/Npgsql.LegacyPostgis/NpgsqlLegacyPostgisExtensions.cs b/src/Npgsql.LegacyPostgis/NpgsqlLegacyPostgisExtensions.cs deleted file mode 100644 index d9c49eaed3..0000000000 --- a/src/Npgsql.LegacyPostgis/NpgsqlLegacyPostgisExtensions.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Data; -using Npgsql.LegacyPostgis; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -// ReSharper disable once CheckNamespace -namespace Npgsql -{ - /// - /// Extension adding the legacy PostGIS types to an Npgsql type mapper. - /// - public static class NpgsqlLegacyPostgisExtensions - { - /// - /// Sets up the legacy PostGIS types to an Npgsql type mapper. - /// - /// The type mapper to set up (global or connection-specific) - public static INpgsqlTypeMapper UseLegacyPostgis(this INpgsqlTypeMapper mapper) - { - var typeHandlerFactory = new LegacyPostgisHandlerFactory(); - - return mapper - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geometry", - NpgsqlDbType = NpgsqlDbType.Geometry, - ClrTypes = new[] - { - typeof(PostgisGeometry), - typeof(PostgisPoint), - typeof(PostgisMultiPoint), - typeof(PostgisLineString), - typeof(PostgisMultiLineString), - typeof(PostgisPolygon), - typeof(PostgisMultiPolygon), - typeof(PostgisGeometryCollection), - }, - TypeHandlerFactory = typeHandlerFactory - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geography", - NpgsqlDbType = NpgsqlDbType.Geography, - DbTypes = new DbType[0], - ClrTypes = new Type[0], - InferredDbType = DbType.Object, - TypeHandlerFactory = typeHandlerFactory - }.Build()); - } - } -} diff --git a/src/Npgsql.LegacyPostgis/PostgisTypes.cs b/src/Npgsql.LegacyPostgis/PostgisTypes.cs deleted file mode 100644 index 6f7c059ba9..0000000000 --- a/src/Npgsql.LegacyPostgis/PostgisTypes.cs +++ /dev/null @@ -1,504 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using JetBrains.Annotations; - -#pragma warning disable CA1710 - -// ReSharper disable once CheckNamespace -namespace Npgsql.LegacyPostgis -{ -#pragma warning disable 1591 - /// - /// Represents the identifier of the Well Known Binary representation of a geographical feature specified by the OGC. - /// http://portal.opengeospatial.org/files/?artifact_id=13227 Chapter 6.3.2.7 - /// - enum WkbIdentifier : uint - { - Point = 1, - LineString = 2, - Polygon = 3, - MultiPoint = 4, - MultiLineString = 5, - MultiPolygon = 6, - GeometryCollection = 7 - } - - /// - /// The modifiers used by postgis to extend the geomtry's binary representation - /// - [Flags] - enum EwkbModifiers : uint - { - HasSRID = 0x20000000, - HasMDim = 0x40000000, - HasZDim = 0x80000000 - } - - /// - /// A structure representing a 2D double precision floating point coordinate; - /// - public struct Coordinate2D : IEquatable - { - /// - /// X coordinate. - /// - public double X { get; } - - /// - /// Y coordinate. - /// - public double Y { get; } - - /// - /// Generates a new BBpoint with the specified coordinates. - /// - /// X coordinate - /// Y coordinate - public Coordinate2D(double x, double y) { X = x; Y = y;} - - // ReSharper disable CompareOfFloatsByEqualityOperator - public bool Equals(Coordinate2D c) - => X == c.X && Y == c.Y; - // ReSharper restore CompareOfFloatsByEqualityOperator - - public override int GetHashCode() - => X.GetHashCode() ^ Util.RotateShift(Y.GetHashCode(), Util.BitsInInt / 2); - - public override bool Equals(object? obj) => obj is Coordinate2D coord && Equals(coord); - - public static bool operator ==(Coordinate2D left, Coordinate2D right) - => Equals(left, right); - - public static bool operator !=(Coordinate2D left, Coordinate2D right) - => !Equals(left, right); - } - - /// - /// Represents an Postgis feature. - /// - public abstract class PostgisGeometry - { - /// - /// returns the binary length of the data structure without header. - /// - /// - protected abstract int GetLenHelper(); - internal abstract WkbIdentifier Identifier { get;} - - internal int GetLen(bool includeSRID) - { - // header = - // 1 byte for the endianness of the structure - // + 4 bytes for the type identifier - // (+ 4 bytes for the SRID if present and included) - return 5 + (SRID == 0 || !includeSRID ? 0 : 4) + GetLenHelper(); - } - - /// - /// The Spatial Reference System Identifier of the geometry (0 if unspecified). - /// - public uint SRID { get; set; } - } - - /// - /// Represents an Postgis 2D Point - /// - public class PostgisPoint : PostgisGeometry, IEquatable - { - Coordinate2D _coord; - - internal override WkbIdentifier Identifier => WkbIdentifier.Point; - protected override int GetLenHelper() => 16; - - public PostgisPoint(double x, double y) => _coord = new Coordinate2D(x, y); - - public double X => _coord.X; - public double Y => _coord.Y; - - public bool Equals(PostgisPoint? other) - => !(other is null) && _coord.Equals(other._coord); - - public override bool Equals(object? obj) => Equals(obj as PostgisPoint); - - public static bool operator ==(PostgisPoint x, PostgisPoint y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisPoint x, PostgisPoint y) => !(x == y); - - public override int GetHashCode() => X.GetHashCode() ^ Util.RotateShift(Y.GetHashCode(), Util.BitsInInt / 2); - } - - /// - /// Represents an Ogc 2D LineString - /// - public class PostgisLineString : PostgisGeometry, IEquatable, IEnumerable - { - readonly Coordinate2D[] _points; - - internal override WkbIdentifier Identifier => WkbIdentifier.LineString; - protected override int GetLenHelper() => 4 + _points.Length * 16; - - public IEnumerator GetEnumerator() - => ((IEnumerable)_points).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public Coordinate2D this[int index] => _points[index]; - - public PostgisLineString(IEnumerable points) => _points = points.ToArray(); - - public PostgisLineString(Coordinate2D[] points) => _points = points; - - public int PointCount => _points.Length; - - public bool Equals(PostgisLineString? other) - { - if (ReferenceEquals(other , null)) - return false ; - - if (_points.Length != other._points.Length) - return false; - for (var i = 0; i < _points.Length; i++) - if (!_points[i].Equals(other._points[i])) - return false; - return true; - } - - public override bool Equals(object? obj) => Equals(obj as PostgisLineString); - - public static bool operator ==(PostgisLineString x, PostgisLineString y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisLineString x, PostgisLineString y) => !(x == y); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - foreach (var t in _points) - ret ^= Util.RotateShift(t.GetHashCode(), ret % Util.BitsInInt); - return ret; - } - } - - /// - /// Represents an Postgis 2D Polygon. - /// - public class PostgisPolygon : PostgisGeometry, IEquatable, IEnumerable> - { - readonly Coordinate2D[][] _rings; - - internal override WkbIdentifier Identifier => WkbIdentifier.Polygon; - protected override int GetLenHelper() => 4 + _rings.Length * 4 + TotalPointCount * 16; - - public Coordinate2D this[int ringIndex, int pointIndex] => _rings[ringIndex][pointIndex]; - public Coordinate2D[] this[int ringIndex] => _rings[ringIndex]; - - public PostgisPolygon(Coordinate2D[][] rings) => _rings = rings; - - public PostgisPolygon(IEnumerable> rings) - => _rings = rings.Select(x => x.ToArray()).ToArray(); - - public IEnumerator> GetEnumerator() - => ((IEnumerable>)_rings).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public bool Equals(PostgisPolygon? other) - { - if (other is null) - return false; - - if (_rings.Length != other._rings.Length) - return false; - for (var i = 0; i < _rings.Length; i++) - { - if (_rings[i].Length != other._rings[i].Length) - return false; - for (var j = 0; j < _rings[i].Length; j++) - if (!_rings[i][j].Equals (other._rings[i][j])) - return false; - } - return true; - } - - public override bool Equals(object? obj) - => Equals(obj as PostgisPolygon); - - public static bool operator ==(PostgisPolygon x, PostgisPolygon y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisPolygon x, PostgisPolygon y) => !(x == y); - - public int RingCount => _rings.Length; - public int TotalPointCount => _rings.Sum(r => r.Length); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - for (var i = 0; i < _rings.Length; i++) - for (var j = 0; j < _rings[i].Length; j++) - ret ^= Util.RotateShift(_rings[i][j].GetHashCode(), ret % Util.BitsInInt); - return ret; - } - } - - /// - /// Represents a Postgis 2D MultiPoint - /// - public class PostgisMultiPoint : PostgisGeometry, IEquatable, IEnumerable - { - readonly Coordinate2D[] _points; - - internal override WkbIdentifier Identifier => WkbIdentifier.MultiPoint; - - //each point of a multipoint is a postgispoint, not a building block point. - protected override int GetLenHelper() => 4 + _points.Length * 21; - - public IEnumerator GetEnumerator() => ((IEnumerable)_points).GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public PostgisMultiPoint (Coordinate2D[] points) - => _points = points; - - public PostgisMultiPoint(IEnumerable points) - => _points = points.Select(x => new Coordinate2D(x.X, x.Y)).ToArray(); - - public PostgisMultiPoint(IEnumerable points) - => _points = points.ToArray(); - - public Coordinate2D this[int indexer] => _points[indexer]; - - public bool Equals(PostgisMultiPoint? other) - { - if (ReferenceEquals(other ,null)) - return false ; - - if (_points.Length != other._points.Length) - return false; - for (var i = 0; i < _points.Length; i++) - if (!_points[i].Equals(other._points[i])) - return false; - return true; - } - - public override bool Equals(object? obj) => Equals(obj as PostgisMultiPoint); - - public static bool operator ==(PostgisMultiPoint x, PostgisMultiPoint y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisMultiPoint x, PostgisMultiPoint y) => !(x == y); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - for (var i = 0; i < _points.Length; i++) - ret ^= Util.RotateShift(_points[i].GetHashCode(), ret % Util.BitsInInt); - return ret; - } - - public int PointCount => _points.Length; - } - - /// - /// Represents a Postgis 2D MultiLineString - /// - public sealed class PostgisMultiLineString : PostgisGeometry, - IEquatable, IEnumerable - { - readonly PostgisLineString[] _lineStrings; - - internal PostgisMultiLineString(Coordinate2D[][] pointArray) - { - _lineStrings = new PostgisLineString[pointArray.Length]; - for (var i = 0; i < pointArray.Length; i++) - _lineStrings[i] = new PostgisLineString(pointArray[i]); - } - - internal override WkbIdentifier Identifier => WkbIdentifier.MultiLineString; - - protected override int GetLenHelper() - { - var n = 4; - for (var i = 0; i < _lineStrings.Length; i++) - n += _lineStrings[i].GetLen(false); - return n; - } - - public IEnumerator GetEnumerator() => ((IEnumerable)_lineStrings).GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public PostgisMultiLineString(PostgisLineString[] linestrings) - => _lineStrings = linestrings; - - public PostgisMultiLineString(IEnumerable linestrings) - => _lineStrings = linestrings.ToArray(); - - public PostgisMultiLineString(IEnumerable> pointList) - => _lineStrings = pointList.Select(x => new PostgisLineString(x)).ToArray(); - - public PostgisLineString this[int index] => _lineStrings[index]; - - public bool Equals(PostgisMultiLineString? other) - { - if (other is null) - return false ; - - if (_lineStrings.Length != other._lineStrings.Length) return false; - for (var i = 0; i < _lineStrings.Length; i++) - { - if (_lineStrings[i] != other._lineStrings[i]) return false; - } - return true; - } - - public override bool Equals(object? obj) - => obj is PostgisMultiLineString multiLineString && Equals(multiLineString); - - public static bool operator ==(PostgisMultiLineString x, PostgisMultiLineString y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisMultiLineString x, PostgisMultiLineString y) => !(x == y); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - for (var i = 0; i < _lineStrings.Length; i++) - ret ^= Util.RotateShift(_lineStrings[i].GetHashCode(), ret % Util.BitsInInt); - return ret; - } - - public int LineCount => _lineStrings.Length; - } - - /// - /// Represents a Postgis 2D MultiPolygon. - /// - public class PostgisMultiPolygon : PostgisGeometry, IEquatable, IEnumerable - { - readonly PostgisPolygon[] _polygons; - - public IEnumerator GetEnumerator() => ((IEnumerable)_polygons).GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - internal override WkbIdentifier Identifier => WkbIdentifier.MultiPolygon; - - public PostgisPolygon this[int index] => _polygons[index]; - - public PostgisMultiPolygon(PostgisPolygon[] polygons) - => _polygons = polygons; - - public PostgisMultiPolygon(IEnumerable polygons) - => _polygons = polygons.ToArray(); - - public PostgisMultiPolygon(IEnumerable>> ringList) - => _polygons = ringList.Select(x => new PostgisPolygon(x)).ToArray(); - - public bool Equals(PostgisMultiPolygon? other) - { - if (other is null) - return false; - if (_polygons.Length != other._polygons.Length) - return false; - for (var i = 0; i < _polygons.Length; i++) - if (_polygons[i] != other._polygons[i]) return false; - return true; - } - - public override bool Equals(object? obj) - => obj is PostgisMultiPolygon polygon && Equals(polygon); - - public static bool operator ==(PostgisMultiPolygon x, PostgisMultiPolygon y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisMultiPolygon x, PostgisMultiPolygon y) => !(x == y); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - for (var i = 0; i < _polygons.Length; i++) - ret ^= Util.RotateShift(_polygons[i].GetHashCode(), ret % Util.BitsInInt); - return ret; - } - - protected override int GetLenHelper() - { - var n = 4; - for (var i = 0; i < _polygons.Length; i++) - n += _polygons[i].GetLen(false); - return n; - } - - - public int PolygonCount => _polygons.Length; - } - - /// - /// Represents a collection of Postgis feature. - /// - public class PostgisGeometryCollection : PostgisGeometry, IEquatable, IEnumerable - { - readonly PostgisGeometry[] _geometries; - - public PostgisGeometry this[int index] => _geometries[index]; - - internal override WkbIdentifier Identifier => WkbIdentifier.GeometryCollection; - - public IEnumerator GetEnumerator() => ((IEnumerable)_geometries).GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public PostgisGeometryCollection(PostgisGeometry[] geometries) => _geometries = geometries; - - public PostgisGeometryCollection(IEnumerable geometries) => _geometries = geometries.ToArray(); - - public bool Equals(PostgisGeometryCollection? other) - { - if (other is null) - return false; - if (_geometries.Length != other._geometries.Length) - return false; - for (var i = 0; i < _geometries.Length; i++) - if (!_geometries[i].Equals(other._geometries[i])) - return false; - return true; - } - - public override bool Equals(object? obj) - => obj is PostgisGeometryCollection collection && Equals(collection); - - public static bool operator ==(PostgisGeometryCollection x, PostgisGeometryCollection y) - => x is null ? y is null : x.Equals(y); - - public static bool operator !=(PostgisGeometryCollection x, PostgisGeometryCollection y) => !(x == y); - - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - for (var i = 0; i < _geometries.Length; i++) - ret ^= Util.RotateShift(_geometries[i].GetHashCode(), ret % Util.BitsInInt); - return ret; - } - - protected override int GetLenHelper() - { - var n = 4; - for (var i = 0; i < _geometries.Length; i++) - n += _geometries[i].GetLen(true); - return n; - } - - public int GeometryCount => _geometries.Length; - } - - static class Util - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static int RotateShift(int val, int shift) - => (val << shift) | (val >> (BitsInInt - shift)); - - internal const int BitsInInt = sizeof(int) * 8; - } -} diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs new file mode 100644 index 0000000000..45597e7059 --- /dev/null +++ b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteConverter.cs @@ -0,0 +1,79 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using NetTopologySuite.Geometries; +using NetTopologySuite.IO; +using Npgsql.Internal; + +namespace Npgsql.NetTopologySuite.Internal; + +sealed class NetTopologySuiteConverter : PgStreamingConverter + where T : Geometry +{ + readonly PostGisReader _reader; + readonly PostGisWriter _writer; + + internal NetTopologySuiteConverter(PostGisReader reader, PostGisWriter writer) + => (_reader, _writer) = (reader, writer); + + public override T Read(PgReader reader) + => (T)_reader.Read(reader.GetStream()); + + // PostGisReader/PostGisWriter doesn't support async + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => new(Read(reader)); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var lengthStream = new LengthStream(); + lengthStream.SetLength(0); + _writer.Write(value, lengthStream); + return (int)lengthStream.Length; + } + + public override void Write(PgWriter writer, T value) + => _writer.Write(value, writer.GetStream(allowMixedIO: true)); + + // PostGisReader/PostGisWriter doesn't support async + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + { + Write(writer, value); + return default; + } + + sealed class LengthStream : Stream + { + long _length; + + public override bool CanRead => false; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => _length; + + public override long Position + { + get => _length; + set => throw new NotSupportedException(); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + public override void SetLength(long value) + => _length = value; + + public override void Write(byte[] buffer, int offset, int count) + => _length += count; + } +} diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..b9a559c12f --- /dev/null +++ b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs @@ -0,0 +1,117 @@ +using System; +using NetTopologySuite; +using NetTopologySuite.Geometries; +using NetTopologySuite.IO; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.NetTopologySuite.Internal; + +sealed class NetTopologySuiteTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly CoordinateSequenceFactory? _coordinateSequenceFactory; + readonly PrecisionModel? _precisionModel; + readonly Ordinates _handleOrdinates; + readonly bool _geographyAsDefault; + + public NetTopologySuiteTypeInfoResolverFactory(CoordinateSequenceFactory? coordinateSequenceFactory, PrecisionModel? precisionModel, + Ordinates handleOrdinates, bool geographyAsDefault) + { + _coordinateSequenceFactory = coordinateSequenceFactory; + _precisionModel = precisionModel; + _handleOrdinates = handleOrdinates; + _geographyAsDefault = geographyAsDefault; + } + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_coordinateSequenceFactory, _precisionModel, _handleOrdinates, _geographyAsDefault); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_coordinateSequenceFactory, _precisionModel, _handleOrdinates, _geographyAsDefault); + + class Resolver : IPgTypeInfoResolver + { + readonly PostGisReader _gisReader; + protected readonly bool _geographyAsDefault; + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _gisReader, new(), _geographyAsDefault); + + public Resolver( + CoordinateSequenceFactory? coordinateSequenceFactory, + PrecisionModel? precisionModel, + Ordinates handleOrdinates, + bool geographyAsDefault) + { + coordinateSequenceFactory ??= NtsGeometryServices.Instance.DefaultCoordinateSequenceFactory; + precisionModel ??= NtsGeometryServices.Instance.DefaultPrecisionModel; + handleOrdinates = handleOrdinates == Ordinates.None ? coordinateSequenceFactory.Ordinates : handleOrdinates; + + _geographyAsDefault = geographyAsDefault; + _gisReader = new PostGisReader(coordinateSequenceFactory, precisionModel, handleOrdinates); + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, PostGisReader reader, PostGisWriter writer, + bool geographyAsDefault) + { + foreach (var dataTypeName in geographyAsDefault ? new[] {"geography", "geometry"} : new[] { "geometry", "geography" }) + { + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), + isDefault: true); + + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + mappings.AddType(dataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer))); + } + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), _geographyAsDefault); + + public ArrayResolver(CoordinateSequenceFactory? coordinateSequenceFactory, PrecisionModel? precisionModel, + Ordinates handleOrdinates, bool geographyAsDefault) + : base(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault) + { + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, bool geographyAsDefault) + { + foreach (var dataTypeName in geographyAsDefault ? new[] { "geography", "geometry" } : new[] { "geometry", "geography" }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + + return mappings; + } + } +} diff --git a/src/Npgsql.NetTopologySuite/NetTopologySuiteHandler.cs b/src/Npgsql.NetTopologySuite/NetTopologySuiteHandler.cs deleted file mode 100644 index 3e226d9970..0000000000 --- a/src/Npgsql.NetTopologySuite/NetTopologySuiteHandler.cs +++ /dev/null @@ -1,170 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using NetTopologySuite.Geometries; -using NetTopologySuite.IO; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NetTopologySuite -{ - class NetTopologySuiteHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler, - INpgsqlTypeHandler - { - readonly PostGisReader _reader; - readonly PostGisWriter _writer; - readonly LengthStream _lengthStream = new LengthStream(); - - internal NetTopologySuiteHandler(PostgresType postgresType, PostGisReader reader, PostGisWriter writer) - : base(postgresType) - { - _reader = reader; - _writer = writer; - } - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => ReadCore(buf, len); - - ValueTask ReadCore(NpgsqlReadBuffer buf, int len) - where T : Geometry - => new ValueTask((T)_reader.Read(buf.GetStream(len, false))); - - #endregion - - #region ValidateAndGetLength - - public override int ValidateAndGetLength(Geometry value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLengthCore(value); - - int INpgsqlTypeHandler.ValidateAndGetLength(Point value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(LineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(Polygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiPoint value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiLineString value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(MultiPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int INpgsqlTypeHandler.ValidateAndGetLength(GeometryCollection value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - int ValidateAndGetLengthCore(Geometry value) - { - _lengthStream.SetLength(0); - _writer.Write(value, _lengthStream); - return (int)_lengthStream.Length; - } - - sealed class LengthStream : Stream - { - long _length; - - public override bool CanRead => false; - - public override bool CanSeek => false; - - public override bool CanWrite => true; - - public override long Length => _length; - - public override long Position - { - get => _length; - set => throw new NotSupportedException(); - } - - public override void Flush() - { } - - public override int Read(byte[] buffer, int offset, int count) - => throw new NotSupportedException(); - - public override long Seek(long offset, SeekOrigin origin) - => throw new NotSupportedException(); - - public override void SetLength(long value) - => _length = value; - - public override void Write(byte[] buffer, int offset, int count) - => _length += count; - } - - #endregion - - #region Write - - public override Task Write(Geometry value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(Point value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(LineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(Polygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiPoint value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToke) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiLineString value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(MultiPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task INpgsqlTypeHandler.Write(GeometryCollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken) - => WriteCore(value, buf); - - Task WriteCore(Geometry value, NpgsqlWriteBuffer buf) - { - _writer.Write(value, buf.GetStream()); - return Task.CompletedTask; - } - - #endregion - } -} diff --git a/src/Npgsql.NetTopologySuite/NetTopologySuiteHandlerFactory.cs b/src/Npgsql.NetTopologySuite/NetTopologySuiteHandlerFactory.cs deleted file mode 100644 index eb09b597b3..0000000000 --- a/src/Npgsql.NetTopologySuite/NetTopologySuiteHandlerFactory.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using NetTopologySuite.Geometries; -using NetTopologySuite.IO; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NetTopologySuite -{ - public class NetTopologySuiteHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly PostGisReader _reader; - readonly PostGisWriter _writer; - - internal NetTopologySuiteHandlerFactory(PostGisReader reader, PostGisWriter writer) - { - _reader = reader ?? throw new ArgumentNullException(nameof(reader)); - _writer = writer ?? throw new ArgumentNullException(nameof(writer)); - } - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new NetTopologySuiteHandler(postgresType, _reader, _writer); - } -} diff --git a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj index 46b00e1190..2e9451a0d1 100644 --- a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj +++ b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj @@ -1,16 +1,30 @@  - Yoh Deadfall, Shay Rojansky + Shay Rojansky;Yoh Deadfall NetTopologySuite plugin for Npgsql, allowing mapping of PostGIS geometry types to NetTopologySuite types. - npgsql postgresql postgres postgis nts ado ado.net database sql + npgsql;postgresql;postgres;postgis;spatial;nettopologysuite;nts;ado;ado.net;database;sql + README.md netstandard2.0 - net5.0 + net8.0 $(NoWarn);NU5104 + $(NoWarn);NPG9001 + + + + + + + + + + + + - \ No newline at end of file + diff --git a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs index 0d81eeab6c..a30d023891 100644 --- a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs +++ b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs @@ -1,79 +1,33 @@ -using System; -using System.Data; -using NetTopologySuite; -using NetTopologySuite.Geometries; -using NetTopologySuite.IO; -using Npgsql.NetTopologySuite; +using NetTopologySuite.Geometries; +using Npgsql.NetTopologySuite.Internal; using Npgsql.TypeMapping; -using NpgsqlTypes; // ReSharper disable once CheckNamespace -namespace Npgsql +namespace Npgsql; + +/// +/// Extension allowing adding the NetTopologySuite plugin to an Npgsql type mapper. +/// +public static class NpgsqlNetTopologySuiteExtensions { /// - /// Extension allowing adding the NetTopologySuite plugin to an Npgsql type mapper. + /// Sets up NetTopologySuite mappings for the PostGIS types. /// - public static class NpgsqlNetTopologySuiteExtensions + /// The type mapper to set up (global or connection-specific). + /// The factory which knows how to build a particular implementation of ICoordinateSequence from an array of Coordinates. + /// Specifies the grid of allowable points. + /// Specifies the ordinates which will be handled. Not specified ordinates will be ignored. + /// If is specified, an actual value will be taken from + /// the property of . + /// Specifies that the geography type is used for mapping by default. + public static INpgsqlTypeMapper UseNetTopologySuite( + this INpgsqlTypeMapper mapper, + CoordinateSequenceFactory? coordinateSequenceFactory = null, + PrecisionModel? precisionModel = null, + Ordinates handleOrdinates = Ordinates.None, + bool geographyAsDefault = false) { - static readonly Type[] ClrTypes = - { - typeof(Geometry), - typeof(Point), - typeof(LineString), - typeof(Polygon), - typeof(MultiPoint), - typeof(MultiLineString), - typeof(MultiPolygon), - typeof(GeometryCollection), - }; - - /// - /// Sets up NetTopologySuite mappings for the PostGIS types. - /// - /// The type mapper to set up (global or connection-specific). - /// The factory which knows how to build a particular implementation of ICoordinateSequence from an array of Coordinates. - /// Specifies the grid of allowable points. - /// Specifies the ordinates which will be handled. Not specified ordinates will be ignored. - /// If is specified, an actual value will be taken from - /// the property of . - /// Specifies that the geography type is used for mapping by default. - public static INpgsqlTypeMapper UseNetTopologySuite( - this INpgsqlTypeMapper mapper, - CoordinateSequenceFactory? coordinateSequenceFactory = null, - PrecisionModel? precisionModel = null, - Ordinates handleOrdinates = Ordinates.None, - bool geographyAsDefault = false) - { - if (coordinateSequenceFactory == null) - coordinateSequenceFactory = NtsGeometryServices.Instance.DefaultCoordinateSequenceFactory; - - if (precisionModel == null) - precisionModel = NtsGeometryServices.Instance.DefaultPrecisionModel; - - if (handleOrdinates == Ordinates.None) - handleOrdinates = coordinateSequenceFactory.Ordinates; - - var typeHandlerFactory = new NetTopologySuiteHandlerFactory( - new PostGisReader(coordinateSequenceFactory, precisionModel, handleOrdinates), - new PostGisWriter()); - - return mapper - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geometry", - NpgsqlDbType = NpgsqlDbType.Geometry, - ClrTypes = geographyAsDefault ? Type.EmptyTypes : ClrTypes, - InferredDbType = DbType.Object, - TypeHandlerFactory = typeHandlerFactory - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "geography", - NpgsqlDbType = NpgsqlDbType.Geography, - ClrTypes = geographyAsDefault ? ClrTypes : Type.EmptyTypes, - InferredDbType = DbType.Object, - TypeHandlerFactory = typeHandlerFactory - }.Build()); - } + mapper.AddTypeInfoResolverFactory(new NetTopologySuiteTypeInfoResolverFactory(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault)); + return mapper; } } diff --git a/src/Npgsql.NetTopologySuite/Properties/AssemblyInfo.cs b/src/Npgsql.NetTopologySuite/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1a340b1a15 --- /dev/null +++ b/src/Npgsql.NetTopologySuite/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif diff --git a/src/Npgsql.NetTopologySuite/PublicAPI.Shipped.txt b/src/Npgsql.NetTopologySuite/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..a9ca3382e6 --- /dev/null +++ b/src/Npgsql.NetTopologySuite/PublicAPI.Shipped.txt @@ -0,0 +1,3 @@ +#nullable enable +Npgsql.NpgsqlNetTopologySuiteExtensions +static Npgsql.NpgsqlNetTopologySuiteExtensions.UseNetTopologySuite(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper, NetTopologySuite.Geometries.CoordinateSequenceFactory? coordinateSequenceFactory = null, NetTopologySuite.Geometries.PrecisionModel? precisionModel = null, NetTopologySuite.Geometries.Ordinates handleOrdinates = NetTopologySuite.Geometries.Ordinates.None, bool geographyAsDefault = false) -> Npgsql.TypeMapping.INpgsqlTypeMapper! diff --git a/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt b/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..ab058de62d --- /dev/null +++ b/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/Npgsql.NetTopologySuite/README.md b/src/Npgsql.NetTopologySuite/README.md new file mode 100644 index 0000000000..c38f46c10b --- /dev/null +++ b/src/Npgsql.NetTopologySuite/README.md @@ -0,0 +1,34 @@ +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. + +This package is an Npgsql plugin which allows you to interact with spatial data provided by the PostgreSQL [PostGIS extension](https://postgis.net); PostGIS is a mature, standard extension considered to provide top-of-the-line database spatial features. On the .NET side, the plugin adds support for the types from the [NetTopologySuite library](https://github.com/NetTopologySuite/NetTopologySuite), allowing you to read and write them directly to PostgreSQL. + +To use the NetTopologySuite plugin, add a dependency on this package and create a NpgsqlDataSource. + +```csharp +using Npgsql; +using NetTopologySuite.Geometries; + +var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString); + +dataSourceBuilder.UseNetTopologySuite(); + +var dataSource = dataSourceBuilder.Build(); +var conn = await dataSource.OpenConnectionAsync(); + +var point = new Point(new Coordinate(1d, 1d)); +conn.ExecuteNonQuery("CREATE TEMP TABLE data (geom GEOMETRY)"); +using (var cmd = new NpgsqlCommand("INSERT INTO data (geom) VALUES (@p)", conn)) +{ + cmd.Parameters.AddWithValue("@p", point); + cmd.ExecuteNonQuery(); +} + +using (var cmd = new NpgsqlCommand("SELECT geom FROM data", conn)) +using (var reader = cmd.ExecuteReader()) +{ + reader.Read(); + Assert.That(reader[0], Is.EqualTo(point)); +} +``` + +For more information, [visit the NetTopologySuite plugin documentation page](https://www.npgsql.org/doc/types/nts.html). diff --git a/src/Npgsql.NodaTime/DateHandler.cs b/src/Npgsql.NodaTime/DateHandler.cs deleted file mode 100644 index c4f82a74c4..0000000000 --- a/src/Npgsql.NodaTime/DateHandler.cs +++ /dev/null @@ -1,92 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using NpgsqlTypes; -using BclDateHandler = Npgsql.TypeHandlers.DateTimeHandlers.DateHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class DateHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - { - var csb = new NpgsqlConnectionStringBuilder(conn.ConnectionString); - return new DateHandler(postgresType, csb.ConvertInfinityDateTime); - } - } - - sealed class DateHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - /// Whether to convert positive and negative infinity values to Instant.{Max,Min}Value when - /// an Instant is requested - /// - readonly bool _convertInfinityDateTime; - readonly BclDateHandler _bclHandler; - - internal DateHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType) - { - _convertInfinityDateTime = convertInfinityDateTime; - _bclHandler = new BclDateHandler(postgresType, convertInfinityDateTime); - } - - public override LocalDate Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var value = buf.ReadInt32(); - if (_convertInfinityDateTime) - { - if (value == int.MaxValue) - return LocalDate.MaxIsoValue; - if (value == int.MinValue) - return LocalDate.MinIsoValue; - } - return new LocalDate().PlusDays(value + 730119); - } - - public override int ValidateAndGetLength(LocalDate value, NpgsqlParameter? parameter) - => 4; - - public override void Write(LocalDate value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (_convertInfinityDateTime) - { - if (value == LocalDate.MaxIsoValue) - { - buf.WriteInt32(int.MaxValue); - return; - } - if (value == LocalDate.MinIsoValue) - { - buf.WriteInt32(int.MinValue); - return; - } - } - - var totalDaysSinceEra = Period.Between(default(LocalDate), value, PeriodUnits.Days).Days; - buf.WriteInt32(totalDaysSinceEra - 730119); - } - - NpgsqlDate INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(NpgsqlDate value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(NpgsqlDate value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs new file mode 100644 index 0000000000..5e25d8bfcc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs @@ -0,0 +1,49 @@ +using System.Threading; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Internal; +using NpgsqlTypes; + +namespace Npgsql.NodaTime.Internal; + +public class DateIntervalConverter : PgStreamingConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly PgConverter> _rangeConverter; + + public DateIntervalConverter(PgConverter> rangeConverter, bool dateTimeInfinityConversions) + { + _rangeConverter = rangeConverter; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override DateInterval Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + var upperBound = range.UpperBound; + + if (upperBound != LocalDate.MaxIsoValue || !_dateTimeInfinityConversions) + upperBound -= Period.FromDays(1); + + return new(range.LowerBound, upperBound); + } + + public override Size GetSize(SizeContext context, DateInterval value, ref object? writeState) + => _rangeConverter.GetSize(context, new NpgsqlRange(value.Start, value.End), ref writeState); + + public override void Write(PgWriter writer, DateInterval value) + => _rangeConverter.Write(writer, new NpgsqlRange(value.Start, value.End)); + + public override ValueTask WriteAsync(PgWriter writer, DateInterval value, CancellationToken cancellationToken = default) + => _rangeConverter.WriteAsync(writer, new NpgsqlRange(value.Start, value.End), cancellationToken); +} diff --git a/src/Npgsql.NodaTime/Internal/DurationConverter.cs b/src/Npgsql.NodaTime/Internal/DurationConverter.cs new file mode 100644 index 0000000000..940ef29464 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/DurationConverter.cs @@ -0,0 +1,42 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.NodaTime.Properties; + +namespace Npgsql.NodaTime.Internal; + +sealed class DurationConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override Duration ReadCore(PgReader reader) + { + var microsecondsInDay = reader.ReadInt64(); + var days = reader.ReadInt32(); + var totalMonths = reader.ReadInt32(); + + if (totalMonths != 0) + throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadIntervalWithMonthsAsDuration); + + return Duration.FromDays(days) + Duration.FromNanoseconds(microsecondsInDay * 1000); + } + + protected override void WriteCore(PgWriter writer, Duration value) + { + const long microsecondsPerSecond = 1_000_000; + + // Note that the end result must be long + // see #3438 + var microsecondsInDay = + (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * + microsecondsPerSecond + value.SubsecondNanoseconds / 1000); // Take the microseconds, discard the nanosecond remainder + + writer.WriteInt64(microsecondsInDay); + writer.WriteInt32(value.Days); // days + writer.WriteInt32(0); // months + } +} diff --git a/src/Npgsql.NodaTime/Internal/IntervalConverter.cs b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs new file mode 100644 index 0000000000..3ca9ca9ab0 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs @@ -0,0 +1,57 @@ +using System.Threading; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Internal; +using NpgsqlTypes; + +namespace Npgsql.NodaTime.Internal; + +public class IntervalConverter : PgStreamingConverter +{ + readonly PgConverter> _rangeConverter; + + public IntervalConverter(PgConverter> rangeConverter) + => _rangeConverter = rangeConverter; + + public override Interval Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + // NodaTime Interval includes the start instant and excludes the end instant. + Instant? start = range.LowerBoundInfinite + ? null + : range.LowerBoundIsInclusive + ? range.LowerBound + : range.LowerBound + Duration.Epsilon; + Instant? end = range.UpperBoundInfinite + ? null + : range.UpperBoundIsInclusive + ? range.UpperBound + Duration.Epsilon + : range.UpperBound; + + return new(start, end); + } + + public override Size GetSize(SizeContext context, Interval value, ref object? writeState) + => _rangeConverter.GetSize(context, IntervalToNpgsqlRange(value), ref writeState); + + public override void Write(PgWriter writer, Interval value) + => _rangeConverter.Write(writer, IntervalToNpgsqlRange(value)); + + public override ValueTask WriteAsync(PgWriter writer, Interval value, CancellationToken cancellationToken = default) + => _rangeConverter.WriteAsync(writer, IntervalToNpgsqlRange(value), cancellationToken); + + static NpgsqlRange IntervalToNpgsqlRange(Interval interval) + => new( + interval.HasStart ? interval.Start : default, true, !interval.HasStart, + interval.HasEnd ? interval.End : default, false, !interval.HasEnd); +} diff --git a/src/Npgsql.NodaTime/Internal/LegacyConverters.cs b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs new file mode 100644 index 0000000000..54393a4821 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs @@ -0,0 +1,78 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.NodaTime.Internal; + +sealed class LegacyTimestampTzZonedDateTimeConverter : PgBufferedConverter +{ + readonly DateTimeZone _dateTimeZone; + readonly bool _dateTimeInfinityConversions; + + public LegacyTimestampTzZonedDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + { + _dateTimeZone = dateTimeZone; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override ZonedDateTime ReadCore(PgReader reader) + { + var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); + + return instant.InZone(_dateTimeZone); + } + + protected override void WriteCore(PgWriter writer, ZonedDateTime value) + { + var instant = value.ToInstant(); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new ArgumentException("Infinity values not supported for timestamp with time zone"); + + writer.WriteInt64(EncodeInstant(instant, _dateTimeInfinityConversions)); + } +} + +sealed class LegacyTimestampTzOffsetDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly DateTimeZone _dateTimeZone; + + public LegacyTimestampTzOffsetDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _dateTimeZone = dateTimeZone; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override OffsetDateTime ReadCore(PgReader reader) + { + var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); + + return instant.InZone(_dateTimeZone).ToOffsetDateTime(); + } + + protected override void WriteCore(PgWriter writer, OffsetDateTime value) + { + var instant = value.ToInstant(); + if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + throw new ArgumentException("Infinity values not supported for timestamp with time zone"); + + writer.WriteInt64(EncodeInstant(instant, true)); + } +} diff --git a/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs new file mode 100644 index 0000000000..e6be7fe69b --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs @@ -0,0 +1,52 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.NodaTime.Properties; + +namespace Npgsql.NodaTime.Internal; + +sealed class LocalDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LocalDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override LocalDate ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? LocalDate.MaxIsoValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? LocalDate.MinIsoValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + var value => new LocalDate().PlusDays(value + 730119) + }; + + protected override void WriteCore(PgWriter writer, LocalDate value) + { + if (_dateTimeInfinityConversions) + { + if (value == LocalDate.MaxIsoValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + if (value == LocalDate.MinIsoValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + var totalDaysSinceEra = Period.Between(default, value, PeriodUnits.Days).Days; + writer.WriteInt32(totalDaysSinceEra - 730119); + } +} diff --git a/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs b/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs new file mode 100644 index 0000000000..5849f45dfc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/LocalTimeConverter.cs @@ -0,0 +1,20 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class LocalTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + // PostgreSQL time resolution == 1 microsecond == 10 ticks + protected override LocalTime ReadCore(PgReader reader) + => LocalTime.FromTicksSinceMidnight(reader.ReadInt64() * 10); + + protected override void WriteCore(PgWriter writer, LocalTime value) + => writer.WriteInt64(value.TickOfDay / 10); +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs new file mode 100644 index 0000000000..42c6360dad --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using NodaTime; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.NodaTime.Internal; + +sealed partial class NodaTimeTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver? CreateMultirangeResolver() => new MultirangeResolver(); + public override IPgTypeInfoResolver? CreateMultirangeArrayResolver() => new MultirangeArrayResolver(); + + class MultirangeResolver : IPgTypeInfoResolver + { + protected static DataTypeName DateMultirangeDataTypeName => new("pg_catalog.datemultirange"); + protected static DataTypeName TimestampTzMultirangeDataTypeName => new("pg_catalog.tstzmultirange"); + protected static DataTypeName TimestampMultirangeDataTypeName => new("pg_catalog.tsmultirange"); + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tstzmultirange + mappings.AddType(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(new IntervalConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options)), + isDefault: true); + mappings.AddType>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(new IntervalConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter( + CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options), + options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter( + CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options), + options))); + mappings.AddType[]>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter( + CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options), + options))); + mappings.AddType>>(TimestampTzMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter( + CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options), + options))); + + // tsmultirange + mappings.AddType[]>(TimestampMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter( + CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(TimestampMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter( + CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options), + options))); + + // datemultirange + mappings.AddType(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter(new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + mappings.AddType>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter(new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions), options))); + mappings.AddType[]>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(DateMultirangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + + return mappings; + } + } + + sealed class MultirangeArrayResolver : MultirangeResolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tstzmultirange + mappings.AddArrayType(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType[]>(TimestampTzMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampTzMultirangeDataTypeName); + + // tsmultirange + mappings.AddArrayType[]>(TimestampMultirangeDataTypeName); + mappings.AddArrayType>>(TimestampMultirangeDataTypeName); + + // datemultirange + mappings.AddArrayType(DateMultirangeDataTypeName); + mappings.AddArrayType>(DateMultirangeDataTypeName); + mappings.AddArrayType[]>(DateMultirangeDataTypeName); + mappings.AddArrayType>>(DateMultirangeDataTypeName); + + return mappings; + } + } +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs new file mode 100644 index 0000000000..f62669333c --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs @@ -0,0 +1,93 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.NodaTime.Internal; + +sealed partial class NodaTimeTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver? CreateRangeResolver() => new RangeResolver(); + public override IPgTypeInfoResolver? CreateRangeArrayResolver() => new RangeArrayResolver(); + + class RangeResolver : IPgTypeInfoResolver + { + protected static DataTypeName DateRangeDataTypeName => new("pg_catalog.daterange"); + protected static DataTypeName TimestampTzRangeDataTypeName => new("pg_catalog.tstzrange"); + protected static DataTypeName TimestampRangeDataTypeName => new("pg_catalog.tsrange"); + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tstzrange + mappings.AddStructType(TimestampTzRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + new IntervalConverter( + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options))), + isDefault: true); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options))); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions), options))); + mappings.AddStructType>(TimestampTzRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions), options))); + + // tsrange + mappings.AddStructType>(TimestampRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateRangeConverter(new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + + // daterange + mappings.AddType(DateRangeDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new DateIntervalConverter( + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options), + options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType>(DateRangeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LocalDateConverter(options.EnableDateTimeInfinityConversions), options))); + + return mappings; + } + } + + sealed class RangeArrayResolver : RangeResolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tstzrange + mappings.AddStructArrayType(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + mappings.AddStructArrayType>(TimestampTzRangeDataTypeName); + + // tsrange + mappings.AddStructArrayType>(TimestampRangeDataTypeName); + + // daterange + mappings.AddArrayType(DateRangeDataTypeName); + mappings.AddStructArrayType>(DateRangeDataTypeName); + + return mappings; + } + } +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..de5548a569 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs @@ -0,0 +1,142 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.NodaTime.Internal; + +sealed partial class NodaTimeTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(); + + class Resolver : IPgTypeInfoResolver + { + protected static DataTypeName TimestampTzDataTypeName => new("pg_catalog.timestamptz"); + protected static DataTypeName TimestampDataTypeName => new("pg_catalog.timestamp"); + protected static DataTypeName DateDataTypeName => new("pg_catalog.date"); + protected static DataTypeName TimeDataTypeName => new("pg_catalog.time"); + protected static DataTypeName TimeTzDataTypeName => new("pg_catalog.timetz"); + protected static DataTypeName IntervalDataTypeName => new("pg_catalog.interval"); + + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // timestamp and timestamptz, legacy and non-legacy modes + if (LegacyTimestampBehavior) + { + // timestamp is the default for writing an Instant. + + // timestamp + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions))); + + // timestamptz + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LegacyTimestampTzZonedDateTimeConverter( + DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LegacyTimestampTzOffsetDateTimeConverter( + DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); + } + else + { + // timestamp + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), + isDefault: true); + + // timestamptz + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions))); + mappings.AddStructType(TimestampTzDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions))); + } + + // date + mappings.AddStructType(DateDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + + // time + mappings.AddStructType(TimeDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new LocalTimeConverter()), isDefault: true); + + // timetz + mappings.AddStructType(TimeTzDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new OffsetTimeConverter()), isDefault: true); + + // interval + mappings.AddType(IntervalDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new PeriodConverter()), isDefault: true); + mappings.AddStructType(IntervalDataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new DurationConverter())); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + if (LegacyTimestampBehavior) + { + // timestamp + mappings.AddStructArrayType(TimestampDataTypeName); + mappings.AddStructArrayType(TimestampDataTypeName); + + // timestamptz + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + } + else + { + // timestamp + mappings.AddStructArrayType(TimestampDataTypeName); + + // timestamptz + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + } + + // other + mappings.AddStructArrayType(DateDataTypeName); + mappings.AddStructArrayType(TimeDataTypeName); + mappings.AddStructArrayType(TimeTzDataTypeName); + mappings.AddArrayType(IntervalDataTypeName); + mappings.AddStructArrayType(IntervalDataTypeName); + + return mappings; + } + } +} diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs b/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs new file mode 100644 index 0000000000..1cf433759a --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/NodaTimeUtils.cs @@ -0,0 +1,63 @@ +using System; +using NodaTime; +using Npgsql.NodaTime.Properties; + +namespace Npgsql.NodaTime.Internal; + +static class NodaTimeUtils +{ +#if DEBUG + internal static bool LegacyTimestampBehavior; +#else + internal static readonly bool LegacyTimestampBehavior; +#endif + + static NodaTimeUtils() => LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out var enabled) && enabled; + + static readonly Instant Instant2000 = Instant.FromUtc(2000, 1, 1, 0, 0, 0); + static readonly Duration Plus292Years = Duration.FromDays(292 * 365); + static readonly Duration Minus292Years = -Plus292Years; + + /// + /// Decodes a PostgreSQL timestamp/timestamptz into a NodaTime Instant. + /// + /// The number of microseconds from 2000-01-01T00:00:00. + /// Whether infinity date/time conversions are enabled. + /// + /// Unfortunately NodaTime doesn't have Duration.FromMicroseconds(), so we decompose into milliseconds and nanoseconds. + /// + internal static Instant DecodeInstant(long value, bool dateTimeInfinityConversions) + => value switch + { + long.MaxValue => dateTimeInfinityConversions + ? Instant.MaxValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + long.MinValue => dateTimeInfinityConversions + ? Instant.MinValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), + _ => Instant2000 + Duration.FromMilliseconds(value / 1000) + Duration.FromNanoseconds(value % 1000 * 1000) + }; + + /// + /// Encodes a NodaTime Instant to a PostgreSQL timestamp/timestamptz. + /// + internal static long EncodeInstant(Instant instant, bool dateTimeInfinityConversions) + { + if (dateTimeInfinityConversions) + { + if (instant == Instant.MaxValue) + return long.MaxValue; + + if (instant == Instant.MinValue) + return long.MinValue; + } + + // We need to write the number of microseconds from 2000-01-01T00:00:00. + var since2000 = instant - Instant2000; + + // The nanoseconds may overflow, so fallback to BigInteger where necessary. + return since2000 >= Minus292Years && since2000 <= Plus292Years + ? since2000.ToInt64Nanoseconds() / 1000 + : (long)(since2000.ToBigIntegerNanoseconds() / 1000); + } +} diff --git a/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs b/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs new file mode 100644 index 0000000000..7c5499c2f8 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/OffsetTimeConverter.cs @@ -0,0 +1,23 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class OffsetTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int)); + return format is DataFormat.Binary; + } + + // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. + protected override OffsetTime ReadCore(PgReader reader) + => new(LocalTime.FromTicksSinceMidnight(reader.ReadInt64() * 10), Offset.FromSeconds(-reader.ReadInt32())); + + protected override void WriteCore(PgWriter writer, OffsetTime value) + { + writer.WriteInt64(value.TickOfDay / 10); + writer.WriteInt32(-(int)(value.Offset.Ticks / NodaConstants.TicksPerSecond)); + } +} diff --git a/src/Npgsql.NodaTime/Internal/PeriodConverter.cs b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs new file mode 100644 index 0000000000..4dbde48dbc --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs @@ -0,0 +1,46 @@ +using NodaTime; +using Npgsql.Internal; + +namespace Npgsql.NodaTime.Internal; + +sealed class PeriodConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override Period ReadCore(PgReader reader) + { + var microsecondsInDay = reader.ReadInt64(); + var days = reader.ReadInt32(); + var totalMonths = reader.ReadInt32(); + + // NodaTime will normalize most things (i.e. nanoseconds to milliseconds, seconds...) + // but it will not normalize months to years. + var months = totalMonths % 12; + var years = totalMonths / 12; + + return new PeriodBuilder + { + Nanoseconds = microsecondsInDay * 1000, + Days = days, + Months = months, + Years = years + }.Build().Normalize(); + } + + protected override void WriteCore(PgWriter writer, Period value) + { + // Note that the end result must be long + // see #3438 + var microsecondsInDay = + (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + + value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder + + writer.WriteInt64(microsecondsInDay); + writer.WriteInt32(value.Weeks * 7 + value.Days); // days + writer.WriteInt32(value.Years * 12 + value.Months); // months + } +} diff --git a/src/Npgsql.NodaTime/Internal/TimestampConverters.cs b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs new file mode 100644 index 0000000000..6808503638 --- /dev/null +++ b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs @@ -0,0 +1,106 @@ +using System; +using NodaTime; +using Npgsql.Internal; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.NodaTime.Internal; + +sealed class InstantConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public InstantConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override Instant ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + + protected override void WriteCore(PgWriter writer, Instant value) + => writer.WriteInt64(EncodeInstant(value, _dateTimeInfinityConversions)); +} + +sealed class ZonedDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public ZonedDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override ZonedDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc(); + + protected override void WriteCore(PgWriter writer, ZonedDateTime value) + { + if (value.Zone != DateTimeZone.Utc && !LegacyTimestampBehavior) + { + throw new ArgumentException( + $"Cannot write ZonedDateTime with Zone={value.Zone} to PostgreSQL type 'timestamp with time zone', " + + "only UTC is supported. " + + "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); + } + + writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + } +} + +sealed class OffsetDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public OffsetDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override OffsetDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).WithOffset(Offset.Zero); + + protected override void WriteCore(PgWriter writer, OffsetDateTime value) + { + if (value.Offset != Offset.Zero && !LegacyTimestampBehavior) + { + throw new ArgumentException( + $"Cannot write OffsetDateTime with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', " + + "only offset 0 (UTC) is supported. " + + "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); + } + + writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + } +} + +sealed class LocalDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LocalDateTimeConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override LocalDateTime ReadCore(PgReader reader) + => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc().LocalDateTime; + + protected override void WriteCore(PgWriter writer, LocalDateTime value) + => writer.WriteInt64(EncodeInstant(value.InUtc().ToInstant(), _dateTimeInfinityConversions)); +} diff --git a/src/Npgsql.NodaTime/IntervalHandler.cs b/src/Npgsql.NodaTime/IntervalHandler.cs deleted file mode 100644 index 6155456fe1..0000000000 --- a/src/Npgsql.NodaTime/IntervalHandler.cs +++ /dev/null @@ -1,112 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using NpgsqlTypes; -using BclIntervalHandler = Npgsql.TypeHandlers.DateTimeHandlers.IntervalHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class IntervalHandlerFactory : NpgsqlTypeHandlerFactory - { - // Check for the legacy floating point timestamps feature - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes - ? new IntervalHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - sealed class IntervalHandler : - NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler - { - readonly BclIntervalHandler _bclHandler; - - internal IntervalHandler(PostgresType postgresType) : base(postgresType) - => _bclHandler = new BclIntervalHandler(postgresType); - - public override Period Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var microsecondsInDay = buf.ReadInt64(); - var days = buf.ReadInt32(); - var totalMonths = buf.ReadInt32(); - - // NodaTime will normalize most things (i.e. nanoseconds to milliseconds, seconds...) - // but it will not normalize months to years. - var months = totalMonths % 12; - var years = totalMonths / 12; - - return new PeriodBuilder - { - Nanoseconds = microsecondsInDay * 1000, - Days = days, - Months = months, - Years = years - }.Build().Normalize(); - } - - public override int ValidateAndGetLength(Period value, NpgsqlParameter? parameter) - => 16; - - public override void Write(Period value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var microsecondsInDay = - (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + - value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder - - buf.WriteInt64(microsecondsInDay); - buf.WriteInt32(value.Weeks * 7 + value.Days); // days - buf.WriteInt32(value.Years * 12 + value.Months); // months - } - - Duration INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var microsecondsInDay = buf.ReadInt64(); - var days = buf.ReadInt32(); - var totalMonths = buf.ReadInt32(); - - if (totalMonths != 0) - throw new NpgsqlException("Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead."); - - return Duration.FromDays(days) + Duration.FromNanoseconds(microsecondsInDay * 1000); - } - - public int ValidateAndGetLength(Duration value, NpgsqlParameter? parameter) => 16; - - public void Write(Duration value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - const int microsecondsPerSecond = 1_000_000; - - var microsecondsInDay = - (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * - microsecondsPerSecond + value.SubsecondNanoseconds / 1000); // Take the microseconds, discard the nanosecond remainder - - buf.WriteInt64(microsecondsInDay); - buf.WriteInt32(value.Days); // days - buf.WriteInt32(0); // months - } - - NpgsqlTimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(NpgsqlTimeSpan value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(NpgsqlTimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj b/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj index fe1cdacf9c..b4ce274125 100644 --- a/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj +++ b/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj @@ -2,14 +2,42 @@ Shay Rojansky NodaTime plugin for Npgsql, allowing mapping of PostgreSQL date/time types to NodaTime types. - npgsql postgresql postgres nodatime date time ado ado.net database sql - netstandard2.0 - net5.0 + npgsql;postgresql;postgres;nodatime;date;time;ado;ado;net;database;sql + README.md + netstandard2.0;net6.0 + net8.0 + $(NoWarn);NPG9001 + + + + + + + + + + + + + + + + ResXFileCodeGenerator + NpgsqlNodaTimeStrings.Designer.cs + + + + + + True + True + NpgsqlNodaTimeStrings.resx + diff --git a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs index def715bd12..9ebf42e83f 100644 --- a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs +++ b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs @@ -1,69 +1,21 @@ -using System; -using System.Data; -using NodaTime; -using Npgsql.NodaTime; +using Npgsql.NodaTime.Internal; using Npgsql.TypeMapping; -using NpgsqlTypes; // ReSharper disable once CheckNamespace -namespace Npgsql +namespace Npgsql; + +/// +/// Extension adding the NodaTime plugin to an Npgsql type mapper. +/// +public static class NpgsqlNodaTimeExtensions { /// - /// Extension adding the NodaTime plugin to an Npgsql type mapper. + /// Sets up NodaTime mappings for the PostgreSQL date/time types. /// - public static class NpgsqlNodaTimeExtensions + /// The type mapper to set up (global or connection-specific) + public static INpgsqlTypeMapper UseNodaTime(this INpgsqlTypeMapper mapper) { - /// - /// Sets up NodaTime mappings for the PostgreSQL date/time types. - /// - /// The type mapper to set up (global or connection-specific) - public static INpgsqlTypeMapper UseNodaTime(this INpgsqlTypeMapper mapper) - => mapper - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "timestamp without time zone", - NpgsqlDbType = NpgsqlDbType.Timestamp, - DbTypes = new[] { DbType.DateTime, DbType.DateTime2 }, - ClrTypes = new[] { typeof(Instant), typeof(LocalDateTime), typeof(DateTime) }, - InferredDbType = DbType.DateTime, - TypeHandlerFactory = new TimestampHandlerFactory() - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "timestamp with time zone", - NpgsqlDbType = NpgsqlDbType.TimestampTz, - ClrTypes = new[] { typeof(ZonedDateTime), typeof(OffsetDateTime), typeof(DateTimeOffset) }, - TypeHandlerFactory = new TimestampTzHandlerFactory() - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "date", - NpgsqlDbType = NpgsqlDbType.Date, - DbTypes = new[] { DbType.Date }, - ClrTypes = new[] { typeof(LocalDate), typeof(NpgsqlDate) }, - TypeHandlerFactory = new DateHandlerFactory() - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "time without time zone", - NpgsqlDbType = NpgsqlDbType.Time, - DbTypes = new[] { DbType.Time }, - ClrTypes = new[] { typeof(LocalTime) }, - TypeHandlerFactory = new TimeHandlerFactory() - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "time with time zone", - NpgsqlDbType = NpgsqlDbType.TimeTz, - ClrTypes = new[] { typeof(OffsetTime) }, - TypeHandlerFactory = new TimeTzHandlerFactory() - }.Build()) - .AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "interval", - NpgsqlDbType = NpgsqlDbType.Interval, - ClrTypes = new[] { typeof(Period), typeof(Duration), typeof(TimeSpan), typeof(NpgsqlTimeSpan) }, - TypeHandlerFactory = new IntervalHandlerFactory() - }.Build()); + mapper.AddTypeInfoResolverFactory(new NodaTimeTypeInfoResolverFactory()); + return mapper; } } diff --git a/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs b/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..a03d5a93d6 --- /dev/null +++ b/src/Npgsql.NodaTime/Properties/AssemblyInfo.cs @@ -0,0 +1,12 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif + +[assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + +"0024000004800000940000000602000000240000525341310004000001000100" + +"2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + +"8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + +"7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + +"29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs new file mode 100644 index 0000000000..bc6511ea9a --- /dev/null +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs @@ -0,0 +1,60 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Npgsql.NodaTime.Properties { + using System; + + + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [System.Diagnostics.DebuggerNonUserCodeAttribute()] + [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class NpgsqlNodaTimeStrings { + + private static System.Resources.ResourceManager resourceMan; + + private static System.Globalization.CultureInfo resourceCulture; + + [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal NpgsqlNodaTimeStrings() { + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Resources.ResourceManager ResourceManager { + get { + if (object.Equals(null, resourceMan)) { + System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Npgsql.NodaTime.Properties.NpgsqlNodaTimeStrings", typeof(NpgsqlNodaTimeStrings).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + internal static string CannotReadInfinityValue { + get { + return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); + } + } + + internal static string CannotReadIntervalWithMonthsAsDuration { + get { + return ResourceManager.GetString("CannotReadIntervalWithMonthsAsDuration", resourceCulture); + } + } + } +} diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx new file mode 100644 index 0000000000..d3329f2a80 --- /dev/null +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx @@ -0,0 +1,27 @@ + + + + + + + + + + text/microsoft-resx + + + 1.3 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled. + + + Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead. + + diff --git a/src/Npgsql.NodaTime/PublicAPI.Shipped.txt b/src/Npgsql.NodaTime/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..998522184e --- /dev/null +++ b/src/Npgsql.NodaTime/PublicAPI.Shipped.txt @@ -0,0 +1,3 @@ +#nullable enable +Npgsql.NpgsqlNodaTimeExtensions +static Npgsql.NpgsqlNodaTimeExtensions.UseNodaTime(this Npgsql.TypeMapping.INpgsqlTypeMapper! mapper) -> Npgsql.TypeMapping.INpgsqlTypeMapper! diff --git a/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt b/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..ab058de62d --- /dev/null +++ b/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/Npgsql.NodaTime/README.md b/src/Npgsql.NodaTime/README.md new file mode 100644 index 0000000000..d24070920b --- /dev/null +++ b/src/Npgsql.NodaTime/README.md @@ -0,0 +1,33 @@ +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. + +This package is an Npgsql plugin which allows you to use the [NodaTime](https://nodatime.org) date/time library when interacting with PostgreSQL; this provides a better and safer API for dealing with date and time data. + +To use the NodaTime plugin, add a dependency on this package and create a NpgsqlDataSource. Once this is done, you can use NodaTime types when interacting with PostgreSQL, just as you would use e.g. `DateTime`: + +```csharp +using Npgsql; + +var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString); + +dataSourceBuilder.UseNodaTime(); + +var dataSource = dataSourceBuilder.Build(); +var conn = await dataSource.OpenConnectionAsync(); + +// Write NodaTime Instant to PostgreSQL "timestamp with time zone" (UTC) +using (var cmd = new NpgsqlCommand(@"INSERT INTO mytable (my_timestamptz) VALUES (@p)", conn)) +{ + cmd.Parameters.Add(new NpgsqlParameter("p", Instant.FromUtc(2011, 1, 1, 10, 30))); + cmd.ExecuteNonQuery(); +} + +// Read timestamp back from the database as an Instant +using (var cmd = new NpgsqlCommand(@"SELECT my_timestamptz FROM mytable", conn)) +using (var reader = cmd.ExecuteReader()) +{ + reader.Read(); + var instant = reader.GetFieldValue(0); +} +``` + +For more information, [visit the NodaTime plugin documentation page](https://www.npgsql.org/doc/types/nodatime.html). diff --git a/src/Npgsql.NodaTime/TimeHandler.cs b/src/Npgsql.NodaTime/TimeHandler.cs deleted file mode 100644 index 1b1fef821c..0000000000 --- a/src/Npgsql.NodaTime/TimeHandler.cs +++ /dev/null @@ -1,47 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using BclTimeHandler = Npgsql.TypeHandlers.DateTimeHandlers.TimeHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class TimeHandlerFactory : NpgsqlTypeHandlerFactory - { - // Check for the legacy floating point timestamps feature - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes - ? new TimeHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - sealed class TimeHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - readonly BclTimeHandler _bclHandler; - - internal TimeHandler(PostgresType postgresType) : base(postgresType) - => _bclHandler = new BclTimeHandler(postgresType); - - // PostgreSQL time resolution == 1 microsecond == 10 ticks - public override LocalTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => LocalTime.FromTicksSinceMidnight(buf.ReadInt64() * 10); - - public override int ValidateAndGetLength(LocalTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(LocalTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value.TickOfDay / 10); - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.NodaTime/TimeTzHandler.cs b/src/Npgsql.NodaTime/TimeTzHandler.cs deleted file mode 100644 index 79e2cfd86c..0000000000 --- a/src/Npgsql.NodaTime/TimeTzHandler.cs +++ /dev/null @@ -1,70 +0,0 @@ -using System; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using BclTimeTzHandler = Npgsql.TypeHandlers.DateTimeHandlers.TimeTzHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class TimeTzHandlerFactory : NpgsqlTypeHandlerFactory - { - // Check for the legacy floating point timestamps feature - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes - ? new TimeTzHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - sealed class TimeTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - readonly BclTimeTzHandler _bclHandler; - - internal TimeTzHandler(PostgresType postgresType) : base(postgresType) - => _bclHandler = new BclTimeTzHandler(postgresType); - - // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. - public override OffsetTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new OffsetTime( - LocalTime.FromTicksSinceMidnight(buf.ReadInt64() * 10), - Offset.FromSeconds(-buf.ReadInt32())); - - public override int ValidateAndGetLength(OffsetTime value, NpgsqlParameter? parameter) => 12; - - public override void Write(OffsetTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.TickOfDay / 10); - buf.WriteInt32(-(int)(value.Offset.Ticks / NodaConstants.TicksPerSecond)); - } - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.NodaTime/TimestampHandler.cs b/src/Npgsql.NodaTime/TimestampHandler.cs deleted file mode 100644 index 3804cc56ee..0000000000 --- a/src/Npgsql.NodaTime/TimestampHandler.cs +++ /dev/null @@ -1,186 +0,0 @@ -using System; -using System.Diagnostics; -using NodaTime; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using BclTimestampHandler = Npgsql.TypeHandlers.DateTimeHandlers.TimestampHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class TimestampHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - { - if (!conn.HasIntegerDateTimes) - throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - - var csb = new NpgsqlConnectionStringBuilder(conn.ConnectionString); - return new TimestampHandler(postgresType, csb.ConvertInfinityDateTime); - } - } - - sealed class TimestampHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - static readonly Instant Instant0 = Instant.FromUtc(1, 1, 1, 0, 0, 0); - static readonly Instant Instant2000 = Instant.FromUtc(2000, 1, 1, 0, 0, 0); - static readonly Duration Plus292Years = Duration.FromDays(292 * 365); - static readonly Duration Minus292Years = -Plus292Years; - - /// - /// Whether to convert positive and negative infinity values to Instant.{Max,Min}Value when - /// an Instant is requested - /// - readonly bool _convertInfinityDateTime; - readonly BclTimestampHandler _bclHandler; - - internal TimestampHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType) - { - _convertInfinityDateTime = convertInfinityDateTime; - _bclHandler = new BclTimestampHandler(postgresType, convertInfinityDateTime); - } - - #region Read - - public override Instant Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var value = buf.ReadInt64(); - if (_convertInfinityDateTime) - { - if (value == long.MaxValue) - return Instant.MaxValue; - if (value == long.MinValue) - return Instant.MinValue; - } - - return Decode(value); - } - - LocalDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var value = buf.ReadInt64(); - if (value == long.MaxValue || value == long.MinValue) - throw new NotSupportedException("Infinity values not supported when reading LocalDateTime, read as Instant instead"); - return Decode(value).InUtc().LocalDateTime; - } - - // value is the number of microseconds from 2000-01-01T00:00:00. - // Unfortunately NodaTime doesn't have Duration.FromMicroseconds(), so we decompose into milliseconds - // and nanoseconds - internal static Instant Decode(long value) - => Instant2000 + Duration.FromMilliseconds(value / 1000) + Duration.FromNanoseconds(value % 1000 * 1000); - - // This is legacy support for PostgreSQL's old floating-point timestamp encoding - finally removed in PG 10 and not used for a long - // time. Unfortunately CrateDB seems to use this for some reason. - internal static Instant Decode(double value) - { - Debug.Assert(!double.IsPositiveInfinity(value) && !double.IsNegativeInfinity(value)); - - if (value >= 0d) - { - var date = (int)value / 86400; - date += 730119; // 730119 = days since era (0001-01-01) for 2000-01-01 - var microsecondOfDay = (long)((value % 86400d) * 1000000d); - - return Instant0 + Duration.FromDays(date) + Duration.FromNanoseconds(microsecondOfDay * 1000); - } - else - { - value = -value; - var date = (int)value / 86400; - var microsecondOfDay = (long)((value % 86400d) * 1000000d); - if (microsecondOfDay != 0) - { - ++date; - microsecondOfDay = 86400000000L - microsecondOfDay; - } - - date = 730119 - date; // 730119 = days since era (0001-01-01) for 2000-01-01 - - return Instant0 + Duration.FromDays(date) + Duration.FromNanoseconds(microsecondOfDay * 1000); - } - } - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(Instant value, NpgsqlParameter? parameter) - => 8; - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(LocalDateTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(Instant value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (_convertInfinityDateTime) - { - if (value == Instant.MaxValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == Instant.MinValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - - WriteInteger(value, buf); - } - - void INpgsqlSimpleTypeHandler.Write(LocalDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => WriteInteger(value.InUtc().ToInstant(), buf); - - // We need to write the number of microseconds from 2000-01-01T00:00:00. - internal static void WriteInteger(Instant instant, NpgsqlWriteBuffer buf) - { - var since2000 = instant - Instant2000; - - // The nanoseconds may overflow, so fallback to BigInteger where necessary. - var microseconds = - since2000 >= Minus292Years && - since2000 <= Plus292Years - ? since2000.ToInt64Nanoseconds() / 1000 - : (long)(since2000.ToBigIntegerNanoseconds() / 1000); - - buf.WriteInt64(microseconds); - } - - // This is legacy support for PostgreSQL's old floating-point timestamp encoding - finally removed in PG 10 and not used for a long - // time. Unfortunately CrateDB seems to use this for some reason. - internal static void WriteDouble(Instant instant, NpgsqlWriteBuffer buf) - { - var localDateTime = instant.InUtc().LocalDateTime; - var totalDaysSinceEra = Period.Between(default(LocalDateTime), localDateTime, PeriodUnits.Days).Days; - var secondOfDay = localDateTime.NanosecondOfDay / 1000000000d; - - if (totalDaysSinceEra >= 730119) - { - var uSecsDate = (totalDaysSinceEra - 730119) * 86400d; - buf.WriteDouble(uSecsDate + secondOfDay); - } - else - { - var uSecsDate = (730119 - totalDaysSinceEra) * 86400d; - buf.WriteDouble(-(uSecsDate - secondOfDay)); - } - } - - #endregion Write - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => ((INpgsqlSimpleTypeHandler)_bclHandler).Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.NodaTime/TimestampTzHandler.cs b/src/Npgsql.NodaTime/TimestampTzHandler.cs deleted file mode 100644 index 9e775f4cd9..0000000000 --- a/src/Npgsql.NodaTime/TimestampTzHandler.cs +++ /dev/null @@ -1,143 +0,0 @@ -using System; -using NodaTime; -using NodaTime.TimeZones; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using BclTimestampTzHandler = Npgsql.TypeHandlers.DateTimeHandlers.TimestampTzHandler; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql.NodaTime -{ - public class TimestampTzHandlerFactory : NpgsqlTypeHandlerFactory - { - // Check for the legacy floating point timestamps feature - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - { - var csb = new NpgsqlConnectionStringBuilder(conn.ConnectionString); - return conn.HasIntegerDateTimes - ? new TimestampTzHandler(postgresType, csb.ConvertInfinityDateTime) - : throw new NotSupportedException( - $"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - } - - sealed class TimestampTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler - { - readonly IDateTimeZoneProvider _dateTimeZoneProvider; - readonly BclTimestampTzHandler _bclHandler; - - /// - /// Whether to convert positive and negative infinity values to Instant.{Max,Min}Value when - /// an Instant is requested - /// - readonly bool _convertInfinityDateTime; - - public TimestampTzHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType) - { - _dateTimeZoneProvider = DateTimeZoneProviders.Tzdb; - _convertInfinityDateTime = convertInfinityDateTime; - _bclHandler = new BclTimestampTzHandler(postgresType, convertInfinityDateTime); - } - - #region Read - - public override Instant Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var value = buf.ReadInt64(); - if (_convertInfinityDateTime) - { - if (value == long.MaxValue) - return Instant.MaxValue; - if (value == long.MinValue) - return Instant.MinValue; - } - return TimestampHandler.Decode(value); - } - - ZonedDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - try - { - var value = buf.ReadInt64(); - if (value == long.MaxValue || value == long.MinValue) - throw new NotSupportedException("Infinity values not supported for timestamp with time zone"); - return TimestampHandler.Decode(value).InZone(_dateTimeZoneProvider[buf.Connection.Timezone]); - } - catch (Exception e) when ( - string.Equals(buf.Connection.Timezone, "localtime", StringComparison.OrdinalIgnoreCase) && - (e is TimeZoneNotFoundException || e is DateTimeZoneNotFoundException)) - { - throw new TimeZoneNotFoundException( - "The special PostgreSQL timezone 'localtime' is not supported when reading values of type 'timestamp with time zone'. " + - "Please specify a real timezone in 'postgresql.conf' on the server, or set the 'PGTZ' environment variable on the client.", - e); - } - } - - OffsetDateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ((INpgsqlSimpleTypeHandler)this).Read(buf, len, fieldDescription).ToOffsetDateTime(); - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(Instant value, NpgsqlParameter? parameter) - => 8; - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(ZonedDateTime value, NpgsqlParameter? parameter) - => 8; - - public int ValidateAndGetLength(OffsetDateTime value, NpgsqlParameter? parameter) - => 8; - - public override void Write(Instant value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (_convertInfinityDateTime) - { - if (value == Instant.MaxValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == Instant.MinValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - TimestampHandler.WriteInteger(value, buf); - } - - void INpgsqlSimpleTypeHandler.Write(ZonedDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => Write(value.ToInstant(), buf, parameter); - - public void Write(OffsetDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => Write(value.ToInstant(), buf, parameter); - - #endregion Write - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => _bclHandler.Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) - => _bclHandler.ValidateAndGetLength(value, parameter); - - void INpgsqlSimpleTypeHandler.Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => _bclHandler.Write(value, buf, parameter); - } -} diff --git a/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj b/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj new file mode 100644 index 0000000000..d2b8e620a7 --- /dev/null +++ b/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj @@ -0,0 +1,22 @@ + + + + Shay Rojansky + netstandard2.0 + net8.0 + npgsql;postgresql;postgres;ado;ado.net;database;sql;opentelemetry;tracing;diagnostics;instrumentation + README.md + + + + + + + + + + + + + + diff --git a/src/Npgsql.OpenTelemetry/Properties/AssemblyInfo.cs b/src/Npgsql.OpenTelemetry/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1a340b1a15 --- /dev/null +++ b/src/Npgsql.OpenTelemetry/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif diff --git a/src/Npgsql.OpenTelemetry/README.md b/src/Npgsql.OpenTelemetry/README.md new file mode 100644 index 0000000000..c4ebf9778e --- /dev/null +++ b/src/Npgsql.OpenTelemetry/README.md @@ -0,0 +1,22 @@ +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. + +This package helps set up Npgsql's support for OpenTelemetry tracing, which allows you to observe database commands as they are being executed. + +You can drop the following code snippet in your application's startup, and you should start seeing tracing information on the console: + +```csharp +using var tracerProvider = Sdk.CreateTracerProviderBuilder() + .SetResourceBuilder(ResourceBuilder.CreateDefault().AddService("npgsql-tester")) + .SetSampler(new AlwaysOnSampler()) + // This optional activates tracing for your application, if you trace your own activities: + .AddSource("MyApp") + // This activates up Npgsql's tracing: + .AddNpgsql() + // This prints tracing data to the console: + .AddConsoleExporter() + .Build(); +``` + +Once this is done, you should start seeing Npgsql trace data appearing in your application's console. At this point, you can look into exporting your trace data to a more useful destination: systems such as [Zipkin](https://zipkin.io/) or [Jaeger](https://www.jaegertracing.io/) can efficiently collect and store your data, and provide user interfaces for querying and exploring it. + +For more information, [visit the diagnostics documentation page](https://www.npgsql.org/doc/diagnostics/tracing.html). diff --git a/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs b/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs new file mode 100644 index 0000000000..0c34138278 --- /dev/null +++ b/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs @@ -0,0 +1,19 @@ +using System; +using OpenTelemetry.Trace; + +// ReSharper disable once CheckNamespace +namespace Npgsql; + +/// +/// Extension method for setting up Npgsql OpenTelemetry tracing. +/// +public static class TracerProviderBuilderExtensions +{ + /// + /// Subscribes to the Npgsql activity source to enable OpenTelemetry tracing. + /// + public static TracerProviderBuilder AddNpgsql( + this TracerProviderBuilder builder, + Action? options = null) + => builder.AddSource("Npgsql"); +} \ No newline at end of file diff --git a/src/Npgsql.SourceGenerators/AnalyzerReleases.Shipped.md b/src/Npgsql.SourceGenerators/AnalyzerReleases.Shipped.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/Npgsql.SourceGenerators/AnalyzerReleases.Unshipped.md b/src/Npgsql.SourceGenerators/AnalyzerReleases.Unshipped.md new file mode 100644 index 0000000000..a5a24bb6ee --- /dev/null +++ b/src/Npgsql.SourceGenerators/AnalyzerReleases.Unshipped.md @@ -0,0 +1,4 @@ +### New Rules +Rule ID | Category | Severity | Notes +--------|----------|----------|------- +PGXXXX | Internal | Error | diff --git a/src/Npgsql.SourceGenerators/EmbeddedResource.cs b/src/Npgsql.SourceGenerators/EmbeddedResource.cs new file mode 100644 index 0000000000..6f019d3962 --- /dev/null +++ b/src/Npgsql.SourceGenerators/EmbeddedResource.cs @@ -0,0 +1,26 @@ +using System; +using System.IO; +using System.Reflection; + +namespace Npgsql.SourceGenerators; + +static class EmbeddedResource +{ + public static string GetContent(string relativePath) + { + var baseName = Assembly.GetExecutingAssembly().GetName().Name; + var resourceName = relativePath + .TrimStart('.') + .Replace(Path.DirectorySeparatorChar, '.') + .Replace(Path.AltDirectorySeparatorChar, '.'); + + using var stream = Assembly.GetExecutingAssembly() + .GetManifestResourceStream(baseName + "." + resourceName); + + if (stream == null) + throw new NotSupportedException(); + + using var reader = new StreamReader(stream); + return reader.ReadToEnd(); + } +} \ No newline at end of file diff --git a/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj new file mode 100644 index 0000000000..bc0f37e9bb --- /dev/null +++ b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj @@ -0,0 +1,32 @@ + + + + netstandard2.0 + 1591 + true + + false + + + + + + + + + + + + $(GetTargetPathDependsOn);GetDependencyTargetPaths + + + + + + + + + + + + diff --git a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilder.snbtxt b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilder.snbtxt new file mode 100644 index 0000000000..9ad343124c --- /dev/null +++ b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilder.snbtxt @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; + +#nullable disable +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +#pragma warning disable RS0016 // Add public types and members to the declared API +#pragma warning disable CS0618 // Member is obsolete + +namespace Npgsql +{ + public sealed partial class NpgsqlConnectionStringBuilder + { + private partial int Init() + { + // Set the strongly-typed properties to their default values + {{~ + for p in properties + if p.is_obsolete + continue + end + + if (p.default_value != null) + ~}} + {{ p.name }} = {{ p.default_value }}; + {{~ + end + end ~}} + + // Setting the strongly-typed properties here also set the string-based properties in the base class. + // Clear them (default settings = empty connection string) + base.Clear(); + + return 0; + } + + private partial bool GeneratedActions(GeneratedAction action, string keyword, ref object value) + { + switch (keyword) + { + {{~ for kv in properties_by_keyword ~}} + case "{{ kv.key }}": + {{~ for alternative in kv.value.alternatives ~}} + case "{{ alternative }}": + {{~ end ~}} + { + {{~ p = kv.value ~}} + const string canonicalName = "{{ p.canonical_name }}"; + switch(action) + { + case GeneratedAction.Remove: + var removed = base.ContainsKey(canonicalName); + {{~ if p.default_value == null ~}} + {{ p.name }} = default; + {{~ else ~}} + {{ p.name }} = {{ p.default_value }}; + {{~ end ~}} + {{~ if p.type_name != "String" ~}} + base.Remove(canonicalName); + {{~ else ~}} + // String property setters call SetValue, which itself calls base.Remove(). + {{~ end ~}} + return removed; + case GeneratedAction.Set: + {{~ if p.is_enum ~}} + {{ p.name }} = ({{ p.type_name }})GetValue(typeof({{ p.type_name }}), value); + {{~ else ~}} + {{ p.name }} = ({{ p.type_name }})Convert.ChangeType(value, typeof({{ p.type_name }})); + {{~ end ~}} + break; + case GeneratedAction.Get: + value = (object){{ p.name }} ?? ""; + break; + case GeneratedAction.GetCanonical: + value = canonicalName; + break; + } + return true; + } + {{~ end ~}} + } + if (action is GeneratedAction.Get or GeneratedAction.GetCanonical) + return false; + throw new KeyNotFoundException(); + + static object GetValue(Type type, object value) + => value is string s + ? Enum.Parse(type, s, ignoreCase: true) + : Convert.ChangeType(value, type); + } + } +} diff --git a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs new file mode 100644 index 0000000000..665789e74e --- /dev/null +++ b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs @@ -0,0 +1,146 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using Scriban; + +namespace Npgsql.SourceGenerators; + +[Generator] +public class NpgsqlConnectionStringBuilderSourceGenerator : ISourceGenerator +{ + static readonly DiagnosticDescriptor InternalError = new DiagnosticDescriptor( + id: "PGXXXX", + title: "Internal issue when source-generating NpgsqlConnectionStringBuilder", + messageFormat: "{0}", + category: "Internal", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public void Initialize(GeneratorInitializationContext context) {} + + public void Execute(GeneratorExecutionContext context) + { + if (context.Compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringBuilder") is not { } type) + return; + + if (context.Compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringPropertyAttribute") is not + { } connectionStringPropertyAttribute) + { + context.ReportDiagnostic(Diagnostic.Create( + InternalError, + location: null, + "Could not find Npgsql.NpgsqlConnectionStringPropertyAttribute")); + return; + } + + var obsoleteAttribute = context.Compilation.GetTypeByMetadataName("System.ObsoleteAttribute"); + var displayNameAttribute = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DisplayNameAttribute"); + var defaultValueAttribute = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DefaultValueAttribute"); + + if (obsoleteAttribute is null || displayNameAttribute is null || defaultValueAttribute is null) + { + context.ReportDiagnostic(Diagnostic.Create( + InternalError, + location: null, + "Could not find ObsoleteAttribute, DisplayNameAttribute or DefaultValueAttribute")); + return; + } + + var properties = new List(); + var propertiesByKeyword = new Dictionary(); + foreach (var member in type.GetMembers()) + { + if (member is not IPropertySymbol property || + property.GetAttributes().FirstOrDefault(a => connectionStringPropertyAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) is not { } propertyAttribute || + property.GetAttributes() + .FirstOrDefault(a => displayNameAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) + ?.ConstructorArguments[0].Value is not string displayName) + { + continue; + } + + var explicitDefaultValue = property.GetAttributes() + .FirstOrDefault(a => defaultValueAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) + ?.ConstructorArguments[0].Value; + + if (explicitDefaultValue is string s) + explicitDefaultValue = '"' + s.Replace("\"", "\"\"") + '"'; + + if (explicitDefaultValue is not null && property.Type.TypeKind == TypeKind.Enum) + { + explicitDefaultValue = $"({property.Type.Name}){explicitDefaultValue}"; + // var foo = property.Type.Name; + // explicitDefaultValue += $"/* {foo} */"; + } + + var propertyDetails = new PropertyDetails + { + Name = property.Name, + CanonicalName = displayName, + TypeName = property.Type.Name, + IsEnum = property.Type.TypeKind == TypeKind.Enum, + IsObsolete = property.GetAttributes().Any(a => obsoleteAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)), + DefaultValue = explicitDefaultValue + }; + + properties.Add(propertyDetails); + + propertiesByKeyword[displayName.ToUpperInvariant()] = propertyDetails; + if (property.Name != displayName) + { + var propertyName = property.Name.ToUpperInvariant(); + if (!propertiesByKeyword.ContainsKey(propertyName)) + propertyDetails.Alternatives.Add(propertyName); + } + + if (propertyAttribute.ConstructorArguments.Length == 1) + { + foreach (var synonymArg in propertyAttribute.ConstructorArguments[0].Values) + { + if (synonymArg.Value is string synonym) + { + var synonymName = synonym.ToUpperInvariant(); + if (!propertiesByKeyword.ContainsKey(synonymName)) + propertyDetails.Alternatives.Add(synonymName); + } + } + } + } + + var template = Template.Parse(EmbeddedResource.GetContent("NpgsqlConnectionStringBuilder.snbtxt"), "NpgsqlConnectionStringBuilder.snbtxt"); + + var output = template.Render(new + { + Properties = properties, + PropertiesByKeyword = propertiesByKeyword + }); + + context.AddSource(type.Name + ".Generated.cs", SourceText.From(output, Encoding.UTF8)); + } + + sealed class PropertyDetails + { + public string Name { get; set; } = null!; + public string CanonicalName { get; set; } = null!; + public string TypeName { get; set; } = null!; + public bool IsEnum { get; set; } + public bool IsObsolete { get; set; } + public object? DefaultValue { get; set; } + + public HashSet Alternatives { get; } = new(StringComparer.Ordinal); + + public PropertyDetails Clone() + => new() + { + Name = Name, + CanonicalName = CanonicalName, + TypeName = TypeName, + IsEnum = IsEnum, + IsObsolete = IsObsolete, + DefaultValue = DefaultValue + }; + } +} diff --git a/src/Npgsql/BackendMessages/AuthenticationMessages.cs b/src/Npgsql/BackendMessages/AuthenticationMessages.cs index 8a6be71ef1..b6320e87b8 100644 --- a/src/Npgsql/BackendMessages/AuthenticationMessages.cs +++ b/src/Npgsql/BackendMessages/AuthenticationMessages.cs @@ -1,233 +1,229 @@ -using System.Collections.Generic; -using Npgsql.Logging; -using Npgsql.Util; +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; -namespace Npgsql.BackendMessages -{ - abstract class AuthenticationRequestMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.AuthenticationRequest; - internal abstract AuthenticationRequestType AuthRequestType { get; } - } +namespace Npgsql.BackendMessages; - class AuthenticationOkMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationOk; +abstract class AuthenticationRequestMessage : IBackendMessage +{ + public BackendMessageCode Code => BackendMessageCode.AuthenticationRequest; + internal abstract AuthenticationRequestType AuthRequestType { get; } +} - internal static readonly AuthenticationOkMessage Instance = new AuthenticationOkMessage(); - AuthenticationOkMessage() { } - } +sealed class AuthenticationOkMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationOk; - class AuthenticationKerberosV5Message : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationKerberosV5; + internal static readonly AuthenticationOkMessage Instance = new(); + AuthenticationOkMessage() { } +} - internal static readonly AuthenticationKerberosV5Message Instance = new AuthenticationKerberosV5Message(); - AuthenticationKerberosV5Message() { } - } +sealed class AuthenticationKerberosV5Message : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationKerberosV5; - class AuthenticationCleartextPasswordMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationCleartextPassword; + internal static readonly AuthenticationKerberosV5Message Instance = new(); + AuthenticationKerberosV5Message() { } +} - internal static readonly AuthenticationCleartextPasswordMessage Instance = new AuthenticationCleartextPasswordMessage(); - AuthenticationCleartextPasswordMessage() { } - } +sealed class AuthenticationCleartextPasswordMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationCleartextPassword; - class AuthenticationMD5PasswordMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationMD5Password; + internal static readonly AuthenticationCleartextPasswordMessage Instance = new(); + AuthenticationCleartextPasswordMessage() { } +} - internal byte[] Salt { get; private set; } +sealed class AuthenticationMD5PasswordMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationMD5Password; - internal static AuthenticationMD5PasswordMessage Load(NpgsqlReadBuffer buf) - { - var salt = new byte[4]; - buf.ReadBytes(salt, 0, 4); - return new AuthenticationMD5PasswordMessage(salt); - } + internal byte[] Salt { get; } - AuthenticationMD5PasswordMessage(byte[] salt) - { - Salt = salt; - } + internal static AuthenticationMD5PasswordMessage Load(NpgsqlReadBuffer buf) + { + var salt = new byte[4]; + buf.ReadBytes(salt, 0, 4); + return new AuthenticationMD5PasswordMessage(salt); } - class AuthenticationSCMCredentialMessage : AuthenticationRequestMessage + AuthenticationMD5PasswordMessage(byte[] salt) { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSCMCredential; - - internal static readonly AuthenticationSCMCredentialMessage Instance = new AuthenticationSCMCredentialMessage(); - AuthenticationSCMCredentialMessage() { } + Salt = salt; } +} - class AuthenticationGSSMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSS; +sealed class AuthenticationSCMCredentialMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSCMCredential; - internal static readonly AuthenticationGSSMessage Instance = new AuthenticationGSSMessage(); - AuthenticationGSSMessage() { } - } + internal static readonly AuthenticationSCMCredentialMessage Instance = new(); + AuthenticationSCMCredentialMessage() { } +} - class AuthenticationGSSContinueMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSSContinue; +sealed class AuthenticationGSSMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSS; - internal byte[] AuthenticationData { get; private set; } + internal static readonly AuthenticationGSSMessage Instance = new(); + AuthenticationGSSMessage() { } +} - internal static AuthenticationGSSContinueMessage Load(NpgsqlReadBuffer buf, int len) - { - len -= 4; // The AuthRequestType code - var authenticationData = new byte[len]; - buf.ReadBytes(authenticationData, 0, len); - return new AuthenticationGSSContinueMessage(authenticationData); - } +sealed class AuthenticationGSSContinueMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSSContinue; - AuthenticationGSSContinueMessage(byte[] authenticationData) - { - AuthenticationData = authenticationData; - } - } + internal byte[] AuthenticationData { get; } - class AuthenticationSSPIMessage : AuthenticationRequestMessage + internal static AuthenticationGSSContinueMessage Load(NpgsqlReadBuffer buf, int len) { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSSPI; + len -= 4; // The AuthRequestType code + var authenticationData = new byte[len]; + buf.ReadBytes(authenticationData, 0, len); + return new AuthenticationGSSContinueMessage(authenticationData); + } - internal static readonly AuthenticationSSPIMessage Instance = new AuthenticationSSPIMessage(); - AuthenticationSSPIMessage() { } + AuthenticationGSSContinueMessage(byte[] authenticationData) + { + AuthenticationData = authenticationData; } +} - #region SASL +sealed class AuthenticationSSPIMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSSPI; - class AuthenticationSASLMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASL; - internal List Mechanisms { get; } = new List(); + internal static readonly AuthenticationSSPIMessage Instance = new(); + AuthenticationSSPIMessage() { } +} - internal AuthenticationSASLMessage(NpgsqlReadBuffer buf) - { - while (buf.Buffer[buf.ReadPosition] != 0) - Mechanisms.Add(buf.ReadNullTerminatedString()); - buf.ReadByte(); - if (Mechanisms.Count == 0) - throw new NpgsqlException("Received AuthenticationSASL message with 0 mechanisms!"); - } - } +#region SASL - class AuthenticationSASLContinueMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLContinue; - internal byte[] Payload { get; } +sealed class AuthenticationSASLMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASL; + internal List Mechanisms { get; } = new(); - internal AuthenticationSASLContinueMessage(NpgsqlReadBuffer buf, int len) - { - Payload = new byte[len]; - buf.ReadBytes(Payload, 0, len); - } + internal AuthenticationSASLMessage(NpgsqlReadBuffer buf) + { + while (buf.Buffer[buf.ReadPosition] != 0) + Mechanisms.Add(buf.ReadNullTerminatedString()); + buf.ReadByte(); + if (Mechanisms.Count == 0) + throw new NpgsqlException("Received AuthenticationSASL message with 0 mechanisms!"); } +} - class AuthenticationSCRAMServerFirstMessage +sealed class AuthenticationSASLContinueMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLContinue; + internal byte[] Payload { get; } + + internal AuthenticationSASLContinueMessage(NpgsqlReadBuffer buf, int len) { - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(AuthenticationSCRAMServerFirstMessage)); + Payload = new byte[len]; + buf.ReadBytes(Payload, 0, len); + } +} - internal string Nonce { get; } - internal string Salt { get; } - internal int Iteration { get; } +sealed class AuthenticationSCRAMServerFirstMessage +{ + internal string Nonce { get; } + internal string Salt { get; } + internal int Iteration { get; } - internal static AuthenticationSCRAMServerFirstMessage Load(byte[] bytes) - { - var data = PGUtil.UTF8Encoding.GetString(bytes); - string? nonce = null, salt = null; - var iteration = -1; - - foreach (var part in data.Split(',')) - { - if (part.StartsWith("r=")) - nonce = part.Substring(2); - else if (part.StartsWith("s=")) - salt = part.Substring(2); - else if (part.StartsWith("i=")) - iteration = int.Parse(part.Substring(2)); - else - Log.Debug("Unknown part in SCRAM server-first message:" + part); - } - - if (nonce == null) - throw new NpgsqlException("Server nonce not received in SCRAM server-first message"); - if (salt == null) - throw new NpgsqlException("Server salt not received in SCRAM server-first message"); - if (iteration == -1) - throw new NpgsqlException("Server iterations not received in SCRAM server-first message"); - - return new AuthenticationSCRAMServerFirstMessage(nonce, salt, iteration); - } + internal static AuthenticationSCRAMServerFirstMessage Load(byte[] bytes, ILogger connectionLogger) + { + var data = NpgsqlWriteBuffer.UTF8Encoding.GetString(bytes); + string? nonce = null, salt = null; + var iteration = -1; - AuthenticationSCRAMServerFirstMessage(string nonce, string salt, int iteration) + foreach (var part in data.Split(',')) { - Nonce = nonce; - Salt = salt; - Iteration = iteration; + if (part.StartsWith("r=", StringComparison.Ordinal)) + nonce = part.Substring(2); + else if (part.StartsWith("s=", StringComparison.Ordinal)) + salt = part.Substring(2); + else if (part.StartsWith("i=", StringComparison.Ordinal)) + iteration = int.Parse(part.Substring(2)); + else + connectionLogger.LogDebug("Unknown part in SCRAM server-first message:" + part); } - } - class AuthenticationSASLFinalMessage : AuthenticationRequestMessage - { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLFinal; - internal byte[] Payload { get; } + if (nonce == null) + throw new NpgsqlException("Server nonce not received in SCRAM server-first message"); + if (salt == null) + throw new NpgsqlException("Server salt not received in SCRAM server-first message"); + if (iteration == -1) + throw new NpgsqlException("Server iterations not received in SCRAM server-first message"); - internal AuthenticationSASLFinalMessage(NpgsqlReadBuffer buf, int len) - { - Payload = new byte[len]; - buf.ReadBytes(Payload, 0, len); - } + return new AuthenticationSCRAMServerFirstMessage(nonce, salt, iteration); } - class AuthenticationSCRAMServerFinalMessage + AuthenticationSCRAMServerFirstMessage(string nonce, string salt, int iteration) { - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(AuthenticationSCRAMServerFinalMessage)); + Nonce = nonce; + Salt = salt; + Iteration = iteration; + } +} - internal string ServerSignature { get; } +sealed class AuthenticationSASLFinalMessage : AuthenticationRequestMessage +{ + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLFinal; + internal byte[] Payload { get; } - internal static AuthenticationSCRAMServerFinalMessage Load(byte[] bytes) - { - var data = PGUtil.UTF8Encoding.GetString(bytes); - string? serverSignature = null; + internal AuthenticationSASLFinalMessage(NpgsqlReadBuffer buf, int len) + { + Payload = new byte[len]; + buf.ReadBytes(Payload, 0, len); + } +} - foreach (var part in data.Split(',')) - { - if (part.StartsWith("v=")) - serverSignature = part.Substring(2); - else - Log.Debug("Unknown part in SCRAM server-first message:" + part); - } +sealed class AuthenticationSCRAMServerFinalMessage +{ + internal string ServerSignature { get; } - if (serverSignature == null) - throw new NpgsqlException("Server signature not received in SCRAM server-final message"); + internal static AuthenticationSCRAMServerFinalMessage Load(byte[] bytes, ILogger connectionLogger) + { + var data = NpgsqlWriteBuffer.UTF8Encoding.GetString(bytes); + string? serverSignature = null; - return new AuthenticationSCRAMServerFinalMessage(serverSignature); + foreach (var part in data.Split(',')) + { + if (part.StartsWith("v=", StringComparison.Ordinal)) + serverSignature = part.Substring(2); + else + connectionLogger.LogDebug("Unknown part in SCRAM server-first message:" + part); } - internal AuthenticationSCRAMServerFinalMessage(string serverSignature) - => ServerSignature = serverSignature; + if (serverSignature == null) + throw new NpgsqlException("Server signature not received in SCRAM server-final message"); + + return new AuthenticationSCRAMServerFinalMessage(serverSignature); } - #endregion SASL + internal AuthenticationSCRAMServerFinalMessage(string serverSignature) + => ServerSignature = serverSignature; +} - // TODO: Remove Authentication prefix from everything - enum AuthenticationRequestType - { - AuthenticationOk = 0, - AuthenticationKerberosV4 = 1, - AuthenticationKerberosV5 = 2, - AuthenticationCleartextPassword = 3, - AuthenticationCryptPassword = 4, - AuthenticationMD5Password = 5, - AuthenticationSCMCredential = 6, - AuthenticationGSS = 7, - AuthenticationGSSContinue = 8, - AuthenticationSSPI = 9, - AuthenticationSASL = 10, - AuthenticationSASLContinue = 11, - AuthenticationSASLFinal = 12 - } +#endregion SASL + +// TODO: Remove Authentication prefix from everything +enum AuthenticationRequestType +{ + AuthenticationOk = 0, + AuthenticationKerberosV4 = 1, + AuthenticationKerberosV5 = 2, + AuthenticationCleartextPassword = 3, + AuthenticationCryptPassword = 4, + AuthenticationMD5Password = 5, + AuthenticationSCMCredential = 6, + AuthenticationGSS = 7, + AuthenticationGSSContinue = 8, + AuthenticationSSPI = 9, + AuthenticationSASL = 10, + AuthenticationSASLContinue = 11, + AuthenticationSASLFinal = 12 } diff --git a/src/Npgsql/BackendMessages/BackendKeyDataMessage.cs b/src/Npgsql/BackendMessages/BackendKeyDataMessage.cs index 3c5440a5d2..2140048c38 100644 --- a/src/Npgsql/BackendMessages/BackendKeyDataMessage.cs +++ b/src/Npgsql/BackendMessages/BackendKeyDataMessage.cs @@ -1,16 +1,17 @@ -namespace Npgsql.BackendMessages +using Npgsql.Internal; + +namespace Npgsql.BackendMessages; + +sealed class BackendKeyDataMessage : IBackendMessage { - class BackendKeyDataMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.BackendKeyData; + public BackendMessageCode Code => BackendMessageCode.BackendKeyData; - internal int BackendProcessId { get; private set; } - internal int BackendSecretKey { get; private set; } + internal int BackendProcessId { get; } + internal int BackendSecretKey { get; } - internal BackendKeyDataMessage(NpgsqlReadBuffer buf) - { - BackendProcessId = buf.ReadInt32(); - BackendSecretKey = buf.ReadInt32(); - } + internal BackendKeyDataMessage(NpgsqlReadBuffer buf) + { + BackendProcessId = buf.ReadInt32(); + BackendSecretKey = buf.ReadInt32(); } -} +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/BindCompleteMessage.cs b/src/Npgsql/BackendMessages/BindCompleteMessage.cs index 7f43486f5f..f6dbfce1bb 100644 --- a/src/Npgsql/BackendMessages/BindCompleteMessage.cs +++ b/src/Npgsql/BackendMessages/BindCompleteMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class BindCompleteMessage : IBackendMessage { - class BindCompleteMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.BindComplete; - internal static readonly BindCompleteMessage Instance = new BindCompleteMessage(); - BindCompleteMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.BindComplete; + internal static readonly BindCompleteMessage Instance = new(); + BindCompleteMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/CloseCompletedMessage.cs b/src/Npgsql/BackendMessages/CloseCompletedMessage.cs index 55aab247b3..9443fd3e97 100644 --- a/src/Npgsql/BackendMessages/CloseCompletedMessage.cs +++ b/src/Npgsql/BackendMessages/CloseCompletedMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class CloseCompletedMessage : IBackendMessage { - class CloseCompletedMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.CloseComplete; - internal static readonly CloseCompletedMessage Instance = new CloseCompletedMessage(); - CloseCompletedMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.CloseComplete; + internal static readonly CloseCompletedMessage Instance = new(); + CloseCompletedMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/CommandCompleteMessage.cs b/src/Npgsql/BackendMessages/CommandCompleteMessage.cs index 4090e38e55..98154d1a7e 100644 --- a/src/Npgsql/BackendMessages/CommandCompleteMessage.cs +++ b/src/Npgsql/BackendMessages/CommandCompleteMessage.cs @@ -1,108 +1,62 @@ -using System.Diagnostics; +using System; +using System.Buffers.Text; +using Npgsql.Internal; -namespace Npgsql.BackendMessages -{ - class CommandCompleteMessage : IBackendMessage - { - internal StatementType StatementType { get; private set; } - internal uint OID { get; private set; } - internal ulong Rows { get; private set; } - - internal CommandCompleteMessage Load(NpgsqlReadBuffer buf, int len) - { - Rows = 0; - OID = 0; - - var bytes = buf.Buffer; - var i = buf.ReadPosition; - buf.Skip(len); - switch (bytes[i]) - { - case (byte)'I': - if (!AreEqual(bytes, i, "INSERT ")) - goto default; - StatementType = StatementType.Insert; - i += 7; - OID = (uint) ParseNumber(bytes, ref i); - i++; - Rows = ParseNumber(bytes, ref i); - return this; - - case (byte)'D': - if (!AreEqual(bytes, i, "DELETE ")) - goto default; - StatementType = StatementType.Delete; - i += 7; - Rows = ParseNumber(bytes, ref i); - return this; - - case (byte)'U': - if (!AreEqual(bytes, i, "UPDATE ")) - goto default; - StatementType = StatementType.Update; - i += 7; - Rows = ParseNumber(bytes, ref i); - return this; - - case (byte)'S': - if (!AreEqual(bytes, i, "SELECT ")) - goto default; - StatementType = StatementType.Select; - i += 7; - Rows = ParseNumber(bytes, ref i); - return this; +namespace Npgsql.BackendMessages; - case (byte)'M': - if (!AreEqual(bytes, i, "MOVE ")) - goto default; - StatementType = StatementType.Move; - i += 5; - Rows = ParseNumber(bytes, ref i); - return this; - - case (byte)'F': - if (!AreEqual(bytes, i, "FETCH ")) - goto default; - StatementType = StatementType.Fetch; - i += 6; - Rows = ParseNumber(bytes, ref i); - return this; +sealed class CommandCompleteMessage : IBackendMessage +{ + uint _oid; + ulong _rows; + internal StatementType StatementType { get; private set; } - case (byte)'C': - if (!AreEqual(bytes, i, "COPY ")) - goto default; - StatementType = StatementType.Copy; - i += 5; - Rows = ParseNumber(bytes, ref i); - return this; + internal uint OID => _oid; + internal ulong Rows => _rows; - default: - StatementType = StatementType.Other; - return this; - } - } + internal CommandCompleteMessage Load(NpgsqlReadBuffer buf, int len) + { + var bytes = buf.Span.Slice(0, len); + buf.Skip(len); - static bool AreEqual(byte[] bytes, int pos, string s) + // PostgreSQL always writes these strings as ASCII, see https://github.com/postgres/postgres/blob/c8e1ba736b2b9e8c98d37a5b77c4ed31baf94147/src/backend/tcop/cmdtag.c#L130-L133 + (StatementType, var argumentsStart) = Convert.ToChar(bytes[0]) switch { - for (var i = 0; i < s.Length; i++) - { - if (bytes[pos+i] != s[i]) - return false; - } - return true; - } - - static ulong ParseNumber(byte[] bytes, ref int pos) + 'S' when bytes.StartsWith("SELECT "u8) => (StatementType.Select, "SELECT ".Length), + 'I' when bytes.StartsWith("INSERT "u8) => (StatementType.Insert, "INSERT ".Length), + 'U' when bytes.StartsWith("UPDATE "u8) => (StatementType.Update, "UPDATE ".Length), + 'D' when bytes.StartsWith("DELETE "u8) => (StatementType.Delete, "DELETE ".Length), + 'M' when bytes.StartsWith("MERGE "u8) => (StatementType.Merge, "MERGE ".Length), + 'C' when bytes.StartsWith("COPY "u8) => (StatementType.Copy, "COPY ".Length), + 'C' when bytes.StartsWith("CALL"u8) => (StatementType.Call, "CALL".Length), + 'M' when bytes.StartsWith("MOVE "u8) => (StatementType.Move, "MOVE ".Length), + 'F' when bytes.StartsWith("FETCH "u8) => (StatementType.Fetch, "FETCH ".Length), + 'C' when bytes.StartsWith("CREATE TABLE AS "u8) => (StatementType.CreateTableAs, "CREATE TABLE AS ".Length), + _ => (StatementType.Other, 0) + }; + + _oid = 0; + _rows = 0; + + // Slice away the null terminator. + var arguments = bytes.Slice(argumentsStart, bytes.Length - argumentsStart - 1); + switch (StatementType) { - Debug.Assert(bytes[pos] >= '0' && bytes[pos] <= '9'); - uint result = 0; - do - { - result = result * 10 + bytes[pos++] - '0'; - } while (bytes[pos] >= '0' && bytes[pos] <= '9'); - return result; + case StatementType.Other: + case StatementType.Call: + break; + case StatementType.Insert: + if (!Utf8Parser.TryParse(arguments, out _oid, out var nextArgumentOffset)) + throw new InvalidOperationException("Invalid bytes in command complete message."); + arguments = arguments.Slice(nextArgumentOffset + 1); + goto default; + default: + if (!Utf8Parser.TryParse(arguments, out _rows, out _)) + throw new InvalidOperationException("Invalid bytes in command complete message."); + break; } - public BackendMessageCode Code => BackendMessageCode.CommandComplete; + return this; } + + public BackendMessageCode Code => BackendMessageCode.CommandComplete; } diff --git a/src/Npgsql/BackendMessages/CopyMessages.cs b/src/Npgsql/BackendMessages/CopyMessages.cs index f98f0904d2..1aa8aec0c2 100644 --- a/src/Npgsql/BackendMessages/CopyMessages.cs +++ b/src/Npgsql/BackendMessages/CopyMessages.cs @@ -1,94 +1,93 @@ using System; using System.Collections.Generic; -using Npgsql.Util; +using Npgsql.Internal; -namespace Npgsql.BackendMessages -{ - abstract class CopyResponseMessageBase : IBackendMessage - { - public abstract BackendMessageCode Code { get; } +namespace Npgsql.BackendMessages; - internal bool IsBinary { get; private set; } - internal short NumColumns { get; private set; } - internal List ColumnFormatCodes { get; } +abstract class CopyResponseMessageBase : IBackendMessage +{ + public abstract BackendMessageCode Code { get; } - internal CopyResponseMessageBase() - { - ColumnFormatCodes = new List(); - } + internal bool IsBinary { get; private set; } + internal short NumColumns { get; private set; } + internal List ColumnFormatCodes { get; } - internal void Load(NpgsqlReadBuffer buf) - { - ColumnFormatCodes.Clear(); - - var binaryIndicator = buf.ReadByte(); - IsBinary = binaryIndicator switch - { - 0 => false, - 1 => true, - _ => throw new Exception("Invalid binary indicator in CopyInResponse message: " + binaryIndicator) - }; - - NumColumns = buf.ReadInt16(); - for (var i = 0; i < NumColumns; i++) - ColumnFormatCodes.Add((FormatCode)buf.ReadInt16()); - } + internal CopyResponseMessageBase() + { + ColumnFormatCodes = new List(); } - class CopyInResponseMessage : CopyResponseMessageBase + internal void Load(NpgsqlReadBuffer buf) { - public override BackendMessageCode Code => BackendMessageCode.CopyInResponse; + ColumnFormatCodes.Clear(); - internal new CopyInResponseMessage Load(NpgsqlReadBuffer buf) + var binaryIndicator = buf.ReadByte(); + IsBinary = binaryIndicator switch { - base.Load(buf); - return this; - } + 0 => false, + 1 => true, + _ => throw new Exception("Invalid binary indicator in CopyInResponse message: " + binaryIndicator) + }; + + NumColumns = buf.ReadInt16(); + for (var i = 0; i < NumColumns; i++) + ColumnFormatCodes.Add(DataFormatUtils.Create(buf.ReadInt16())); } +} - class CopyOutResponseMessage : CopyResponseMessageBase - { - public override BackendMessageCode Code => BackendMessageCode.CopyOutResponse; +sealed class CopyInResponseMessage : CopyResponseMessageBase +{ + public override BackendMessageCode Code => BackendMessageCode.CopyInResponse; - internal new CopyOutResponseMessage Load(NpgsqlReadBuffer buf) - { - base.Load(buf); - return this; - } + internal new CopyInResponseMessage Load(NpgsqlReadBuffer buf) + { + base.Load(buf); + return this; } +} - class CopyBothResponseMessage : CopyResponseMessageBase - { - public override BackendMessageCode Code => BackendMessageCode.CopyBothResponse; +sealed class CopyOutResponseMessage : CopyResponseMessageBase +{ + public override BackendMessageCode Code => BackendMessageCode.CopyOutResponse; - internal new CopyBothResponseMessage Load(NpgsqlReadBuffer buf) - { - base.Load(buf); - return this; - } + internal new CopyOutResponseMessage Load(NpgsqlReadBuffer buf) + { + base.Load(buf); + return this; } +} - /// - /// Note that this message doesn't actually contain the data, but only the length. Data is processed - /// directly from the connector's buffer. - /// - class CopyDataMessage : IBackendMessage +sealed class CopyBothResponseMessage : CopyResponseMessageBase +{ + public override BackendMessageCode Code => BackendMessageCode.CopyBothResponse; + + internal new CopyBothResponseMessage Load(NpgsqlReadBuffer buf) { - public BackendMessageCode Code => BackendMessageCode.CopyData; + base.Load(buf); + return this; + } +} - public int Length { get; private set; } +/// +/// Note that this message doesn't actually contain the data, but only the length. Data is processed +/// directly from the connector's buffer. +/// +sealed class CopyDataMessage : IBackendMessage +{ + public BackendMessageCode Code => BackendMessageCode.CopyData; - internal CopyDataMessage Load(int len) - { - Length = len; - return this; - } - } + public int Length { get; private set; } - class CopyDoneMessage : IBackendMessage + internal CopyDataMessage Load(int len) { - public BackendMessageCode Code => BackendMessageCode.CopyDone; - internal static readonly CopyDoneMessage Instance = new CopyDoneMessage(); - CopyDoneMessage() { } + Length = len; + return this; } } + +sealed class CopyDoneMessage : IBackendMessage +{ + public BackendMessageCode Code => BackendMessageCode.CopyDone; + internal static readonly CopyDoneMessage Instance = new(); + CopyDoneMessage() { } +} diff --git a/src/Npgsql/BackendMessages/DataRowMessage.cs b/src/Npgsql/BackendMessages/DataRowMessage.cs index 51ddd4f762..b4fddf9789 100644 --- a/src/Npgsql/BackendMessages/DataRowMessage.cs +++ b/src/Npgsql/BackendMessages/DataRowMessage.cs @@ -1,20 +1,19 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +/// +/// DataRow is special in that it does not parse the actual contents of the backend message, +/// because in sequential mode the message will be traversed and processed sequentially by +/// . +/// +sealed class DataRowMessage : IBackendMessage { - /// - /// DataRow is special in that it does not parse the actual contents of the backend message, - /// because in sequential mode the message will be traversed and processed sequentially by - /// . - /// - class DataRowMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.DataRow; + public BackendMessageCode Code => BackendMessageCode.DataRow; - internal int Length { get; private set; } + internal int Length { get; private set; } - internal DataRowMessage Load(int len) - { - Length = len; - return this; - } + internal DataRowMessage Load(int len) + { + Length = len; + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/EmptyQueryMessage.cs b/src/Npgsql/BackendMessages/EmptyQueryMessage.cs index 70ee5d7677..ef190f3678 100644 --- a/src/Npgsql/BackendMessages/EmptyQueryMessage.cs +++ b/src/Npgsql/BackendMessages/EmptyQueryMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class EmptyQueryMessage : IBackendMessage { - class EmptyQueryMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.EmptyQueryResponse; - internal static readonly EmptyQueryMessage Instance = new EmptyQueryMessage(); - EmptyQueryMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.EmptyQueryResponse; + internal static readonly EmptyQueryMessage Instance = new(); + EmptyQueryMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs b/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs index 56249042a2..8a22139a94 100644 --- a/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs +++ b/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs @@ -1,188 +1,188 @@ using System; -using Npgsql.Logging; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +[Serializable] +sealed class ErrorOrNoticeMessage { - [Serializable] - class ErrorOrNoticeMessage - { - internal string Severity { get; } - internal string InvariantSeverity { get; } - internal string SqlState { get; } - internal string Message { get; } - internal string? Detail { get; } - internal string? Hint { get; } - internal int Position { get; } - internal int InternalPosition { get; } - internal string? InternalQuery { get; } - internal string? Where { get; } - internal string? SchemaName { get; } - internal string? TableName { get; } - internal string? ColumnName { get; } - internal string? DataTypeName { get; } - internal string? ConstraintName { get; } - internal string? File { get; } - internal string? Line { get; } - internal string? Routine { get; } + internal string Severity { get; } + internal string InvariantSeverity { get; } + internal string SqlState { get; } + internal string Message { get; } + internal string? Detail { get; } + internal string? Hint { get; } + internal int Position { get; } + internal int InternalPosition { get; } + internal string? InternalQuery { get; } + internal string? Where { get; } + internal string? SchemaName { get; } + internal string? TableName { get; } + internal string? ColumnName { get; } + internal string? DataTypeName { get; } + internal string? ConstraintName { get; } + internal string? File { get; } + internal string? Line { get; } + internal string? Routine { get; } - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(ErrorOrNoticeMessage)); + // ReSharper disable once FunctionComplexityOverflow + internal static ErrorOrNoticeMessage Load(NpgsqlReadBuffer buf, bool includeDetail, ILogger exceptionLogger) + { + (string? severity, string? invariantSeverity, string? code, string? message, string? detail, string? hint) = (null, null, null, null, null, null); + var (position, internalPosition) = (0, 0); + (string? internalQuery, string? where) = (null, null); + (string? schemaName, string? tableName, string? columnName, string? dataTypeName, string? constraintName) = + (null, null, null, null, null); + (string? file, string? line, string? routine) = (null, null, null); - // ReSharper disable once FunctionComplexityOverflow - internal static ErrorOrNoticeMessage Load(NpgsqlReadBuffer buf, bool includeDetail) + while (true) { - (string? severity, string? invariantSeverity, string? code, string? message, string? detail, string? hint) = (null, null, null, null, null, null); - var (position, internalPosition) = (0, 0); - (string? internalQuery, string? where) = (null, null); - (string? schemaName, string? tableName, string? columnName, string? dataTypeName, string? constraintName) = - (null, null, null, null, null); - (string? file, string? line, string? routine) = (null, null, null); - - while (true) - { - var fieldCode = (ErrorFieldTypeCode)buf.ReadByte(); - switch (fieldCode) { - case ErrorFieldTypeCode.Done: - // Null terminator; error message fully consumed. - goto End; - case ErrorFieldTypeCode.Severity: - severity = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.InvariantSeverity: - invariantSeverity = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Code: - code = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Message: - message = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Detail: - detail = buf.ReadNullTerminatedStringRelaxed(); - if (!includeDetail && !string.IsNullOrEmpty(detail)) - detail = $"Detail redacted as it may contain sensitive data. Specify '{NpgsqlConnectionStringBuilder.IncludeExceptionDetailDisplayName}' in the connection string to include this information."; - break; - case ErrorFieldTypeCode.Hint: - hint = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Position: - var positionStr = buf.ReadNullTerminatedStringRelaxed(); - if (!int.TryParse(positionStr, out var tmpPosition)) { - Log.Warn("Non-numeric position in ErrorResponse: " + positionStr); - continue; - } - position = tmpPosition; - break; - case ErrorFieldTypeCode.InternalPosition: - var internalPositionStr = buf.ReadNullTerminatedStringRelaxed(); - if (!int.TryParse(internalPositionStr, out var internalPositionTmp)) { - Log.Warn("Non-numeric position in ErrorResponse: " + internalPositionStr); - continue; - } - internalPosition = internalPositionTmp; - break; - case ErrorFieldTypeCode.InternalQuery: - internalQuery = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Where: - where = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.File: - file = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Line: - line = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.Routine: - routine = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.SchemaName: - schemaName = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.TableName: - tableName = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.ColumnName: - columnName = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.DataTypeName: - dataTypeName = buf.ReadNullTerminatedStringRelaxed(); - break; - case ErrorFieldTypeCode.ConstraintName: - constraintName = buf.ReadNullTerminatedStringRelaxed(); - break; - default: - // Unknown error field; consume and discard. - buf.ReadNullTerminatedStringRelaxed(); - break; + var fieldCode = (ErrorFieldTypeCode)buf.ReadByte(); + switch (fieldCode) { + case ErrorFieldTypeCode.Done: + // Null terminator; error message fully consumed. + goto End; + case ErrorFieldTypeCode.Severity: + severity = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.InvariantSeverity: + invariantSeverity = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Code: + code = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Message: + message = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Detail: + detail = buf.ReadNullTerminatedStringRelaxed(); + if (!includeDetail && !string.IsNullOrEmpty(detail)) + detail = $"Detail redacted as it may contain sensitive data. Specify '{NpgsqlConnectionStringBuilder.IncludeExceptionDetailDisplayName}' in the connection string to include this information."; + break; + case ErrorFieldTypeCode.Hint: + hint = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Position: + var positionStr = buf.ReadNullTerminatedStringRelaxed(); + if (!int.TryParse(positionStr, out var tmpPosition)) + { + exceptionLogger.LogWarning("Non-numeric position in ErrorResponse: " + positionStr); + continue; } + position = tmpPosition; + break; + case ErrorFieldTypeCode.InternalPosition: + var internalPositionStr = buf.ReadNullTerminatedStringRelaxed(); + if (!int.TryParse(internalPositionStr, out var internalPositionTmp)) + { + exceptionLogger.LogWarning("Non-numeric position in ErrorResponse: " + internalPositionStr); + continue; + } + internalPosition = internalPositionTmp; + break; + case ErrorFieldTypeCode.InternalQuery: + internalQuery = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Where: + where = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.File: + file = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Line: + line = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.Routine: + routine = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.SchemaName: + schemaName = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.TableName: + tableName = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.ColumnName: + columnName = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.DataTypeName: + dataTypeName = buf.ReadNullTerminatedStringRelaxed(); + break; + case ErrorFieldTypeCode.ConstraintName: + constraintName = buf.ReadNullTerminatedStringRelaxed(); + break; + default: + // Unknown error field; consume and discard. + buf.ReadNullTerminatedStringRelaxed(); + break; } + } - End: - if (severity == null) - throw new NpgsqlException("Severity not received in server error message"); - if (code == null) - throw new NpgsqlException("Code not received in server error message"); - if (message == null) - throw new NpgsqlException("Message not received in server error message"); + End: + if (severity == null) + throw new NpgsqlException("Severity not received in server error message"); + if (code == null) + throw new NpgsqlException("Code not received in server error message"); + if (message == null) + throw new NpgsqlException("Message not received in server error message"); - return new ErrorOrNoticeMessage( - severity, invariantSeverity ?? severity, code, message, - detail, hint, position, internalPosition, internalQuery, where, - schemaName, tableName, columnName, dataTypeName, constraintName, - file, line, routine); + return new ErrorOrNoticeMessage( + severity, invariantSeverity ?? severity, code, message, + detail, hint, position, internalPosition, internalQuery, where, + schemaName, tableName, columnName, dataTypeName, constraintName, + file, line, routine); - } + } - internal ErrorOrNoticeMessage( - string severity, string invariantSeverity, string sqlState, string message, - string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, string? internalQuery = null, string? where = null, - string? schemaName = null, string? tableName = null, string? columnName = null, string? dataTypeName = null, string? constraintName = null, - string? file = null, string? line = null, string? routine = null) - { - Severity = severity; - InvariantSeverity = invariantSeverity; - SqlState = sqlState; - Message = message; - Detail = detail; - Hint = hint; - Position = position; - InternalPosition = internalPosition; - InternalQuery = internalQuery; - Where = where; - SchemaName = schemaName; - TableName = tableName; - ColumnName = columnName; - DataTypeName = dataTypeName; - ConstraintName = constraintName; - File = file; - Line = line; - Routine = routine; - } + internal ErrorOrNoticeMessage( + string severity, string invariantSeverity, string sqlState, string message, + string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, string? internalQuery = null, string? where = null, + string? schemaName = null, string? tableName = null, string? columnName = null, string? dataTypeName = null, string? constraintName = null, + string? file = null, string? line = null, string? routine = null) + { + Severity = severity; + InvariantSeverity = invariantSeverity; + SqlState = sqlState; + Message = message; + Detail = detail; + Hint = hint; + Position = position; + InternalPosition = internalPosition; + InternalQuery = internalQuery; + Where = where; + SchemaName = schemaName; + TableName = tableName; + ColumnName = columnName; + DataTypeName = dataTypeName; + ConstraintName = constraintName; + File = file; + Line = line; + Routine = routine; + } - /// - /// Error and notice message field codes - /// - internal enum ErrorFieldTypeCode : byte - { - Done = 0, - Severity = (byte)'S', - InvariantSeverity = (byte)'V', - Code = (byte)'C', - Message = (byte)'M', - Detail = (byte)'D', - Hint = (byte)'H', - Position = (byte)'P', - InternalPosition = (byte)'p', - InternalQuery = (byte)'q', - Where = (byte)'W', - SchemaName = (byte)'s', - TableName = (byte)'t', - ColumnName = (byte)'c', - DataTypeName = (byte)'d', - ConstraintName = (byte)'n', - File = (byte)'F', - Line = (byte)'L', - Routine = (byte)'R' - } + /// + /// Error and notice message field codes + /// + internal enum ErrorFieldTypeCode : byte + { + Done = 0, + Severity = (byte)'S', + InvariantSeverity = (byte)'V', + Code = (byte)'C', + Message = (byte)'M', + Detail = (byte)'D', + Hint = (byte)'H', + Position = (byte)'P', + InternalPosition = (byte)'p', + InternalQuery = (byte)'q', + Where = (byte)'W', + SchemaName = (byte)'s', + TableName = (byte)'t', + ColumnName = (byte)'c', + DataTypeName = (byte)'d', + ConstraintName = (byte)'n', + File = (byte)'F', + Line = (byte)'L', + Routine = (byte)'R' } } diff --git a/src/Npgsql/BackendMessages/NoDataMessage.cs b/src/Npgsql/BackendMessages/NoDataMessage.cs index 4202c1efd2..884d5c4d5e 100644 --- a/src/Npgsql/BackendMessages/NoDataMessage.cs +++ b/src/Npgsql/BackendMessages/NoDataMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class NoDataMessage : IBackendMessage { - class NoDataMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.NoData; - internal static readonly NoDataMessage Instance = new NoDataMessage(); - NoDataMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.NoData; + internal static readonly NoDataMessage Instance = new(); + NoDataMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs b/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs index 176fa78673..ebda485331 100644 --- a/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs +++ b/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs @@ -1,26 +1,26 @@ using System.Collections.Generic; +using Npgsql.Internal; -namespace Npgsql.BackendMessages -{ - class ParameterDescriptionMessage : IBackendMessage - { - // ReSharper disable once InconsistentNaming - internal List TypeOIDs { get; } +namespace Npgsql.BackendMessages; - internal ParameterDescriptionMessage() - { - TypeOIDs = new List(); - } +sealed class ParameterDescriptionMessage : IBackendMessage +{ + // ReSharper disable once InconsistentNaming + internal List TypeOIDs { get; } - internal ParameterDescriptionMessage Load(NpgsqlReadBuffer buf) - { - var numParams = buf.ReadUInt16(); - TypeOIDs.Clear(); - for (var i = 0; i < numParams; i++) - TypeOIDs.Add(buf.ReadUInt32()); - return this; - } + internal ParameterDescriptionMessage() + { + TypeOIDs = new List(); + } - public BackendMessageCode Code => BackendMessageCode.ParameterDescription; + internal ParameterDescriptionMessage Load(NpgsqlReadBuffer buf) + { + var numParams = buf.ReadUInt16(); + TypeOIDs.Clear(); + for (var i = 0; i < numParams; i++) + TypeOIDs.Add(buf.ReadUInt32()); + return this; } -} + + public BackendMessageCode Code => BackendMessageCode.ParameterDescription; +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/ParseCompleteMessage.cs b/src/Npgsql/BackendMessages/ParseCompleteMessage.cs index 8651ace674..bb011f821a 100644 --- a/src/Npgsql/BackendMessages/ParseCompleteMessage.cs +++ b/src/Npgsql/BackendMessages/ParseCompleteMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class ParseCompleteMessage : IBackendMessage { - class ParseCompleteMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.ParseComplete; - internal static readonly ParseCompleteMessage Instance = new ParseCompleteMessage(); - ParseCompleteMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.ParseComplete; + internal static readonly ParseCompleteMessage Instance = new(); + ParseCompleteMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/PortalSuspendedMessage.cs b/src/Npgsql/BackendMessages/PortalSuspendedMessage.cs index 551b9265e2..5da91ea831 100644 --- a/src/Npgsql/BackendMessages/PortalSuspendedMessage.cs +++ b/src/Npgsql/BackendMessages/PortalSuspendedMessage.cs @@ -1,9 +1,8 @@ -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +sealed class PortalSuspendedMessage : IBackendMessage { - class PortalSuspendedMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.PortalSuspended; - internal static readonly PortalSuspendedMessage Instance = new PortalSuspendedMessage(); - PortalSuspendedMessage() { } - } -} + public BackendMessageCode Code => BackendMessageCode.PortalSuspended; + internal static readonly PortalSuspendedMessage Instance = new(); + PortalSuspendedMessage() { } +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/ReadyForQueryMessage.cs b/src/Npgsql/BackendMessages/ReadyForQueryMessage.cs index 09ca4b2d6c..4d7225c422 100644 --- a/src/Npgsql/BackendMessages/ReadyForQueryMessage.cs +++ b/src/Npgsql/BackendMessages/ReadyForQueryMessage.cs @@ -1,14 +1,15 @@ -namespace Npgsql.BackendMessages +using Npgsql.Internal; + +namespace Npgsql.BackendMessages; + +sealed class ReadyForQueryMessage : IBackendMessage { - class ReadyForQueryMessage : IBackendMessage - { - public BackendMessageCode Code => BackendMessageCode.ReadyForQuery; + public BackendMessageCode Code => BackendMessageCode.ReadyForQuery; - internal TransactionStatus TransactionStatusIndicator { get; private set; } + internal TransactionStatus TransactionStatusIndicator { get; private set; } - internal ReadyForQueryMessage Load(NpgsqlReadBuffer buf) { - TransactionStatusIndicator = (TransactionStatus)buf.ReadByte(); - return this; - } + internal ReadyForQueryMessage Load(NpgsqlReadBuffer buf) { + TransactionStatusIndicator = (TransactionStatus)buf.ReadByte(); + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs index 292097d581..1dd1045e21 100644 --- a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs +++ b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs @@ -1,261 +1,391 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; +using System.Runtime.CompilerServices; +using System.Threading; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using Npgsql.Util; +using Npgsql.Replication.PgOutput.Messages; -namespace Npgsql.BackendMessages +namespace Npgsql.BackendMessages; + +readonly struct ColumnInfo { - /// - /// A RowDescription message sent from the backend. - /// - /// - /// See https://www.postgresql.org/docs/current/static/protocol-message-formats.html - /// - sealed class RowDescriptionMessage : IBackendMessage + public ColumnInfo(PgConverterInfo converterInfo, DataFormat dataFormat, bool asObject) { - public List Fields { get; } - readonly Dictionary _nameIndex; - Dictionary? _insensitiveIndex; + ConverterInfo = converterInfo; + DataFormat = dataFormat; + AsObject = asObject; + } - internal RowDescriptionMessage() - { - Fields = new List(); - _nameIndex = new Dictionary(); - } + public PgConverterInfo ConverterInfo { get; } + public DataFormat DataFormat { get; } + public bool AsObject { get; } +} - RowDescriptionMessage(RowDescriptionMessage source) - { - Fields = new List(source.Fields.Count); - foreach (var f in source.Fields) - Fields.Add(f.Clone()); - _nameIndex = new Dictionary(source._nameIndex); - if (source._insensitiveIndex?.Count > 0) - _insensitiveIndex = new Dictionary(source._insensitiveIndex); - } +/// +/// A RowDescription message sent from the backend. +/// +/// +/// See https://www.postgresql.org/docs/current/static/protocol-message-formats.html +/// +sealed class RowDescriptionMessage : IBackendMessage +{ + // We should really have CompareOptions.IgnoreKanaType here, but see + // https://github.com/dotnet/corefx/issues/12518#issuecomment-389658716 + static readonly StringComparer InvariantIgnoreCaseAndKanaWidthComparer = + CultureInfo.InvariantCulture.CompareInfo.GetStringComparer( + CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase | CompareOptions.IgnoreKanaType); + + readonly bool _connectorOwned; + FieldDescription?[] _fields; + readonly Dictionary _nameIndex; + Dictionary? _insensitiveIndex; + ColumnInfo[]? _lastConverterInfoCache; + + internal RowDescriptionMessage(bool connectorOwned, int numFields = 10) + { + _connectorOwned = connectorOwned; + _fields = new FieldDescription[numFields]; + _nameIndex = new Dictionary(); + } - internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, ConnectorTypeMapper typeMapper) - { - Fields.Clear(); - _nameIndex.Clear(); - _insensitiveIndex?.Clear(); + RowDescriptionMessage(RowDescriptionMessage source) + { + Count = source.Count; + _fields = new FieldDescription?[Count]; + for (var i = 0; i < Count; i++) + _fields[i] = source._fields[i]!.Clone(); + _nameIndex = new Dictionary(source._nameIndex); + if (source._insensitiveIndex?.Count > 0) + _insensitiveIndex = new Dictionary(source._insensitiveIndex, InvariantIgnoreCaseAndKanaWidthComparer); + } - var numFields = buf.ReadInt16(); - for (var i = 0; i != numFields; ++i) - { - FieldDescription field; - if (i >= Fields.Count) - { - field = new FieldDescription(); - Fields.Add(field); - } - else - field = Fields[i]; - - field.Populate( - typeMapper, - buf.ReadNullTerminatedString(), // Name - buf.ReadUInt32(), // TableOID - buf.ReadInt16(), // ColumnAttributeNumber - buf.ReadUInt32(), // TypeOID - buf.ReadInt16(), // TypeSize - buf.ReadInt32(), // TypeModifier - (FormatCode)buf.ReadInt16() // FormatCode - ); - - if (!_nameIndex.ContainsKey(field.Name)) - _nameIndex.Add(field.Name, i); - } + internal RowDescriptionMessage Load(NpgsqlReadBuffer buf, PgSerializerOptions options) + { + _nameIndex.Clear(); + _insensitiveIndex?.Clear(); - return this; + var numFields = Count = buf.ReadInt16(); + if (_fields.Length < numFields) + { + var oldFields = _fields; + _fields = new FieldDescription[numFields]; + Array.Copy(oldFields, _fields, oldFields.Length); } - internal FieldDescription this[int index] => Fields[index]; + for (var i = 0; i < numFields; ++i) + { + var field = _fields[i] ??= new(); + + field.Populate( + options, + name: buf.ReadNullTerminatedString(), + tableOID: buf.ReadUInt32(), + columnAttributeNumber: buf.ReadInt16(), + oid: buf.ReadUInt32(), + typeSize: buf.ReadInt16(), + typeModifier: buf.ReadInt32(), + dataFormat: DataFormatUtils.Create(buf.ReadInt16()) + ); + + _nameIndex.TryAdd(field.Name, i); + } - internal int NumFields => Fields.Count; + return this; + } - /// - /// Given a string name, returns the field's ordinal index in the row. - /// - internal int GetFieldIndex(string name) - => TryGetFieldIndex(name, out var ret) - ? ret - : throw new IndexOutOfRangeException("Field not found in row: " + name); + internal static RowDescriptionMessage CreateForReplication( + PgSerializerOptions options, uint tableOID, DataFormat dataFormat, IReadOnlyList columns) + { + var msg = new RowDescriptionMessage(false, columns.Count); + var numFields = msg.Count = columns.Count; - /// - /// Given a string name, returns the field's ordinal index in the row. - /// - internal bool TryGetFieldIndex(string name, out int fieldIndex) + for (var i = 0; i < numFields; ++i) { - if (_nameIndex.TryGetValue(name, out fieldIndex)) - return true; + var field = msg._fields[i] = new(); + var column = columns[i]; + + field.Populate( + options, + name: column.ColumnName, + tableOID: tableOID, + columnAttributeNumber: checked((short)i), + oid: column.DataTypeId, + typeSize: 0, // TODO: Confirm we don't have this in replication + typeModifier: column.TypeModifier, + dataFormat: dataFormat + ); + + if (!msg._nameIndex.ContainsKey(field.Name)) + msg._nameIndex.Add(field.Name, i); + } - if (_insensitiveIndex is null || _insensitiveIndex.Count == 0) - { - if (_insensitiveIndex == null) - _insensitiveIndex = new Dictionary(InsensitiveComparer.Instance); + return msg; + } - foreach (var kv in _nameIndex) - if (!_insensitiveIndex.ContainsKey(kv.Key)) - _insensitiveIndex[kv.Key] = kv.Value; - } + public FieldDescription this[int index] + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + Debug.Assert(index < Count); + Debug.Assert(_fields[index] != null); - return _insensitiveIndex.TryGetValue(name, out fieldIndex); + return _fields[index]!; } + } - public BackendMessageCode Code => BackendMessageCode.RowDescription; + internal void SetColumnInfoCache(ReadOnlySpan values) + { + if (_connectorOwned || _lastConverterInfoCache is not null) + return; + Interlocked.CompareExchange(ref _lastConverterInfoCache, values.ToArray(), null); + } - internal RowDescriptionMessage Clone() => new RowDescriptionMessage(this); + internal void LoadColumnInfoCache(PgSerializerOptions options, ColumnInfo[] values) + { + if (_lastConverterInfoCache is not { } cache) + return; - /// - /// Comparer that's case-insensitive and Kana width-insensitive - /// - sealed class InsensitiveComparer : IEqualityComparer + // If the options have changed (for instance due to ReloadTypes) we need to invalidate the cache. + if (Count > 0 && !ReferenceEquals(options, _fields[0]!._serializerOptions)) { - public static readonly InsensitiveComparer Instance = new InsensitiveComparer(); - static readonly CompareInfo CompareInfo = CultureInfo.InvariantCulture.CompareInfo; + Interlocked.CompareExchange(ref _lastConverterInfoCache, null, cache); + return; + } - InsensitiveComparer() {} + cache.CopyTo(values.AsSpan()); + } - // We should really have CompareOptions.IgnoreKanaType here, but see - // https://github.com/dotnet/corefx/issues/12518#issuecomment-389658716 - public bool Equals(string? x, string? y) - => CompareInfo.Compare(x, y, CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase | CompareOptions.IgnoreKanaType) == 0; + public int Count { get; private set; } - public int GetHashCode(string o) - => CompareInfo.GetSortKey(o, CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase | CompareOptions.IgnoreKanaType).GetHashCode(); - } + /// + /// Given a string name, returns the field's ordinal index in the row. + /// + internal int GetFieldIndex(string name) + { + if (!TryGetFieldIndex(name, out var ret)) + ThrowHelper.ThrowIndexOutOfRangeException($"Field not found in row: {name}"); + return ret; } /// - /// A descriptive record on a single field received from PostgreSQL. - /// See RowDescription in https://www.postgresql.org/docs/current/static/protocol-message-formats.html + /// Given a string name, returns the field's ordinal index in the row. /// - public sealed class FieldDescription + internal bool TryGetFieldIndex(string name, out int fieldIndex) { -#pragma warning disable CS8618 // Lazy-initialized type - internal FieldDescription() {} - - internal FieldDescription(uint oid) - : this("?", 0, 0, oid, 0, 0, FormatCode.Binary) {} + if (_nameIndex.TryGetValue(name, out fieldIndex)) + return true; - internal FieldDescription( - string name, uint tableOID, short columnAttributeNumber, - uint oid, short typeSize, int typeModifier, FormatCode formatCode) + if (_insensitiveIndex is null || _insensitiveIndex.Count == 0) { - Name = name; - TableOID = tableOID; - ColumnAttributeNumber = columnAttributeNumber; - TypeOID = oid; - TypeSize = typeSize; - TypeModifier = typeModifier; - FormatCode = formatCode; + if (_insensitiveIndex == null) + _insensitiveIndex = new Dictionary(InvariantIgnoreCaseAndKanaWidthComparer); + + foreach (var kv in _nameIndex) + _insensitiveIndex.TryAdd(kv.Key, kv.Value); } + + return _insensitiveIndex.TryGetValue(name, out fieldIndex); + } + + public BackendMessageCode Code => BackendMessageCode.RowDescription; + + internal RowDescriptionMessage Clone() => new(this); +} + +/// +/// A descriptive record on a single field received from PostgreSQL. +/// See RowDescription in https://www.postgresql.org/docs/current/static/protocol-message-formats.html +/// +public sealed class FieldDescription +{ +#pragma warning disable CS8618 // Lazy-initialized type + internal FieldDescription() { } + + internal FieldDescription(uint oid) + : this("?", 0, 0, oid, 0, 0, DataFormat.Binary) { } + + internal FieldDescription( + string name, uint tableOID, short columnAttributeNumber, + uint oid, short typeSize, int typeModifier, DataFormat dataFormat) + { + Name = name; + TableOID = tableOID; + ColumnAttributeNumber = columnAttributeNumber; + TypeOID = oid; + TypeSize = typeSize; + TypeModifier = typeModifier; + DataFormat = dataFormat; + } #pragma warning restore CS8618 - internal FieldDescription(FieldDescription source) - { - _typeMapper = source._typeMapper; - Name = source.Name; - TableOID = source.TableOID; - ColumnAttributeNumber = source.ColumnAttributeNumber; - TypeOID = source.TypeOID; - TypeSize = source.TypeSize; - TypeModifier = source.TypeModifier; - FormatCode = source.FormatCode; - Handler = source.Handler; - } + internal FieldDescription(FieldDescription source) + { + _serializerOptions = source._serializerOptions; + Name = source.Name; + TableOID = source.TableOID; + ColumnAttributeNumber = source.ColumnAttributeNumber; + TypeOID = source.TypeOID; + TypeSize = source.TypeSize; + TypeModifier = source.TypeModifier; + DataFormat = source.DataFormat; + PostgresType = source.PostgresType; + Field = source.Field; + _objectOrDefaultInfo = source._objectOrDefaultInfo; + } - internal void Populate( - ConnectorTypeMapper typeMapper, string name, uint tableOID, short columnAttributeNumber, - uint oid, short typeSize, int typeModifier, FormatCode formatCode - ) - { - _typeMapper = typeMapper; - Name = name; - TableOID = tableOID; - ColumnAttributeNumber = columnAttributeNumber; - TypeOID = oid; - TypeSize = typeSize; - TypeModifier = typeModifier; - FormatCode = formatCode; - - ResolveHandler(); - } + internal void Populate( + PgSerializerOptions serializerOptions, string name, uint tableOID, short columnAttributeNumber, + uint oid, short typeSize, int typeModifier, DataFormat dataFormat + ) + { + _serializerOptions = serializerOptions; + Name = name; + TableOID = tableOID; + ColumnAttributeNumber = columnAttributeNumber; + TypeOID = oid; + TypeSize = typeSize; + TypeModifier = typeModifier; + DataFormat = dataFormat; + PostgresType = _serializerOptions.DatabaseInfo.FindPostgresType((Oid)TypeOID)?.GetRepresentationalType() ?? UnknownBackendType.Instance; + Field = new(Name, _serializerOptions.ToCanonicalTypeId(PostgresType), TypeModifier); + _objectOrDefaultInfo = default; + } + + /// + /// The field name. + /// + internal string Name { get; set; } - /// - /// The field name. - /// - internal string Name { get; set; } + /// + /// The object ID of the field's data type. + /// + internal uint TypeOID { get; set; } - /// - /// The object ID of the field's data type. - /// - internal uint TypeOID { get; set; } + /// + /// The data type size (see pg_type.typlen). Note that negative values denote variable-width types. + /// + public short TypeSize { get; set; } - /// - /// The data type size (see pg_type.typlen). Note that negative values denote variable-width types. - /// - public short TypeSize { get; set; } + /// + /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + /// + public int TypeModifier { get; set; } + + /// + /// If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. + /// + internal uint TableOID { get; set; } + + /// + /// If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. + /// + internal short ColumnAttributeNumber { get; set; } - /// - /// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. - /// - public int TypeModifier { get; set; } + /// + /// The format code being used for the field. + /// Currently will be text or binary. + /// In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. + /// + internal DataFormat DataFormat { get; set; } - /// - /// If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. - /// - internal uint TableOID { get; set; } + internal Field Field { get; private set; } - /// - /// If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. - /// - internal short ColumnAttributeNumber { get; set; } + internal string TypeDisplayName => PostgresType.GetDisplayNameWithFacets(TypeModifier); - /// - /// The format code being used for the field. - /// Currently will be zero (text) or one (binary). - /// In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. - /// - internal FormatCode FormatCode { get; set; } + internal PostgresType PostgresType { get; private set; } - internal string TypeDisplayName => PostgresType.GetDisplayNameWithFacets(TypeModifier); + internal Type FieldType => ObjectOrDefaultInfo.TypeToConvert; - /// - /// The Npgsql type handler assigned to handle this field. - /// Returns for fields with format text. - /// - internal NpgsqlTypeHandler Handler { get; private set; } + ColumnInfo _objectOrDefaultInfo; + internal PgConverterInfo ObjectOrDefaultInfo + { + get + { + if (!_objectOrDefaultInfo.ConverterInfo.IsDefault) + return _objectOrDefaultInfo.ConverterInfo; - internal PostgresType PostgresType - => _typeMapper.DatabaseInfo.ByOID.TryGetValue(TypeOID, out var postgresType) - ? postgresType - : UnknownBackendType.Instance; + ref var info = ref _objectOrDefaultInfo; + GetInfo(null, ref _objectOrDefaultInfo); + return info.ConverterInfo; + } + } - internal Type FieldType => Handler.GetFieldType(this); + internal PgSerializerOptions _serializerOptions; - internal void ResolveHandler() - => Handler = IsBinaryFormat ? _typeMapper.GetByOID(TypeOID) : _typeMapper.UnrecognizedTypeHandler; + internal FieldDescription Clone() + { + var field = new FieldDescription(this); + return field; + } - ConnectorTypeMapper _typeMapper; + internal void GetInfo(Type? type, ref ColumnInfo lastColumnInfo) + { + Debug.Assert(lastColumnInfo.ConverterInfo.IsDefault || ( + ReferenceEquals(_serializerOptions, lastColumnInfo.ConverterInfo.TypeInfo.Options) && + lastColumnInfo.ConverterInfo.TypeInfo.PgTypeId == _serializerOptions.ToCanonicalTypeId(PostgresType)), "Cache is bleeding over"); - internal bool IsBinaryFormat => FormatCode == FormatCode.Binary; - internal bool IsTextFormat => FormatCode == FormatCode.Text; + if (!lastColumnInfo.ConverterInfo.IsDefault && lastColumnInfo.ConverterInfo.TypeToConvert == type) + return; - internal FieldDescription Clone() + var odfInfo = DataFormat is DataFormat.Text && type is not null ? ObjectOrDefaultInfo : _objectOrDefaultInfo.ConverterInfo; + if (odfInfo is { IsDefault: false }) { - var field = new FieldDescription(this); - field.ResolveHandler(); - return field; + if (typeof(object) == type) + { + lastColumnInfo = new(odfInfo, DataFormat, true); + return; + } + if (odfInfo.TypeToConvert == type) + { + // As TypeInfoMappingCollection is always adding object mappings for + // default/datatypename mappings, we'll also check Converter.TypeToConvert. + // If we have an exact match we are still able to use e.g. a converter for ints in an unboxed fashion. + lastColumnInfo = new(odfInfo, DataFormat, odfInfo.IsBoxingConverter && odfInfo.Converter.TypeToConvert != type); + return; + } } - /// - /// Returns a string that represents the current object. - /// - public override string ToString() => Name + (Handler == null ? "" : $"({Handler.PgDisplayName})"); + GetInfoSlow(type, out lastColumnInfo); + + [MethodImpl(MethodImplOptions.NoInlining)] + void GetInfoSlow(Type? type, out ColumnInfo lastColumnInfo) + { + var typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(object), PostgresType, _serializerOptions); + PgConverterInfo converterInfo; + switch (DataFormat) + { + case DataFormat.Binary: + // If we don't support binary we'll just throw. + converterInfo = typeInfo.Bind(Field, DataFormat); + lastColumnInfo = new(converterInfo, DataFormat.Binary, typeof(object) == type || converterInfo.IsBoxingConverter); + break; + default: + // For text we'll fall back to any available text converter for the expected clr type or throw. + if (!typeInfo.TryBind(Field, DataFormat, out converterInfo)) + { + typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(string), _serializerOptions.TextPgType, _serializerOptions); + converterInfo = typeInfo.Bind(Field, DataFormat); + lastColumnInfo = new(converterInfo, DataFormat, type != converterInfo.TypeToConvert || converterInfo.IsBoxingConverter); + } + else + lastColumnInfo = new(converterInfo, DataFormat, typeof(object) == type || converterInfo.IsBoxingConverter); + break; + } + + // We delay initializing ObjectOrDefaultInfo until after the first lookup (unless it is itself the first lookup). + // When passed in an unsupported type it allows the error to be more specific, instead of just having object/null to deal with. + if (_objectOrDefaultInfo.ConverterInfo.IsDefault && type is not null) + _ = ObjectOrDefaultInfo; + } } + + /// + /// Returns a string that represents the current object. + /// + public override string ToString() => Name + $"({PostgresType.DisplayName})"; } diff --git a/src/Npgsql/Common.cs b/src/Npgsql/Common.cs index 75953ee9a6..4a6d72757b 100644 --- a/src/Npgsql/Common.cs +++ b/src/Npgsql/Common.cs @@ -1,80 +1,81 @@ -namespace Npgsql +namespace Npgsql; + +/// +/// Base class for all classes which represent a message sent by the PostgreSQL backend. +/// +interface IBackendMessage { - /// - /// Base class for all classes which represent a message sent by the PostgreSQL backend. - /// - interface IBackendMessage - { - BackendMessageCode Code { get; } - } + BackendMessageCode Code { get; } +} - enum BackendMessageCode : byte - { - AuthenticationRequest = (byte)'R', - BackendKeyData = (byte)'K', - BindComplete = (byte)'2', - CloseComplete = (byte)'3', - CommandComplete = (byte)'C', - CopyData = (byte)'d', - CopyDone = (byte)'c', - CopyBothResponse = (byte)'W', - CopyInResponse = (byte)'G', - CopyOutResponse = (byte)'H', - DataRow = (byte)'D', - EmptyQueryResponse = (byte)'I', - ErrorResponse = (byte)'E', - FunctionCall = (byte)'F', - FunctionCallResponse = (byte)'V', - NoData = (byte)'n', - NoticeResponse = (byte)'N', - NotificationResponse = (byte)'A', - ParameterDescription = (byte)'t', - ParameterStatus = (byte)'S', - ParseComplete = (byte)'1', - PasswordPacket = (byte)' ', - PortalSuspended = (byte)'s', - ReadyForQuery = (byte)'Z', - RowDescription = (byte)'T', - } +enum BackendMessageCode : byte +{ + AuthenticationRequest = (byte)'R', + BackendKeyData = (byte)'K', + BindComplete = (byte)'2', + CloseComplete = (byte)'3', + CommandComplete = (byte)'C', + CopyData = (byte)'d', + CopyDone = (byte)'c', + CopyBothResponse = (byte)'W', + CopyInResponse = (byte)'G', + CopyOutResponse = (byte)'H', + DataRow = (byte)'D', + EmptyQueryResponse = (byte)'I', + ErrorResponse = (byte)'E', + FunctionCall = (byte)'F', + FunctionCallResponse = (byte)'V', + NoData = (byte)'n', + NoticeResponse = (byte)'N', + NotificationResponse = (byte)'A', + ParameterDescription = (byte)'t', + ParameterStatus = (byte)'S', + ParseComplete = (byte)'1', + PasswordPacket = (byte)' ', + PortalSuspended = (byte)'s', + ReadyForQuery = (byte)'Z', + RowDescription = (byte)'T', +} - static class FrontendMessageCode - { - internal const byte Describe = (byte)'D'; - internal const byte Sync = (byte)'S'; - internal const byte Execute = (byte)'E'; - internal const byte Parse = (byte)'P'; - internal const byte Bind = (byte)'B'; - internal const byte Close = (byte)'C'; - internal const byte Query = (byte)'Q'; - internal const byte CopyData = (byte)'d'; - internal const byte CopyDone = (byte)'c'; - internal const byte CopyFail = (byte)'f'; - internal const byte Terminate = (byte)'X'; - internal const byte Password = (byte)'p'; - } +static class FrontendMessageCode +{ + internal const byte Describe = (byte)'D'; + internal const byte Sync = (byte)'S'; + internal const byte Execute = (byte)'E'; + internal const byte Parse = (byte)'P'; + internal const byte Bind = (byte)'B'; + internal const byte Close = (byte)'C'; + internal const byte Query = (byte)'Q'; + internal const byte CopyData = (byte)'d'; + internal const byte CopyDone = (byte)'c'; + internal const byte CopyFail = (byte)'f'; + internal const byte Terminate = (byte)'X'; + internal const byte Password = (byte)'p'; +} - enum StatementOrPortal : byte - { - Statement = (byte)'S', - Portal = (byte)'P' - } +enum StatementOrPortal : byte +{ + Statement = (byte)'S', + Portal = (byte)'P' +} - /// - /// Specifies the type of SQL statement, e.g. SELECT - /// - public enum StatementType - { +/// +/// Specifies the type of SQL statement, e.g. SELECT +/// +public enum StatementType +{ #pragma warning disable 1591 - Unknown, - Select, - Insert, - Delete, - Update, - CreateTableAs, - Move, - Fetch, - Copy, - Other + Unknown, + Select, + Insert, + Delete, + Update, + CreateTableAs, + Move, + Fetch, + Copy, + Other, + Merge, + Call #pragma warning restore 1591 - } } diff --git a/src/Npgsql/ConnectorPool.Multiplexing.cs b/src/Npgsql/ConnectorPool.Multiplexing.cs deleted file mode 100644 index d5984d688d..0000000000 --- a/src/Npgsql/ConnectorPool.Multiplexing.cs +++ /dev/null @@ -1,439 +0,0 @@ -using System; -using System.Diagnostics; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using Npgsql.TypeMapping; -using Npgsql.Util; -using static Npgsql.Util.Statics; - -namespace Npgsql -{ - sealed partial class ConnectorPool - { - readonly ChannelReader? _multiplexCommandReader; - internal ChannelWriter? MultiplexCommandWriter { get; } - - const int WriteCoalescineDelayAdaptivityUs = 10; - - /// - /// A pool-wide type mapper used when multiplexing. This is necessary because binding parameters - /// to their type handlers happens *before* the command is enqueued for execution, so there's no - /// connector yet at that stage. - /// - internal ConnectorTypeMapper? MultiplexingTypeMapper { get; private set; } - - /// - /// When multiplexing is enabled, determines the maximum amount of time to wait for further - /// commands before flushing to the network. In ticks (100ns), 0 disables waiting. - /// This is in 100ns ticks, not ticks whose meaning vary across platforms. - /// - readonly long _writeCoalescingDelayTicks; - - /// - /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before - /// flushing to the network. - /// - readonly int _writeCoalescingBufferThresholdBytes; - - readonly SemaphoreSlim? _bootstrapSemaphore; - - /// - /// Called exactly once per multiplexing pool, when the first connection is opened, with two goals: - /// 1. Load types and bind the pool-wide type mapper (necessary for binding parameters) - /// 2. Cause any connection exceptions (e.g. bad username) to be thrown from NpgsqlConnection.Open - /// - internal async Task BootstrapMultiplexing(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(_multiplexing); - - var hasSemaphore = async - ? await _bootstrapSemaphore!.WaitAsync(timeout.TimeLeft, cancellationToken) - : _bootstrapSemaphore!.Wait(timeout.TimeLeft, cancellationToken); - - // We've timed out - calling Check, to throw the correct exception - if (!hasSemaphore) - timeout.Check(); - - try - { - if (IsBootstrapped) - return; - - var connector = await conn.StartBindingScope(ConnectorBindingScope.Connection, timeout, async, cancellationToken); - using var _ = Defer(static conn => conn.EndBindingScope(ConnectorBindingScope.Connection), conn); - - // Somewhat hacky. Extract the connector's type mapper as our pool-wide mapper, - // and have the connector rebind to ensure it has a different instance. - // The latter isn't strictly necessary (type mappers should always be usable - // concurrently) but just in case. - MultiplexingTypeMapper = connector.TypeMapper; - await connector.LoadDatabaseInfo(false, timeout, async, cancellationToken); - - IsBootstrapped = true; - } - finally - { - _bootstrapSemaphore!.Release(); - } - } - - async Task MultiplexingWriteLoop() - { - // This method is async, but only ever yields when there are no pending commands in the command channel. - // No I/O should ever be performed asynchronously, as that would block further writing for the entire - // application; whenever I/O cannot complete immediately, we chain a callback with ContinueWith and move - // on to the next connector. - Debug.Assert(_multiplexCommandReader != null); - - var timeout = _writeCoalescingDelayTicks / 2; - var timeoutTokenSource = new ResettableCancellationTokenSource(TimeSpan.FromTicks(timeout)); - var timeoutToken = timeout == 0 ? CancellationToken.None : timeoutTokenSource.Token; - - while (true) - { - var stats = new MultiplexingStats { Stopwatch = new Stopwatch() }; - NpgsqlConnector? connector; - - // Get a first command out. - if (!_multiplexCommandReader.TryRead(out var command)) - command = await _multiplexCommandReader.ReadAsync(); - - try - { - // First step is to get a connector on which to execute - var spinwait = new SpinWait(); - while (true) - { - if (TryGetIdleConnector(out connector)) - { - // See increment under over-capacity mode below - Interlocked.Increment(ref connector.CommandsInFlightCount); - break; - } - - connector = await OpenNewConnector( - command.Connection!, - new NpgsqlTimeout(TimeSpan.FromSeconds(Settings.Timeout)), - async: true, - CancellationToken.None); - - if (connector != null) - { - // Managed to created a new connector - connector.Connection = null; - - // See increment under over-capacity mode below - Interlocked.Increment(ref connector.CommandsInFlightCount); - - break; - } - - // There were no idle connectors and we're at max capacity, so we can't open a new one. - // Enter over-capacity mode - find an unlocked connector with the least currently in-flight - // commands and sent on it, even though there are already pending commands. - var minInFlight = int.MaxValue; - foreach (var c in _connectors) - { - if (c?.MultiplexAsyncWritingLock == 0 && c.CommandsInFlightCount < minInFlight) - { - minInFlight = c.CommandsInFlightCount; - connector = c; - } - } - - // There could be no writable connectors (all stuck in transaction or flushing). - if (connector == null) - { - // TODO: This is problematic - when absolutely all connectors are both busy *and* currently - // performing (async) I/O, this will spin-wait. - // We could call WaitAsync, but that would wait for an idle connector, whereas we want any - // writeable (non-writing) connector even if it has in-flight commands. Maybe something - // with better back-off. - // On the other hand, this is exactly *one* thread doing spin-wait, maybe not that bad. - spinwait.SpinOnce(); - continue; - } - - // We may be in a race condition with the connector read loop, which may be currently returning - // the connector to the Idle channel (because it has completed all commands). - // Increment the in-flight count to make sure the connector isn't returned as idle. - var newInFlight = Interlocked.Increment(ref connector.CommandsInFlightCount); - if (newInFlight == 1) - { - // The connector's in-flight was 0, so it was idle - abort over-capacity read - // and retry the normal flow. - Interlocked.Decrement(ref connector.CommandsInFlightCount); - spinwait.SpinOnce(); - continue; - } - - break; - } - } - catch (Exception ex) - { - Log.Error("Exception opening a connection", ex); - - // Fail the first command in the channel as a way of bubbling the exception up to the user - command.ExecutionCompletion.SetException(ex); - - continue; - } - - // We now have a ready connector, and can start writing commands to it. - Debug.Assert(connector != null); - - try - { - stats.Reset(); - connector.FlagAsNotWritableForMultiplexing(); - - // Read queued commands and write them to the connector's buffer, for as long as we're - // under our write threshold and timer delay. - // Note we already have one command we read above, and have already updated the connector's - // CommandsInFlightCount. Now write that command. - var writtenSynchronously = WriteCommand(connector, command, ref stats); - - if (timeout == 0) - { - while (connector.WriteBuffer.WritePosition < _writeCoalescingBufferThresholdBytes && - writtenSynchronously && - _multiplexCommandReader.TryRead(out command)) - { - Interlocked.Increment(ref connector.CommandsInFlightCount); - writtenSynchronously = WriteCommand(connector, command, ref stats); - } - } - else - { - timeoutToken = timeoutTokenSource.Start(); - - try - { - while (connector.WriteBuffer.WritePosition < _writeCoalescingBufferThresholdBytes && - writtenSynchronously) - { - if (!_multiplexCommandReader.TryRead(out command)) - { - stats.Waits++; - command = await _multiplexCommandReader.ReadAsync(timeoutToken); - } - - Interlocked.Increment(ref connector.CommandsInFlightCount); - writtenSynchronously = WriteCommand(connector, command, ref stats); - } - - // The cancellation token (presumably!) has not fired, reset its timer so - // we can reuse the cancellation token source instead of reallocating - timeoutTokenSource.Stop(); - - // Increase the timeout slightly for next time: we're under load, so allow more - // commands to get coalesced into the same packet (up to the hard limit) - timeout = Math.Min(timeout + WriteCoalescineDelayAdaptivityUs, _writeCoalescingDelayTicks); - } - catch (OperationCanceledException) - { - // Timeout fired, we're done writing. - // Reduce the timeout slightly for next time: we're under little load, so reduce impact - // on latency - timeout = Math.Max(timeout - WriteCoalescineDelayAdaptivityUs, 0); - } - } - - // If all commands were written synchronously (good path), complete the write here, flushing - // and updating statistics. If not, CompleteRewrite is scheduled to run later, when the async - // operations complete, so skip it and continue. - if (writtenSynchronously) - Flush(connector, ref stats); - } - catch (Exception ex) - { - FailWrite(connector, ex); - } - } - - bool WriteCommand(NpgsqlConnector connector, NpgsqlCommand command, ref MultiplexingStats stats) - { - // Note: this method *never* awaits on I/O - doing so would suspend all outgoing multiplexing commands - // for the entire pool. In the normal/fast case, writing the command is purely synchronous (serialize - // to buffer in memory), and the actual flush will occur at the level above. For cases where the - // command overflows the buffer, async I/O is done, and we schedule continuations separately - - // but the main thread continues to handle other commands on other connectors. - if (_autoPrepare) - { - var numPrepared = 0; - foreach (var statement in command._statements) - { - // If this statement isn't prepared, see if it gets implicitly prepared. - // Note that this may return null (not enough usages for automatic preparation). - if (!statement.IsPrepared) - statement.PreparedStatement = connector.PreparedStatementManager.TryGetAutoPrepared(statement); - if (statement.PreparedStatement is PreparedStatement pStatement) - { - numPrepared++; - if (pStatement?.State == PreparedState.NotPrepared) - { - pStatement.State = PreparedState.BeingPrepared; - statement.IsPreparing = true; - } - } - } - } - - var written = connector.CommandsInFlightWriter!.TryWrite(command); - Debug.Assert(written, $"Failed to enqueue command to {connector.CommandsInFlightWriter}"); - - // Purposefully don't wait for I/O to complete - var task = command.Write(connector, async: true); - stats.NumCommands++; - - switch (task.Status) - { - case TaskStatus.RanToCompletion: - return true; - - case TaskStatus.Faulted: - task.GetAwaiter().GetResult(); // Throw the exception - return true; - - case TaskStatus.WaitingForActivation: - case TaskStatus.Running: - { - // Asynchronous completion, which means the writing is flushing to network and there's actual I/O - // (i.e. a big command which overflowed our buffer). - // We don't (ever) await in the write loop, so remove the connector from the writable list (as it's - // still flushing) and schedule a continuation to continue taking care of this connector. - // The write loop continues to the next connector. - - // Create a copy of the statistics and purposefully box it via the closure. We need a separate - // copy of the stats for the async writing that will continue in parallel with this loop. - var clonedStats = stats.Clone(); - - // ReSharper disable once MethodSupportsCancellation - task.ContinueWith((t, o) => - { - var conn = (NpgsqlConnector)o!; - - if (t.IsFaulted) - { - FailWrite(conn, t.Exception!.UnwrapAggregate()); - return; - } - - // There's almost certainly more buffered outgoing data for the command, after the flush - // occured. Complete the write, which will flush again (and update statistics). - try - { - Flush(conn, ref clonedStats); - } - catch (Exception e) - { - FailWrite(conn, e); - } - }, connector); - - return false; - } - - default: - Debug.Fail("When writing command to connector, task is in invalid state " + task.Status); - throw new Exception("When writing command to connector, task is in invalid state " + task.Status); - } - } - - void Flush(NpgsqlConnector connector, ref MultiplexingStats stats) - { - var task = connector.Flush(async: true); - switch (task.Status) - { - case TaskStatus.RanToCompletion: - CompleteWrite(connector, ref stats); - return; - - case TaskStatus.Faulted: - task.GetAwaiter().GetResult(); // Throw the exception - return; - - case TaskStatus.WaitingForActivation: - case TaskStatus.Running: - { - // Asynchronous completion - the flush didn't complete immediately (e.g. TCP zero window). - - // Create a copy of the statistics and purposefully box it via the closure. We need a separate - // copy of the stats for the async writing that will continue in parallel with this loop. - var clonedStats = stats.Clone(); - - task.ContinueWith((t, o) => - { - var conn = (NpgsqlConnector)o!; - if (t.IsFaulted) - { - FailWrite(conn, t.Exception!.UnwrapAggregate()); - return; - } - - CompleteWrite(conn, ref clonedStats); - }, connector); - - return; - } - - default: - Debug.Fail("When flushing, task is in invalid state " + task.Status); - throw new Exception("When flushing, task is in invalid state " + task.Status); - } - } - - void FailWrite(NpgsqlConnector connector, Exception exception) - { - // Note that all commands already passed validation before being enqueued. This means any error - // here is either an unrecoverable network issue (in which case we're already broken), or some other - // issue while writing (e.g. invalid UTF8 characters in the SQL query) - unrecoverable in any case. - - // All commands enqueued in CommandsInFlightWriter will be drained by the reader and failed. - // Note that some of these commands where only written to the connector's buffer, but never - // actually sent - because of a later exception. - // In theory, we could track commands that were only enqueued and not sent, and retry those - // (on another connector), but that would add some book-keeping and complexity, and in any case - // if one connector was broken, chances are that all are (networking). - Debug.Assert(connector.IsBroken); - - Log.Error("Exception while writing commands", exception, connector.Id); - } - - static void CompleteWrite(NpgsqlConnector connector, ref MultiplexingStats stats) - { - // All I/O has completed, mark this connector as safe for writing again. - // This will allow the connector to be returned to the pool by its read loop, and also to be selected - // for over-capacity write. - connector.FlagAsWritableForMultiplexing(); - - NpgsqlEventSource.Log.MultiplexingBatchSent(stats.NumCommands, stats.Waits, stats.Stopwatch!); - } - - // ReSharper disable once FunctionNeverReturns - } - - struct MultiplexingStats - { - internal Stopwatch Stopwatch; - internal int NumCommands; - internal int Waits; - - internal void Reset() - { - Stopwatch.Restart(); - NumCommands = 0; - Waits = 0; - } - - internal MultiplexingStats Clone() - { - var clone = new MultiplexingStats { Stopwatch = Stopwatch, NumCommands = NumCommands }; - Stopwatch = new Stopwatch(); - return clone; - } - } - } -} diff --git a/src/Npgsql/ConnectorPool.cs b/src/Npgsql/ConnectorPool.cs deleted file mode 100644 index a4f211ef66..0000000000 --- a/src/Npgsql/ConnectorPool.cs +++ /dev/null @@ -1,530 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using System.Transactions; -using Npgsql.Logging; -using Npgsql.Util; -using static Npgsql.Util.Statics; - -namespace Npgsql -{ - sealed partial class ConnectorPool - { - #region Fields and properties - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(ConnectorPool)); - - internal NpgsqlConnectionStringBuilder Settings { get; } - - /// - /// Contains the connection string returned to the user from - /// after the connection has been opened. Does not contain the password unless Persist Security Info=true. - /// - internal string UserFacingConnectionString { get; } - - readonly int _max; - readonly int _min; - readonly bool _autoPrepare; - readonly TimeSpan _connectionLifetime; - volatile int _numConnectors; - - public bool IsBootstrapped - { - get => _isBootstrapped; - set => _isBootstrapped = value; - } - - volatile bool _isBootstrapped; - - volatile int _idleCount; - - /// - /// Tracks all connectors currently managed by this pool, whether idle or busy. - /// Only updated rarely - when physical connections are opened/closed - but is read in perf-sensitive contexts. - /// - readonly NpgsqlConnector?[] _connectors; - - readonly bool _multiplexing; - - /// - /// Reader side for the idle connector channel. Contains nulls in order to release waiting attempts after - /// a connector has been physically closed/broken. - /// - readonly ChannelReader _idleConnectorReader; - internal ChannelWriter IdleConnectorWriter { get; } - - /// - /// Incremented every time this pool is cleared via or - /// . Allows us to identify connections which were - /// created before the clear. - /// - volatile int _clearCounter; - - static readonly TimerCallback PruningTimerCallback = PruneIdleConnectors; - readonly Timer _pruningTimer; - readonly TimeSpan _pruningSamplingInterval; - readonly int _pruningSampleSize; - readonly int[] _pruningSamples; - readonly int _pruningMedianIndex; - volatile bool _pruningTimerEnabled; - int _pruningSampleIndex; - - // Note that while the dictionary is protected by locking, we assume that the lists it contains don't need to be - // (i.e. access to connectors of a specific transaction won't be concurrent) - readonly Dictionary> _pendingEnlistedConnectors - = new Dictionary>(); - - static readonly SingleThreadSynchronizationContext SingleThreadSynchronizationContext = new SingleThreadSynchronizationContext("NpgsqlRemainingAsyncSendWorker"); - - // TODO: Make this configurable - const int MultiexingCommandChannelBound = 4096; - - #endregion - - internal (int Total, int Idle, int Busy) Statistics - { - get - { - var numConnectors = _numConnectors; - var idleCount = _idleCount; - return (numConnectors, idleCount, numConnectors - idleCount); - } - } - - internal ConnectorPool(NpgsqlConnectionStringBuilder settings, string connString) - { - if (settings.MaxPoolSize < settings.MinPoolSize) - throw new ArgumentException($"Connection can't have MaxPoolSize {settings.MaxPoolSize} under MinPoolSize {settings.MinPoolSize}"); - - // We enforce Max Pool Size, so no need to to create a bounded channel (which is less efficient) - // On the consuming side, we have the multiplexing write loop but also non-multiplexing Rents - // On the producing side, we have connections being released back into the pool (both multiplexing and not) - var idleChannel = Channel.CreateUnbounded(); - _idleConnectorReader = idleChannel.Reader; - IdleConnectorWriter = idleChannel.Writer; - - _max = settings.MaxPoolSize; - _min = settings.MinPoolSize; - - UserFacingConnectionString = settings.PersistSecurityInfo - ? connString - : settings.ToStringWithoutPassword(); - - Settings = settings; - - if (settings.ConnectionPruningInterval == 0) - throw new ArgumentException("ConnectionPruningInterval can't be 0."); - var connectionIdleLifetime = TimeSpan.FromSeconds(settings.ConnectionIdleLifetime); - var pruningSamplingInterval = TimeSpan.FromSeconds(settings.ConnectionPruningInterval); - if (connectionIdleLifetime < pruningSamplingInterval) - throw new ArgumentException($"Connection can't have ConnectionIdleLifetime {connectionIdleLifetime} under ConnectionPruningInterval {_pruningSamplingInterval}"); - - _pruningTimer = new Timer(PruningTimerCallback, this, Timeout.Infinite, Timeout.Infinite); - _pruningSampleSize = DivideRoundingUp(settings.ConnectionIdleLifetime, settings.ConnectionPruningInterval); - _pruningMedianIndex = DivideRoundingUp(_pruningSampleSize, 2) - 1; // - 1 to go from length to index - _pruningSamplingInterval = pruningSamplingInterval; - _pruningSamples = new int[_pruningSampleSize]; - _pruningTimerEnabled = false; - - _max = settings.MaxPoolSize; - _min = settings.MinPoolSize; - _autoPrepare = settings.MaxAutoPrepare > 0; - _connectionLifetime = TimeSpan.FromSeconds(settings.ConnectionLifetime); - _connectors = new NpgsqlConnector[_max]; - - // TODO: Validate multiplexing options are set only when Multiplexing is on - - if (Settings.Multiplexing) - { - _multiplexing = true; - - _bootstrapSemaphore = new SemaphoreSlim(1); - - // Translate microseconds to ticks for cancellation token - _writeCoalescingDelayTicks = Settings.WriteCoalescingDelayUs * 100; - _writeCoalescingBufferThresholdBytes = Settings.WriteCoalescingBufferThresholdBytes; - - var multiplexCommandChannel = Channel.CreateBounded( - new BoundedChannelOptions(MultiexingCommandChannelBound) - { - FullMode = BoundedChannelFullMode.Wait, - SingleReader = true - }); - _multiplexCommandReader = multiplexCommandChannel.Reader; - MultiplexCommandWriter = multiplexCommandChannel.Writer; - - // TODO: Think about cleanup for this, e.g. completing the channel at application shutdown and/or - // pool clearing - - _ = Task.Run(MultiplexingWriteLoop) - .ContinueWith(t => - { - // Note that we *must* observe the exception if the task is faulted. - Log.Error("Exception in multiplexing write loop, this is an Npgsql bug, please file an issue.", - t.Exception!); - }, TaskContinuationOptions.OnlyOnFaulted); - } - } - - internal ValueTask Rent( - NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - return TryGetIdleConnector(out var connector) - ? new ValueTask(AssignConnection(conn, connector)) - : RentAsync(); - - async ValueTask RentAsync() - { - // First, try to open a new physical connector. This will fail if we're at max capacity. - connector = await OpenNewConnector(conn, timeout, async, cancellationToken); - if (connector != null) - return AssignConnection(conn, connector); - - // We're at max capacity. Block on the idle channel with a timeout. - // Note that Channels guarantee fair FIFO behavior to callers of ReadAsync (first-come first- - // served), which is crucial to us. - using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - var finalToken = linkedSource.Token; - linkedSource.CancelAfter(timeout.TimeLeft); - - while (true) - { - try - { - if (async) - { - connector = await _idleConnectorReader.ReadAsync(finalToken); - if (CheckIdleConnector(connector)) - return AssignConnection(conn, connector); - } - else - { - // Channels don't have a sync API. To avoid sync-over-async issues, we use a special single- - // thread synchronization context which ensures that callbacks are executed on a dedicated - // thread. - // Note that AsTask isn't safe here for getting the result, since it still causes some continuation code - // to get executed on the TP (which can cause deadlocks). - using (SingleThreadSynchronizationContext.Enter()) - using (var mre = new ManualResetEventSlim()) - { - _idleConnectorReader.WaitToReadAsync(finalToken).GetAwaiter().OnCompleted(() => mre.Set()); - mre.Wait(finalToken); - } - } - } - catch (OperationCanceledException) - { - cancellationToken.ThrowIfCancellationRequested(); - Debug.Assert(finalToken.IsCancellationRequested); - throw new NpgsqlException( - $"The connection pool has been exhausted, either raise MaxPoolSize (currently {_max}) " + - $"or Timeout (currently {Settings.Timeout} seconds)"); - } - catch (ChannelClosedException) - { - throw new NpgsqlException("The connection pool has been shut down."); - } - - // If we're here, our waiting attempt on the idle connector channel was released with a null - // (or bad connector), or we're in sync mode. Check again if a new idle connector has appeared since we last checked. - if (TryGetIdleConnector(out connector)) - return AssignConnection(conn, connector); - - // We might have closed a connector in the meantime and no longer be at max capacity - // so try to open a new connector and if that fails, loop again. - connector = await OpenNewConnector(conn, timeout, async, cancellationToken); - if (connector != null) - return AssignConnection(conn, connector); - } - } - - static NpgsqlConnector AssignConnection(NpgsqlConnection connection, NpgsqlConnector connector) - { - connector.Connection = connection; - connection.Connector = connector; - return connector; - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) - { - while (_idleConnectorReader.TryRead(out var nullableConnector)) - { - if (CheckIdleConnector(nullableConnector)) - { - connector = nullableConnector; - return true; - } - } - - connector = null; - return false; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) - { - if (connector is null) - return false; - - // Only decrement when the connector has a value. - Interlocked.Decrement(ref _idleCount); - - // An connector could be broken because of a keepalive that occurred while it was - // idling in the pool - // TODO: Consider removing the pool from the keepalive code. The following branch is simply irrelevant - // if keepalive isn't turned on. - if (connector.IsBroken) - { - CloseConnector(connector); - return false; - } - - if (_connectionLifetime != TimeSpan.Zero && DateTime.UtcNow > connector.OpenTimestamp + _connectionLifetime) - { - Log.Debug("Connection has exceeded its maximum lifetime and will be closed.", connector.Id); - CloseConnector(connector); - return false; - } - - Debug.Assert(connector.State == ConnectorState.Ready, - $"Got idle connector but {nameof(connector.State)} is {connector.State}"); - Debug.Assert(connector.CommandsInFlightCount == 0, - $"Got idle connector but {nameof(connector.CommandsInFlightCount)} is {connector.CommandsInFlightCount}"); - Debug.Assert(connector.MultiplexAsyncWritingLock == 0, - $"Got idle connector but {nameof(connector.MultiplexAsyncWritingLock)} is 1"); - - return true; - } - - async ValueTask OpenNewConnector( - NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - // As long as we're under max capacity, attempt to increase the connector count and open a new connection. - for (var numConnectors = _numConnectors; numConnectors < _max; numConnectors = _numConnectors) - { - // Note that we purposefully don't use SpinWait for this: https://github.com/dotnet/coreclr/pull/21437 - if (Interlocked.CompareExchange(ref _numConnectors, numConnectors + 1, numConnectors) != numConnectors) - continue; - - try - { - // We've managed to increase the open counter, open a physical connections. - var connector = new NpgsqlConnector(conn) { ClearCounter = _clearCounter }; - await connector.Open(timeout, async, cancellationToken); - - var i = 0; - for (; i < _max; i++) - if (Interlocked.CompareExchange(ref _connectors[i], connector, null) == null) - break; - - Debug.Assert(i < _max, $"Could not find free slot in {_connectors} when opening."); - if (i == _max) - throw new NpgsqlException($"Could not find free slot in {_connectors} when opening. Please report a bug."); - - // Only start pruning if it was this thread that incremented open count past _min. - if (numConnectors == _min) - EnablePruning(); - - return connector; - } - catch - { - // Physical open failed, decrement the open and busy counter back down. - conn.Connector = null; - Interlocked.Decrement(ref _numConnectors); - - // In case there's a waiting attempt on the channel, we write a null to the idle connector channel - // to wake it up, so it will try opening (and probably throw immediately) - IdleConnectorWriter.TryWrite(null); - - throw; - } - } - - return null; - } - - internal void Return(NpgsqlConnector connector) - { - Debug.Assert(!connector.InTransaction); - Debug.Assert(connector.MultiplexAsyncWritingLock == 0 || connector.IsBroken || connector.IsClosed, - $"About to return multiplexing connector to the pool, but {nameof(connector.MultiplexAsyncWritingLock)} is {connector.MultiplexAsyncWritingLock}"); - - // If Clear/ClearAll has been been called since this connector was first opened, - // throw it away. The same if it's broken (in which case CloseConnector is only - // used to update state/perf counter). - if (connector.ClearCounter < _clearCounter || connector.IsBroken) - { - CloseConnector(connector); - return; - } - - // Statement order is important since we have synchronous completions on the channel. - Interlocked.Increment(ref _idleCount); - var written = IdleConnectorWriter.TryWrite(connector); - Debug.Assert(written); - } - - internal void Clear() - { - Interlocked.Increment(ref _clearCounter); - - var count = _idleCount; - while (count > 0 && _idleConnectorReader.TryRead(out var connector)) - { - if (CheckIdleConnector(connector)) - { - CloseConnector(connector); - count--; - } - } - } - - void CloseConnector(NpgsqlConnector connector) - { - try - { - connector.Close(); - } - catch (Exception e) - { - Log.Warn("Exception while closing connector", e, connector.Id); - } - - // If a connector has been closed for any reason, we write a null to the idle connector channel to wake up - // a waiter, who will open a new physical connection - IdleConnectorWriter.TryWrite(null); - - var i = 0; - for (; i < _max; i++) - if (Interlocked.CompareExchange(ref _connectors[i], null, connector) == connector) - break; - - Debug.Assert(i < _max, $"Could not find free slot in {_connectors} when closing."); - if (i == _max) - throw new NpgsqlException($"Could not find free slot in {_connectors} when closing. Please report a bug."); - - var numConnectors = Interlocked.Decrement(ref _numConnectors); - Debug.Assert(numConnectors >= 0); - // Only turn off the timer one time, when it was this Close that brought Open back to _min. - if (numConnectors == _min) - DisablePruning(); - } - - #region Pending Enlisted Connections - - internal void AddPendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) - { - lock (_pendingEnlistedConnectors) - { - if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) - list = _pendingEnlistedConnectors[transaction] = new List(); - list.Add(connector); - } - } - - internal void TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) - { - lock (_pendingEnlistedConnectors) - { - if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) - return; - list.Remove(connector); - if (list.Count == 0) - _pendingEnlistedConnectors.Remove(transaction); - } - } - - internal bool TryRentEnlistedPending(Transaction transaction, [NotNullWhen(true)] out NpgsqlConnector? connector) - { - lock (_pendingEnlistedConnectors) - { - if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) - { - connector = null; - return false; - } - connector = list[list.Count - 1]; - list.RemoveAt(list.Count - 1); - if (list.Count == 0) - _pendingEnlistedConnectors.Remove(transaction); - return true; - } - } - - #endregion - - #region Pruning - - // Manual reactivation of timer happens in callback - void EnablePruning() - { - lock (_pruningTimer) - { - _pruningTimerEnabled = true; - _pruningTimer.Change(_pruningSamplingInterval, Timeout.InfiniteTimeSpan); - } - } - - void DisablePruning() - { - lock (_pruningTimer) - { - _pruningTimer.Change(Timeout.Infinite, Timeout.Infinite); - _pruningSampleIndex = 0; - _pruningTimerEnabled = false; - } - } - - static void PruneIdleConnectors(object? state) - { - var pool = (ConnectorPool)state!; - var samples = pool._pruningSamples; - int toPrune; - lock (pool._pruningTimer) - { - // Check if we might have been contending with DisablePruning. - if (!pool._pruningTimerEnabled) - return; - - var sampleIndex = pool._pruningSampleIndex; - samples[sampleIndex] = pool._idleCount; - if (sampleIndex != pool._pruningSampleSize - 1) - { - pool._pruningSampleIndex = sampleIndex + 1; - pool._pruningTimer.Change(pool._pruningSamplingInterval, Timeout.InfiniteTimeSpan); - return; - } - - // Calculate median value for pruning, reset index and timer, and release the lock. - Array.Sort(samples); - toPrune = samples[pool._pruningMedianIndex]; - pool._pruningSampleIndex = 0; - pool._pruningTimer.Change(pool._pruningSamplingInterval, Timeout.InfiniteTimeSpan); - } - - while (toPrune > 0 && - pool._numConnectors > pool._min && - pool._idleConnectorReader.TryRead(out var connector) && - connector != null) - { - if (pool.CheckIdleConnector(connector)) - { - pool.CloseConnector(connector); - toPrune--; - } - } - } - - static int DivideRoundingUp(int value, int divisor) => 1 + (value - 1) / divisor; - - #endregion - } -} diff --git a/src/Npgsql/DatabaseState.cs b/src/Npgsql/DatabaseState.cs new file mode 100644 index 0000000000..d8f6dfd4f1 --- /dev/null +++ b/src/Npgsql/DatabaseState.cs @@ -0,0 +1,10 @@ +namespace Npgsql; + +enum DatabaseState : byte +{ + Unknown = 0, + Offline = 1, + PrimaryReadWrite = 2, + PrimaryReadOnly = 3, + Standby = 4 +} diff --git a/src/Npgsql/GlobalSuppressions.cs b/src/Npgsql/GlobalSuppressions.cs index 07bef71ab3..580c453b9d 100644 --- a/src/Npgsql/GlobalSuppressions.cs +++ b/src/Npgsql/GlobalSuppressions.cs @@ -1,7 +1,7 @@  -// This file is used by Code Analysis to maintain SuppressMessage +// This file is used by Code Analysis to maintain SuppressMessage // attributes that are applied to this project. -// Project-level suppressions either have no target or are given +// Project-level suppressions either have no target or are given // a specific target and scoped to a namespace, type, member, etc. using System.Diagnostics.CodeAnalysis; @@ -10,6 +10,5 @@ [assembly: SuppressMessage("Design", "CA1032:Implement standard exception constructors", Justification = "We have several exception classes where this makes no sense")] [assembly: SuppressMessage("Design", "CA1710:Identifiers should have correct suffix", Justification = "Disagree")] [assembly: SuppressMessage("Design", "CA1707:Remove the underscores from member name", Justification = "Seems to cause some false positives on implicit/explicit cast operators, strange")] -[assembly: SuppressMessage("Reliability", "CA2007:Do not directly await a Task", Justification = "Npgsql uses NoSynchronizationContextScope instead of ConfigureAwait(false)")] [assembly: SuppressMessage("Style", "IDE1006:Naming Styles", Justification = "All I/O methods are both sync and async, avoid clutter")] diff --git a/src/Npgsql/ICancelable.cs b/src/Npgsql/ICancelable.cs index 301e15f5f2..460f17c171 100644 --- a/src/Npgsql/ICancelable.cs +++ b/src/Npgsql/ICancelable.cs @@ -1,9 +1,11 @@ using System; +using System.Threading.Tasks; -namespace Npgsql +namespace Npgsql; + +interface ICancelable : IDisposable, IAsyncDisposable { - interface ICancelable : IDisposable - { - void Cancel(); - } -} + void Cancel(); + + Task CancelAsync(); +} \ No newline at end of file diff --git a/src/Npgsql/INpgsqlDatabaseInfoFactory.cs b/src/Npgsql/INpgsqlDatabaseInfoFactory.cs deleted file mode 100644 index 3a8a9d6861..0000000000 --- a/src/Npgsql/INpgsqlDatabaseInfoFactory.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System.Threading.Tasks; -using Npgsql.Util; - -namespace Npgsql -{ - /// - /// A factory which get generate instances of , which describe a database - /// and the types it contains. When first connecting to a database, Npgsql will attempt to load information - /// about it via this factory. - /// - public interface INpgsqlDatabaseInfoFactory - { - /// - /// Given a connection, loads all necessary information about the connected database, e.g. its types. - /// A factory should only handle the exact database type it was meant for, and return null otherwise. - /// - /// - /// An object describing the database to which is connected, or null if the - /// database isn't of the correct type and isn't handled by this factory. - /// - Task Load(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async); - } -} diff --git a/src/Npgsql/Internal/AdoSerializerHelpers.cs b/src/Npgsql/Internal/AdoSerializerHelpers.cs new file mode 100644 index 0000000000..d0ea19c7a8 --- /dev/null +++ b/src/Npgsql/Internal/AdoSerializerHelpers.cs @@ -0,0 +1,62 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal; + +static class AdoSerializerHelpers +{ + public static PgTypeInfo GetTypeInfoForReading(Type type, PostgresType postgresType, PgSerializerOptions options) + { + PgTypeInfo? typeInfo = null; + Exception? inner = null; + try + { + typeInfo = type == typeof(object) ? options.GetObjectOrDefaultTypeInfo(postgresType) : options.GetTypeInfo(type, postgresType); + } + catch (Exception ex) + { + inner = ex; + } + return typeInfo ?? ThrowReadingNotSupported(type, postgresType.DisplayName, inner); + + // InvalidCastException thrown to align with ADO.NET convention. + [DoesNotReturn] + static PgTypeInfo ThrowReadingNotSupported(Type? type, string displayName, Exception? inner = null) + => throw new InvalidCastException($"Reading{(type is null ? "" : $" as '{type.FullName}'")} is not supported for fields having DataTypeName '{displayName}'", inner); + } + + public static PgTypeInfo GetTypeInfoForWriting(Type? type, PgTypeId? pgTypeId, PgSerializerOptions options, NpgsqlDbType? npgsqlDbType = null) + { + Debug.Assert(type != typeof(object), "Parameters of type object are not supported."); + + PgTypeInfo? typeInfo = null; + Exception? inner = null; + try + { + typeInfo = type is null ? options.GetDefaultTypeInfo(pgTypeId!.Value) : options.GetTypeInfo(type, pgTypeId); + } + catch (Exception ex) + { + inner = ex; + } + return typeInfo ?? ThrowWritingNotSupported(type, options, pgTypeId, npgsqlDbType, inner); + + // InvalidCastException thrown to align with ADO.NET convention. + [DoesNotReturn] + static PgTypeInfo ThrowWritingNotSupported(Type? type, PgSerializerOptions options, PgTypeId? pgTypeId, NpgsqlDbType? npgsqlDbType, Exception? inner = null) + { + var pgTypeString = pgTypeId is null + ? "no NpgsqlDbType or DataTypeName. Try setting one of these values to the expected database type." + : npgsqlDbType is null + ? $"DataTypeName '{options.DatabaseInfo.FindPostgresType(pgTypeId.GetValueOrDefault())?.DisplayName ?? "unknown"}'" + : $"NpgsqlDbType '{npgsqlDbType}'"; + + throw new InvalidCastException( + $"Writing{(type is null ? "" : $" values of '{type.FullName}'")} is not supported for parameters having {pgTypeString}.", inner); + } + } +} diff --git a/src/Npgsql/Internal/BufferRequirements.cs b/src/Npgsql/Internal/BufferRequirements.cs new file mode 100644 index 0000000000..14ffabc52b --- /dev/null +++ b/src/Npgsql/Internal/BufferRequirements.cs @@ -0,0 +1,45 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct BufferRequirements : IEquatable +{ + readonly Size _read; + readonly Size _write; + + BufferRequirements(Size read, Size write) + { + _read = read; + _write = write; + } + + public Size Read => _read; + public Size Write => _write; + + /// Streaming + public static BufferRequirements None => new(Size.Unknown, Size.Unknown); + /// Entire value should be buffered + public static BufferRequirements Value => new(Size.CreateUpperBound(int.MaxValue), Size.CreateUpperBound(int.MaxValue)); + /// Fixed size value should be buffered + public static BufferRequirements CreateFixedSize(int byteCount) => new(byteCount, byteCount); + /// Custom requirements + public static BufferRequirements Create(Size value) => new(value, value); + public static BufferRequirements Create(Size read, Size write) => new(read, write); + + public BufferRequirements Combine(Size read, Size write) + => new(_read.Combine(read), _write.Combine(write)); + + public BufferRequirements Combine(BufferRequirements other) + => Combine(other._read, other._write); + + public BufferRequirements Combine(int byteCount) + => Combine(CreateFixedSize(byteCount)); + + public bool Equals(BufferRequirements other) => _read.Equals(other._read) && _write.Equals(other._write); + public override bool Equals(object? obj) => obj is BufferRequirements other && Equals(other); + public override int GetHashCode() => HashCode.Combine(_read, _write); + public static bool operator ==(BufferRequirements left, BufferRequirements right) => left.Equals(right); + public static bool operator !=(BufferRequirements left, BufferRequirements right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/ChainTypeInfoResolver.cs b/src/Npgsql/Internal/ChainTypeInfoResolver.cs new file mode 100644 index 0000000000..18c39d80b6 --- /dev/null +++ b/src/Npgsql/Internal/ChainTypeInfoResolver.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class ChainTypeInfoResolver : IPgTypeInfoResolver +{ + readonly IPgTypeInfoResolver[] _resolvers; + + public ChainTypeInfoResolver(IEnumerable resolvers) + => _resolvers = new List(resolvers).ToArray(); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + foreach (var resolver in _resolvers) + { + if (resolver.GetTypeInfo(type, dataTypeName, options) is { } info) + return info; + } + + return null; + } +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs new file mode 100644 index 0000000000..c51c0dafa0 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs @@ -0,0 +1,109 @@ +using System; +using System.Buffers; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +abstract class CompositeBuilder +{ + protected StrongBox[] _tempBoxes; + protected int _currentField; + + protected CompositeBuilder(StrongBox[] tempBoxes) => _tempBoxes = tempBoxes; + + protected abstract void Construct(); + protected abstract void SetField(TValue value); + + public void AddValue(TValue value) + { + var tempBoxes = _tempBoxes; + var currentField = _currentField; + if (currentField >= tempBoxes.Length) + { + if (currentField == tempBoxes.Length) + Construct(); + SetField(value); + } + else + { + ((StrongBox)tempBoxes[currentField]).TypedValue = value; + if (currentField + 1 == tempBoxes.Length) + Construct(); + } + + _currentField++; + } +} + +sealed class CompositeBuilder : CompositeBuilder, IDisposable +{ + readonly CompositeInfo _compositeInfo; + T _instance = default!; + object? _boxedInstance; + + public CompositeBuilder(CompositeInfo compositeInfo) + : base(compositeInfo.CreateTempBoxes()) + => _compositeInfo = compositeInfo; + + public T Complete() + { + if (_currentField < _compositeInfo.Fields.Count) + throw new InvalidOperationException($"Missing values, expected: {_compositeInfo.Fields.Count} got: {_currentField}"); + + return (T)(_boxedInstance ?? _instance!); + } + + public void Reset() + { + _instance = default!; + _boxedInstance = null; + _currentField = 0; + foreach (var box in _tempBoxes) + box.Clear(); + } + + public void Dispose() => Reset(); + + protected override void Construct() + { + var tempBoxes = _tempBoxes; + if (_currentField < tempBoxes.Length - 1) + throw new InvalidOperationException($"Missing values, expected: {tempBoxes.Length} got: {_currentField + 1}"); + + var fields = _compositeInfo.Fields; + var args = ArrayPool.Shared.Rent(_compositeInfo.ConstructorParameters); + for (var i = 0; i < tempBoxes.Length; i++) + { + var field = fields[i]; + if (field.ConstructorParameterIndex is { } argIndex) + args[argIndex] = tempBoxes[i]; + } + _instance = _compositeInfo.Constructor(args)!; + ArrayPool.Shared.Return(args); + + if (tempBoxes.Length == _compositeInfo.Fields.Count) + return; + + // We're expecting or already have stored more fields, so box the instance once here. + _boxedInstance = _instance; + for (var i = 0; i < tempBoxes.Length; i++) + { + var field = _compositeInfo.Fields[i]; + if (field.ConstructorParameterIndex is null) + field.Set(_boxedInstance, tempBoxes[i]); + } + } + + protected override void SetField(TValue value) + { + if (_boxedInstance is null) + ThrowHelper.ThrowInvalidOperationException("Not constructed yet, or no more fields were expected."); + + var currentField = _currentField; + var fields = _compositeInfo.Fields; + if (currentField > fields.Count - 1) + ThrowHelper.ThrowIndexOutOfRangeException($"Cannot set field {value} at position {currentField} - all fields have already been set"); + + ((CompositeFieldInfo)fields[currentField]).Set(_boxedInstance, value); + } +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs new file mode 100644 index 0000000000..ea8ca838d4 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs @@ -0,0 +1,261 @@ +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +abstract class CompositeFieldInfo +{ + protected PgTypeInfo PgTypeInfo { get; } + protected PgConverter? Converter { get; } + protected BufferRequirements _binaryBufferRequirements; + + /// + /// CompositeFieldInfo constructor. + /// + /// Name of the field. + /// Type info for reading/writing. + /// The nominal field type, this may differ from the typeInfo.PgTypeId when the field is a domain type. + private protected CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId) + { + Name = name; + PgTypeInfo = typeInfo; + PgTypeId = nominalPgTypeId; + + if (typeInfo.PgTypeId is null) + ThrowHelper.ThrowArgumentException("PgTypeInfo must have a PgTypeId."); + + if (!typeInfo.IsResolverInfo) + { + var resolution = typeInfo.GetResolution(); + if (typeInfo.GetBufferRequirements(resolution.Converter, DataFormat.Binary) is not { } bufferRequirements) + { + ThrowHelper.ThrowInvalidOperationException("Converter must support binary format to participate in composite types."); + return; + } + _binaryBufferRequirements = bufferRequirements; + Converter = resolution.Converter; + } + } + + public PgConverter GetReadInfo(out Size readRequirement) + { + if (Converter is not null) + { + readRequirement = _binaryBufferRequirements.Read; + return Converter; + } + + if (!PgTypeInfo.TryBind(new Field(Name, PgTypeInfo.PgTypeId.GetValueOrDefault(), -1), DataFormat.Binary, out var converterInfo)) + ThrowHelper.ThrowInvalidOperationException("Converter must support binary format to participate in composite types."); + + readRequirement = converterInfo.BufferRequirement; + return converterInfo.Converter; + } + + public PgConverter GetWriteInfo(object instance, out Size writeRequirement) + { + if (Converter is null) + return BindValue(instance, out writeRequirement); + + writeRequirement = _binaryBufferRequirements.Write; + return Converter; + + } + + protected ValueTask ReadAsObject(bool async, PgConverter converter, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken) + { + if (async) + { + var task = converter.ReadAsObjectAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return Core(builder, task); + + AddValue(builder, task.Result); + } + else + AddValue(builder, converter.ReadAsObject(reader)); + return new(); +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + async ValueTask Core(CompositeBuilder builder, ValueTask task) + { + builder.AddValue(await task.ConfigureAwait(false)); + } + } + + protected ValueTask WriteAsObject(bool async, PgConverter converter, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return converter.WriteAsObjectAsync(writer, value, cancellationToken); + + converter.WriteAsObject(writer, value); + return new(); + } + + public string Name { get; } + public PgTypeId PgTypeId { get; } + public Size BinaryReadRequirement => Converter is not null ? _binaryBufferRequirements.Read : Size.Unknown; + public Size BinaryWriteRequirement => Converter is not null ? _binaryBufferRequirements.Write : Size.Unknown; + + public abstract Type Type { get; } + + protected abstract PgConverter BindValue(object instance, out Size writeRequirement); + protected abstract void AddValue(CompositeBuilder builder, object value); + + public abstract StrongBox CreateBox(); + public abstract void Set(object instance, StrongBox value); + public abstract int? ConstructorParameterIndex { get; } + public abstract bool IsDbNullable { get; } + + public abstract void ReadDbNull(CompositeBuilder builder); + public abstract ValueTask Read(bool async, PgConverter converter, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken = default); + public abstract bool IsDbNull(PgConverter converter, object instance, ref object? writeState); + public abstract Size? GetSizeOrDbNull(PgConverter converter, DataFormat format, Size writeRequirement, object instance, ref object? writeState); + public abstract ValueTask Write(bool async, PgConverter converter, PgWriter writer, object instance, CancellationToken cancellationToken); +} + +sealed class CompositeFieldInfo : CompositeFieldInfo +{ + readonly Action? _setter; + readonly int _parameterIndex; + readonly Func _getter; + readonly bool _asObject; + + CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Func getter) + : base(name, typeInfo, nominalPgTypeId) + { + if (typeInfo.Type != typeof(T)) + throw new InvalidOperationException($"PgTypeInfo type '{typeInfo.Type.FullName}' must be equal to field type '{typeof(T)}'."); + + if (!typeInfo.IsResolverInfo) + { + var resolution = typeInfo.GetResolution(); + var typeToConvert = resolution.Converter.TypeToConvert; + _asObject = typeToConvert != typeof(T); + if (!typeToConvert.IsAssignableFrom(typeof(T))) + throw new InvalidOperationException($"Converter type '{typeToConvert.FullName}' must be assignable from field type '{typeof(T)}'."); + } + + _getter = getter; + } + + public CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Func getter, int parameterIndex) + : this(name, typeInfo, nominalPgTypeId, getter) + => _parameterIndex = parameterIndex; + + public CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Func getter, Action setter) + : this(name, typeInfo, nominalPgTypeId, getter) + => _setter = setter; + + bool AsObject(PgConverter converter) + => ReferenceEquals(Converter, converter) ? _asObject : converter.TypeToConvert != typeof(T); + + public override Type Type => typeof(T); + + public override int? ConstructorParameterIndex => _setter is not null ? null : _parameterIndex; + + public T Get(object instance) => _getter(instance); + + public override StrongBox CreateBox() => new Util.StrongBox(); + + public void Set(object instance, T value) + { + if (_setter is null) + throw new InvalidOperationException("Not a composite field for a clr field."); + + _setter(instance, value); + } + + public override void Set(object instance, StrongBox value) + { + if (_setter is null) + throw new InvalidOperationException("Not a composite field for a clr field."); + + _setter(instance, ((Util.StrongBox)value).TypedValue!); + } + + public override void ReadDbNull(CompositeBuilder builder) + { + if (default(T) != null) + throw new InvalidCastException($"Type {typeof(T).FullName} does not have null as a possible value."); + + builder.AddValue((T?)default); + } + + protected override PgConverter BindValue(object instance, out Size writeRequirement) + { + var value = _getter(instance); + var resolution = PgTypeInfo.IsBoxing ? PgTypeInfo.GetObjectResolution(value) : PgTypeInfo.GetResolution(value); + if (PgTypeInfo.GetBufferRequirements(resolution.Converter, DataFormat.Binary) is not { } bufferRequirements) + { + ThrowHelper.ThrowInvalidOperationException("Converter must support binary format to participate in composite types."); + writeRequirement = default; + return default; + } + + writeRequirement = bufferRequirements.Write; + return resolution.Converter; + } + + protected override void AddValue(CompositeBuilder builder, object value) => builder.AddValue((T)value); + + public override ValueTask Read(bool async, PgConverter converter, CompositeBuilder builder, PgReader reader, CancellationToken cancellationToken = default) + { + if (AsObject(converter)) + return ReadAsObject(async, converter, builder, reader, cancellationToken); + + if (async) + { + var task = ((PgConverter)converter).ReadAsync(reader, cancellationToken); + if (!task.IsCompletedSuccessfully) + return Core(builder, task); + + builder.AddValue(task.Result); + } + else + builder.AddValue(((PgConverter)converter).Read(reader)); + return new(); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + async ValueTask Core(CompositeBuilder builder, ValueTask task) + { + builder.AddValue(await task.ConfigureAwait(false)); + } + } + + public override bool IsDbNullable => Converter?.IsDbNullable ?? true; + + public override bool IsDbNull(PgConverter converter, object instance, ref object? writeState) + { + var value = _getter(instance); + return AsObject(converter) ? converter.IsDbNullAsObject(value, ref writeState) : ((PgConverter)converter).IsDbNull(value, ref writeState); + } + + public override Size? GetSizeOrDbNull(PgConverter converter, DataFormat format, Size writeRequirement, object instance, ref object? writeState) + { + var value = _getter(instance); + return AsObject(converter) + ? converter.GetSizeOrDbNullAsObject(format, writeRequirement, value, ref writeState) + : ((PgConverter)converter).GetSizeOrDbNull(format, writeRequirement, value, ref writeState); + } + + public override ValueTask Write(bool async, PgConverter converter, PgWriter writer, object instance, CancellationToken cancellationToken) + { + var value = _getter(instance); + if (AsObject(converter)) + return WriteAsObject(async, converter, writer, value!, cancellationToken); + + if (async) + return ((PgConverter)converter).WriteAsync(writer, value!, cancellationToken); + + ((PgConverter)converter).Write(writer, value!); + return new(); + } +} diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs new file mode 100644 index 0000000000..1db91b2052 --- /dev/null +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using Npgsql.Util; + +namespace Npgsql.Internal.Composites; + +sealed class CompositeInfo +{ + readonly int _lastConstructorFieldIndex; + readonly CompositeFieldInfo[] _fields; + + public CompositeInfo(CompositeFieldInfo[] fields, int constructorParameters, Func constructor) + { + _lastConstructorFieldIndex = -1; + for (var i = fields.Length - 1; i >= 0; i--) + if (fields[i].ConstructorParameterIndex is not null) + { + _lastConstructorFieldIndex = i; + break; + } + + var parameterSum = 0; + for (var i = constructorParameters - 1; i > 0; i--) + parameterSum += i; + + var argumentsSum = 0; + if (parameterSum > 0) + { + foreach (var field in fields) + if (field.ConstructorParameterIndex is { } index) + argumentsSum += index; + } + + if (parameterSum != argumentsSum) + throw new InvalidOperationException($"Missing composite fields to map to the required {constructorParameters} constructor parameters."); + + _fields = fields; + var arguments = constructorParameters is 0 ? Array.Empty() : new CompositeFieldInfo[constructorParameters]; + foreach (var field in fields) + { + if (field.ConstructorParameterIndex is { } index) + arguments[index] = field; + } + Constructor = constructor; + ConstructorParameters = constructorParameters; + } + + public IReadOnlyList Fields => _fields; + + public int ConstructorParameters { get; } + public Func Constructor { get; } + + /// + /// Create temporary storage for all values that come before the constructor parameters can be saturated. + /// + /// + public StrongBox[] CreateTempBoxes() + { + var valueCache = _lastConstructorFieldIndex + 1 is 0 ? Array.Empty() : new StrongBox[_lastConstructorFieldIndex + 1]; + var fields = _fields; + + for (var i = 0; i < valueCache.Length; i++) + valueCache[i] = fields[i].CreateBox(); + + return valueCache; + } +} diff --git a/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs new file mode 100644 index 0000000000..4db2264235 --- /dev/null +++ b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs @@ -0,0 +1,304 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using Npgsql.Util; +using NpgsqlTypes; + +namespace Npgsql.Internal.Composites; + +[RequiresDynamicCode("Serializing arbitrary types can require creating new generic types or methods. This may not work when AOT compiling.")] +static class ReflectionCompositeInfoFactory +{ + public static CompositeInfo CreateCompositeInfo<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T>( + PostgresCompositeType pgType, INpgsqlNameTranslator nameTranslator, PgSerializerOptions options) + { + var pgFields = pgType.Fields; + var propertyMap = MapProperties(pgFields, nameTranslator); + var fieldMap = MapFields(pgFields, nameTranslator); + + var duplicates = propertyMap.Keys.Intersect(fieldMap.Keys).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Property {propertyMap[duplicates[0]].Name} and field {fieldMap[duplicates[0]].Name} map to the same '{pgFields[duplicates[0]].Name}' composite field name."); + + var (constructorInfo, parameterFieldMap) = MapBestMatchingConstructor(pgFields, nameTranslator); + var constructorParameters = constructorInfo?.GetParameters() ?? Array.Empty(); + var compositeFields = new CompositeFieldInfo?[pgFields.Count]; + for (var i = 0; i < parameterFieldMap.Length; i++) + { + var fieldIndex = parameterFieldMap[i]; + var pgField = pgFields[fieldIndex]; + var parameter = constructorParameters[i]; + PgTypeInfo pgTypeInfo; + Delegate getter; + if (propertyMap.TryGetValue(fieldIndex, out var property) && property.GetMethod is not null) + { + if (property.PropertyType != parameter.ParameterType) + throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); + + pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); + getter = CreateGetter(property); + } + else if (fieldMap.TryGetValue(fieldIndex, out var field)) + { + if (field.FieldType != parameter.ParameterType) + throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); + + pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); + getter = CreateGetter(field); + } + else + throw new InvalidOperationException($"Cannot find property or field for composite field {pgFields[fieldIndex].Name}."); + + compositeFields[fieldIndex] = CreateCompositeFieldInfo(pgField.Name, pgTypeInfo.Type, pgTypeInfo, options.ToCanonicalTypeId(pgField.Type), getter, i); + } + + for (var fieldIndex = 0; fieldIndex < pgFields.Count; fieldIndex++) + { + // Handled by constructor. + if (compositeFields[fieldIndex] is not null) + continue; + + var pgField = pgFields[fieldIndex]; + PgTypeInfo pgTypeInfo; + Delegate getter; + Delegate setter; + if (propertyMap.TryGetValue(fieldIndex, out var property)) + { + pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) + ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); + getter = CreateGetter(property); + setter = CreateSetter(property); + } + else if (fieldMap.TryGetValue(fieldIndex, out var field)) + { + pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) + ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); + getter = CreateGetter(field); + setter = CreateSetter(field); + } + else + throw new InvalidOperationException($"Cannot find property or field for composite field '{pgFields[fieldIndex].Name}'."); + + compositeFields[fieldIndex] = CreateCompositeFieldInfo(pgField.Name, pgTypeInfo.Type, pgTypeInfo, options.ToCanonicalTypeId(pgField.Type), getter, setter); + } + + Debug.Assert(compositeFields.All(x => x is not null)); + + var constructor = constructorInfo is null ? _ => Activator.CreateInstance() : CreateStrongBoxConstructor(constructorInfo); + return new CompositeInfo(compositeFields!, constructorInfo is null ? 0 : constructorParameters.Length, constructor); + + static NotSupportedException NotSupportedField(PostgresCompositeType composite, PostgresCompositeType.Field field, bool isField, string name, Type type) + => new($"No mapping could be found for ('{type.FullName}', '{field.Type.FullName}'). Mapping: CLR {(isField ? "field" : "property")} '{typeof(T).FullName}.{name}' <-> Composite field '{composite.Name}.{field.Name}'"); + } + + static Delegate CreateGetter(FieldInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + return Expression + .Lambda(typeof(Func<,>).MakeGenericType(typeof(object), info.FieldType), + Expression.Field(UnboxAny(instance, typeof(T)), info), + instance) + .Compile(); + } + + static Delegate CreateSetter(FieldInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + var value = Expression.Parameter(info.FieldType, "value"); + + return Expression + .Lambda(typeof(Action<,>).MakeGenericType(typeof(object), info.FieldType), + Expression.Assign(Expression.Field(UnboxAny(instance, typeof(T)), info), value), instance, value) + .Compile(); + } + + static Delegate CreateGetter(PropertyInfo info) + { + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var instance = Expression.Parameter(typeof(object), "instance"); + var body = info.GetMethod is null || !info.GetMethod.IsPublic + ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, + Expression.Constant($"No (public) getter for '{info}' on type {typeof(T)}")), info.PropertyType) + : Expression.Property(UnboxAny(instance, typeof(T)), info); + + return Expression + .Lambda(typeof(Func<,>).MakeGenericType(typeof(object), info.PropertyType), body, instance) + .Compile(); + } + + static Delegate CreateSetter(PropertyInfo info) + { + var instance = Expression.Parameter(typeof(object), "instance"); + var value = Expression.Parameter(info.PropertyType, "value"); + + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var body = info.SetMethod is null || !info.SetMethod.IsPublic + ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, + Expression.Constant($"No (public) setter for '{info}' on type {typeof(T)}")), info.PropertyType) + : Expression.Call(UnboxAny(instance, typeof(T)), info.SetMethod, value); + + return Expression + .Lambda(typeof(Action<,>).MakeGenericType(typeof(object), info.PropertyType), body, instance, value) + .Compile(); + } + + static Expression UnboxAny(Expression expression, Type type) + => type.IsValueType ? Expression.Unbox(expression, type) : Expression.Convert(expression, type, null); + +#if !NETSTANDARD + [DynamicDependency("TypedValue", typeof(StrongBox<>))] + [DynamicDependency("Length", typeof(StrongBox[]))] +#endif + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "DynamicDependencies in place for the System.Linq.Expression.Property calls")] + static Func CreateStrongBoxConstructor(ConstructorInfo constructorInfo) + { + var values = Expression.Parameter(typeof(StrongBox[]), "values"); + + var parameters = constructorInfo.GetParameters(); + var parameterCount = Expression.Constant(parameters.Length); + var argumentExceptionNameMessageConstructor = typeof(ArgumentException).GetConstructor(new []{ typeof(string), typeof(string) })!; + return Expression + .Lambda>( + Expression.Block( + Expression.IfThen( + Expression.LessThan(Expression.Property(values, "Length"), parameterCount), + + Expression.Throw(Expression.New(argumentExceptionNameMessageConstructor, + Expression.Constant("Passed fewer arguments than there are constructor parameters."), Expression.Constant(values.Name))) + ), + Expression.New(constructorInfo, parameters.Select((parameter, i) => + Expression.Property( + UnboxAny( + Expression.ArrayIndex(values, Expression.Constant(i)), + typeof(StrongBox<>).MakeGenericType(parameter.ParameterType) + ), + "TypedValue" + ) + )) + ), values) + .Compile(); + } + static CompositeFieldInfo CreateCompositeFieldInfo(string name, Type type, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Delegate getter, int constructorParameterIndex) + => (CompositeFieldInfo)Activator.CreateInstance( + typeof(CompositeFieldInfo<>).MakeGenericType(type), name, typeInfo, nominalPgTypeId, getter, constructorParameterIndex)!; + + static CompositeFieldInfo CreateCompositeFieldInfo(string name, Type type, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Delegate getter, Delegate setter) + => (CompositeFieldInfo)Activator.CreateInstance( + typeof(CompositeFieldInfo<>).MakeGenericType(type), name, typeInfo, nominalPgTypeId, getter, setter)!; + + static Dictionary MapProperties<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + var properties = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance); + var propertiesAndNames = properties.Select(x => + { + var attr = x.GetCustomAttribute(); + var name = attr?.PgName ?? nameTranslator.TranslateMemberName(x.Name); + return new KeyValuePair(name, x); + }).ToArray(); + + var duplicates = propertiesAndNames.GroupBy(x => x.Key).Where(g => g.Count() > 1).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Multiple properties are mapped to the '{duplicates[0].Key}' field."); + + var propertiesMap = propertiesAndNames.ToDictionary(x => x.Key, x => x.Value); + var result = new Dictionary(); + for (var i = 0; i < fields.Count; i++) + { + var field = fields[i]; + if (!propertiesMap.TryGetValue(field.Name, out var value)) + continue; + + result[i] = value; + } + + return result; + } + + static Dictionary MapFields<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] T>(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + var clrFields = typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance); + var clrFieldsAndNames = clrFields.Select(x => + { + var attr = x.GetCustomAttribute(); + var name = attr?.PgName ?? nameTranslator.TranslateMemberName(x.Name); + return new KeyValuePair(name, x); + }).ToArray(); + + var duplicates = clrFieldsAndNames.GroupBy(x => x.Key).Where(g => g.Count() > 1).ToArray(); + if (duplicates.Length > 0) + throw new AmbiguousMatchException($"Multiple properties are mapped to the '{duplicates[0].Key}' field."); + + var clrFieldsMap = clrFieldsAndNames.ToDictionary(x => x.Key, x => x.Value); + var result = new Dictionary(); + for (var i = 0; i < fields.Count; i++) + { + var field = fields[i]; + if (!clrFieldsMap.TryGetValue(field.Name, out var value)) + continue; + + result[i] = value; + } + + return result; + } + + static (ConstructorInfo? ConstructorInfo, int[] ParameterFieldMap) MapBestMatchingConstructor<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>(IReadOnlyList fields, INpgsqlNameTranslator nameTranslator) + { + ConstructorInfo? clrDefaultConstructor = null; + Exception? duplicatesException = null; + foreach (var constructor in typeof(T).GetConstructors().OrderByDescending(x => x.GetParameters().Length)) + { + var parameters = constructor.GetParameters(); + if (parameters.Length == 0) + clrDefaultConstructor = constructor; + + var parametersMap = new int[parameters.Length]; +#if NETSTANDARD2_0 + for (var i = 0; i < parametersMap.Length; i++) + parametersMap[i] = -1; +#else + Array.Fill(parametersMap, -1); +#endif + for (var i = 0; i < parameters.Length; i++) + { + var clrParameter = parameters[i]; + var attr = clrParameter.GetCustomAttribute(); + var name = attr?.PgName ?? (clrParameter.Name is { } clrName ? nameTranslator.TranslateMemberName(clrName) : null); + if (name is null) + break; + + for (var pgFieldIndex = 0; pgFieldIndex < fields.Count; pgFieldIndex++) + { + if (fields[pgFieldIndex].Name == name) + { + parametersMap[i] = pgFieldIndex; + break; + } + } + } + + if (parametersMap.Any(x => x is -1)) + continue; + + var duplicates = parametersMap.GroupBy(x => x).Where(g => g.Count() > 1).ToArray(); + if (duplicates.Length is 0) + return (constructor, parametersMap); + + duplicatesException = new AmbiguousMatchException($"Multiple parameters are mapped to the field '{fields[duplicates[0].Key].Name}' in constructor: {constructor}."); + } + + if (duplicatesException is not null) + throw duplicatesException; + + if (clrDefaultConstructor is null && !typeof(T).IsValueType) + throw new InvalidOperationException($"No parameterless constructor defined for type '{typeof(T)}'."); + + return (clrDefaultConstructor, []); + } +} diff --git a/src/Npgsql/Internal/Converters/ArrayConverter.cs b/src/Npgsql/Internal/Converters/ArrayConverter.cs new file mode 100644 index 0000000000..a43b500812 --- /dev/null +++ b/src/Npgsql/Internal/Converters/ArrayConverter.cs @@ -0,0 +1,699 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +interface IElementOperations +{ + object CreateCollection(int[] lengths); + int GetCollectionCount(object collection, out int[]? lengths); + Size? GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState); + ValueTask Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken = default); + ValueTask Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken = default); +} + +readonly struct PgArrayConverter +{ + internal const string ReadNonNullableCollectionWithNullsExceptionMessage = "Cannot read a non-nullable collection of elements because the returned array contains nulls. Call GetFieldValue with a nullable collection type instead."; + + readonly IElementOperations _elemOps; + readonly int? _expectedDimensions; + readonly BufferRequirements _bufferRequirements; + public bool ElemTypeDbNullable { get; } + readonly int _pgLowerBound; + readonly PgTypeId _elemTypeId; + + public PgArrayConverter(IElementOperations elemOps, bool elemTypeDbNullable, int? expectedDimensions, BufferRequirements bufferRequirements, PgTypeId elemTypeId, int pgLowerBound = 1) + { + _elemTypeId = elemTypeId; + ElemTypeDbNullable = elemTypeDbNullable; + _pgLowerBound = pgLowerBound; + _elemOps = elemOps; + _expectedDimensions = expectedDimensions; + _bufferRequirements = bufferRequirements; + } + + bool IsDbNull(object values, int[] indices) + { + object? state = null; + return _elemOps.GetSizeOrDbNull(new(DataFormat.Binary, _bufferRequirements.Write), values, indices, ref state) is null; + } + + Size GetElemsSize(object values, (Size, object?)[] elemStates, out bool anyElementState, DataFormat format, int count, int[] indices, int[]? lengths = null) + { + Debug.Assert(elemStates.Length >= count); + var totalSize = Size.Zero; + var context = new SizeContext(format, _bufferRequirements.Write); + anyElementState = false; + var lastLength = lengths?[lengths.Length - 1] ?? count; + ref var lastIndex = ref indices[indices.Length - 1]; + var i = 0; + do + { + ref var elemItem = ref elemStates[i++]; + var elemState = (object?)null; + var size = _elemOps.GetSizeOrDbNull(context, values, indices, ref elemState); + anyElementState = anyElementState || elemState is not null; + elemItem = (size ?? -1, elemState); + totalSize = totalSize.Combine(size ?? 0); + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + + return totalSize; + } + + Size GetFixedElemsSize(Size elemSize, object values, int count, int[] indices, int[]? lengths = null) + { + var nulls = 0; + var lastLength = lengths?[lengths.Length - 1] ?? count; + ref var lastIndex = ref indices[indices.Length - 1]; + if (ElemTypeDbNullable) + do + { + if (IsDbNull(values, indices)) + nulls++; + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + + return (count - nulls) * elemSize.Value; + } + + int GetFormatSize(int count, int dimensions) + => sizeof(int) + // Dimensions + sizeof(int) + // Flags + sizeof(int) + // Element OID + dimensions * (sizeof(int) + sizeof(int)) + // Dimensions * (array length and lower bound) + sizeof(int) * count; // Element length integers + + public Size GetSize(SizeContext context, object values, ref object? writeState) + { + var count = _elemOps.GetCollectionCount(values, out var lengths); + var dimensions = lengths?.Length ?? 1; + if (dimensions > 8) + throw new ArgumentException(nameof(values), "Postgres arrays can have at most 8 dimensions."); + + var formatSize = Size.Create(GetFormatSize(count, dimensions)); + if (count is 0) + return formatSize; + + Size elemsSize; + var indices = new int[dimensions]; + if (_bufferRequirements.Write is { Kind: SizeKind.Exact } req) + { + elemsSize = GetFixedElemsSize(req, values, count, indices, lengths); + writeState = new WriteState { Count = count, Indices = indices, Lengths = lengths, ArrayPool = null, Data = default, AnyWriteState = false }; + } + else + { + var arrayPool = ArrayPool<(Size, object?)>.Shared; + var data = ArrayPool<(Size, object?)>.Shared.Rent(count); + elemsSize = GetElemsSize(values, data, out var elemStateDisposable, context.Format, count, indices, lengths); + writeState = new WriteState + { Count = count, Indices = indices, Lengths = lengths, + ArrayPool = arrayPool, Data = new(data, 0, count), AnyWriteState = elemStateDisposable }; + } + + return formatSize.Combine(elemsSize); + } + + sealed class WriteState : MultiWriteState + { + public required int Count { get; init; } + public required int[] Indices { get; init; } + public required int[]? Lengths { get; init; } + } + + public async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken = default) + { + if (reader.ShouldBuffer(sizeof(int) + sizeof(int) + sizeof(uint))) + await reader.Buffer(async, sizeof(int) + sizeof(int) + sizeof(uint), cancellationToken).ConfigureAwait(false); + + var dimensions = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + _ = reader.ReadUInt32(); // Element OID. + + if (dimensions is not 0 && _expectedDimensions is not null && dimensions != _expectedDimensions) + ThrowHelper.ThrowInvalidCastException( + $"Cannot read an array value with {dimensions} dimension{(dimensions == 1 ? "" : "s")} into a " + + $"collection type with {_expectedDimensions} dimension{(_expectedDimensions == 1 ? "" : "s")}. " + + $"Call GetValue or a version of GetFieldValue with the commas being the expected amount of dimensions."); + + if (containsNulls && !ElemTypeDbNullable) + ThrowHelper.ThrowInvalidCastException(ReadNonNullableCollectionWithNullsExceptionMessage); + + // Make sure we can read length + lower bound N dimension times. + if (reader.ShouldBuffer((sizeof(int) + sizeof(int)) * dimensions)) + await reader.Buffer(async, (sizeof(int) + sizeof(int)) * dimensions, cancellationToken).ConfigureAwait(false); + + var dimLengths = new int[_expectedDimensions ?? dimensions]; + var lastDimLength = 0; + for (var i = 0; i < dimensions; i++) + { + lastDimLength = reader.ReadInt32(); + reader.ReadInt32(); // Lower bound + if (dimLengths.Length is 0) + break; + dimLengths[i] = lastDimLength; + } + + var collection = _elemOps.CreateCollection(dimLengths); + Debug.Assert(dimensions <= 1 || collection is Array a && a.Rank == dimensions); + + if (dimensions is 0 || lastDimLength is 0) + return collection; + + int[] indices; + // Reuse array for dim <= 1 + if (dimensions == 1) + { + dimLengths[0] = 0; + indices = dimLengths; + } + else + indices = new int[dimensions]; + do + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var length = reader.ReadInt32(); + var isDbNull = length == -1; + if (!isDbNull) + { + var scope = await reader.BeginNestedRead(async, length, _bufferRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + else + await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++indices[indices.Length - 1] < lastDimLength || (dimensions > 1 && CarryIndices(dimLengths, indices))); + + return collection; + } + + static bool CarryIndices(int[] lengths, int[] indices) + { + Debug.Assert(lengths.Length > 1); + + // Find the first dimension from the end that isn't at or past its length, increment it and bring all previous dimensions to zero. + for (var dim = indices.Length - 1; dim >= 0; dim--) + { + if (indices[dim] >= lengths[dim] - 1) + continue; + + indices.AsSpan().Slice(dim + 1).Clear(); + indices[dim]++; + return true; + } + + // We're done if we can't find any dimension that isn't at its length. + return false; + } + + public async ValueTask Write(bool async, PgWriter writer, object values, CancellationToken cancellationToken) + { + var (count, dims, state) = writer.Current.WriteState switch + { + WriteState writeState => (writeState.Count, writeState.Lengths?.Length ?? 1 , writeState), + null => (0, values is Array a ? a.Rank : 1, null), + _ => throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}.") + }; + + if (writer.ShouldFlush(GetFormatSize(count, dims))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(dims); // Dimensions + writer.WriteInt32(0); // Flags (not really used) + writer.WriteAsOid(_elemTypeId); + for (var dim = 0; dim < dims; dim++) + { + writer.WriteInt32(state?.Lengths?[dim] ?? count); + writer.WriteInt32(_pgLowerBound); // Lower bound + } + + // We can stop here for empty collections. + if (state is null) + return; + + var elemTypeDbNullable = ElemTypeDbNullable; + var elemData = state.Data.Array; + + var indices = state.Indices; + Array.Clear(indices, 0 , indices.Length); + var lastLength = state.Lengths?[state.Lengths.Length - 1] ?? state.Count; + var i = state.Data.Offset; + do + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var elem = elemData?[i++]; + var size = elem?.Size ?? (elemTypeDbNullable && IsDbNull(values, indices) ? -1 : _bufferRequirements.Write); + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (length != -1) + { + using var _ = await writer.BeginNestedWrite(async, _bufferRequirements.Write, length, elem?.WriteState, cancellationToken).ConfigureAwait(false); + await _elemOps.Write(async, writer, values, indices, cancellationToken).ConfigureAwait(false); + } + } + // We can immediately continue if we didn't reach the end of the last dimension. + while (++indices[indices.Length - 1] < lastLength || (indices.Length > 1 && CarryIndices(state.Lengths!, indices))); + } +} + +// Class constraint exists to make ValueTask to ValueTask reinterpretation safe, don't remove unless that is also removed. +abstract class ArrayConverter : PgStreamingConverter where T : class +{ + protected PgConverterResolution ElemResolution { get; } + protected Type ElemTypeToConvert { get; } + + readonly PgArrayConverter _pgArrayConverter; + + private protected ArrayConverter(int? expectedDimensions, PgConverterResolution elemResolution, int pgLowerBound = 1) + { + if (!elemResolution.Converter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Element converter has to support the binary format to be compatible."); + + ElemResolution = elemResolution; + ElemTypeToConvert = elemResolution.Converter.TypeToConvert; + _pgArrayConverter = new((IElementOperations)this, elemResolution.Converter.IsDbNullable, expectedDimensions, + bufferRequirements, elemResolution.PgTypeId, pgLowerBound); + } + + public override T Read(PgReader reader) => (T)_pgArrayConverter.Read(async: false, reader).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + var value = _pgArrayConverter.Read(async: true, reader, cancellationToken); + // Justification: elides the async method bloat/perf cost to transition from object to T (where T : class) + Debug.Assert(typeof(T).IsClass); + return Unsafe.As, ValueTask>(ref value); + } + + public override Size GetSize(SizeContext context, T values, ref object? writeState) + => _pgArrayConverter.GetSize(context, values, ref writeState); + + public override void Write(PgWriter writer, T values) + => _pgArrayConverter.Write(async: false, writer, values, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T values, CancellationToken cancellationToken = default) + => _pgArrayConverter.Write(async: true, writer, values, cancellationToken); + + // Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. + // As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. + // The alternatives are: + // 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). + // 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse as well. +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + private protected static async ValueTask AwaitTask(Task task, Continuation continuation, object collection, int[] indices) + { + await task.ConfigureAwait(false); + continuation.Invoke(task, collection, indices); + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. + protected readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public void Invoke(Task task, object collection, int[] indices) => _continuation(task, collection, indices); + } + + protected static int[]? GetLengths(Array array) + { + if (array.Rank == 1) + return null; + + var lengths = new int[array.Rank]; + for (var i = 0; i < lengths.Length; i++) + lengths[i] = array.GetLength(i); + + return lengths; + } +} + +sealed class ArrayBasedArrayConverter : ArrayConverter, IElementOperations where T : class +{ + readonly PgConverter _elemConverter; + + public ArrayBasedArrayConverter(PgConverterResolution elemResolution, Type? effectiveType = null, int pgLowerBound = 1) + : base( + expectedDimensions: effectiveType is null ? 1 : effectiveType.IsArray ? effectiveType.GetArrayRank() : null, + elemResolution, pgLowerBound) + => _elemConverter = elemResolution.GetConverter(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static TElement? GetValue(object collection, int[] indices) + { + switch (indices.Length) + { + case 1: + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is TElement?[]); + return Unsafe.As(collection)[indices[0]]; + default: + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is Array); + return (TElement?)Unsafe.As(collection).GetValue(indices); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SetValue(object collection, int[] indices, TElement? value) + { + switch (indices.Length) + { + case 1: + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is TElement?[]); + Unsafe.As(collection)[indices[0]] = value; + break; + default: + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is Array); + Unsafe.As(collection).SetValue(value, indices); + break; + } + } + + object IElementOperations.CreateCollection(int[] lengths) + => lengths.Length switch + { + 0 => Array.Empty(), + 1 when lengths[0] == 0 => Array.Empty(), + 1 => new TElement?[lengths[0]], + 2 => new TElement?[lengths[0], lengths[1]], + 3 => new TElement?[lengths[0], lengths[1], lengths[2]], + 4 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3]], + 5 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4]], + 6 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5]], + 7 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6]], + 8 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6], lengths[7]], + _ => throw new InvalidOperationException("Postgres arrays can have at most 8 dimensions.") + }; + + int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) + { + var array = (Array)collection; + lengths = GetLengths(array); + return array.Length; + } + + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) + => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices), ref writeState); + + ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + { + if (!isDbNull && async && _elemConverter is PgStreamingConverter streamingConverter) + return ReadAsync(streamingConverter, reader, collection, indices, cancellationToken); + + SetValue(collection, indices, isDbNull ? default : _elemConverter.Read(reader)); + return new(); + } + + unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, int[] indices, CancellationToken cancellationToken) + { + if (converter.ReadAsyncAsTask(reader, cancellationToken, out var result) is { } task) + return AwaitTask(task, new(this, &SetResult), collection, indices); + + SetValue(collection, indices, result); + return new(); + + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + static void SetResult(Task task, object collection, int[] indices) + { + SetValue(collection, indices, new ValueTask((Task)task).Result); + } + } + + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + { + if (async) + return _elemConverter.WriteAsync(writer, GetValue(collection, indices)!, cancellationToken); + + _elemConverter.Write(writer, GetValue(collection, indices)!); + return new(); + } +} + +sealed class ListBasedArrayConverter : ArrayConverter, IElementOperations where T : class +{ + readonly PgConverter _elemConverter; + + public ListBasedArrayConverter(PgConverterResolution elemResolution, int pgLowerBound = 1) + : base(expectedDimensions: 1, elemResolution, pgLowerBound) + => _elemConverter = elemResolution.GetConverter(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static TElement? GetValue(object collection, int index) + { + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is IList); + return Unsafe.As>(collection)[index]; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SetValue(object collection, int index, TElement? value) + { + // Justification: avoid the cast overhead for per element calls. + Debug.Assert(collection is IList); + var list = Unsafe.As>(collection); + list.Insert(index, value); + } + + object IElementOperations.CreateCollection(int[] lengths) + => new List(lengths.Length is 0 ? 0 : lengths[0]); + + int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) + { + lengths = null; + return ((IList)collection).Count; + } + + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) + => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices[0]), ref writeState); + + ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + { + Debug.Assert(indices.Length is 1); + if (!isDbNull && async && _elemConverter is PgStreamingConverter streamingConverter) + return ReadAsync(streamingConverter, reader, collection, indices, cancellationToken); + + SetValue(collection, indices[0], isDbNull ? default : _elemConverter.Read(reader)); + return new(); + } + + unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, int[] indices, CancellationToken cancellationToken) + { + if (converter.ReadAsyncAsTask(reader, cancellationToken, out var result) is { } task) + return AwaitTask(task, new(this, &SetResult), collection, indices); + + SetValue(collection, indices[0], result); + return new(); + + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + static void SetResult(Task task, object collection, int[] indices) + { + SetValue(collection, indices[0], new ValueTask((Task)task).Result); + } + } + + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + { + Debug.Assert(indices.Length is 1); + if (async) + return _elemConverter.WriteAsync(writer, GetValue(collection, indices[0])!, cancellationToken); + + _elemConverter.Write(writer, GetValue(collection, indices[0])!); + return new(); + } +} + +sealed class ArrayConverterResolver : PgComposingConverterResolver where T : class +{ + readonly Type _effectiveType; + + public ArrayConverterResolver(PgResolverTypeInfo elementTypeInfo, Type effectiveType) + : base(elementTypeInfo.PgTypeId is { } id ? elementTypeInfo.Options.GetArrayTypeId(id) : null, elementTypeInfo) + => _effectiveType = effectiveType; + + PgSerializerOptions Options => EffectiveTypeInfo.Options; + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => Options.GetArrayElementTypeId(pgTypeId); + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => Options.GetArrayTypeId(effectivePgTypeId); + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + { + if (typeof(T) == typeof(Array) || typeof(T).IsArray) + return new ArrayBasedArrayConverter(effectiveResolution, _effectiveType); + + if (typeof(T).IsConstructedGenericType && typeof(T).GetGenericTypeDefinition() == typeof(IList<>)) + return new ListBasedArrayConverter(effectiveResolution); + + throw new NotSupportedException($"Unknown type T: {typeof(T).FullName}"); + } + + protected override PgConverterResolution? GetEffectiveResolution(T? values, PgTypeId? expectedEffectivePgTypeId) + { + PgConverterResolution? resolution = null; + if (values is null) + { + resolution = EffectiveTypeInfo.GetDefaultResolution(expectedEffectivePgTypeId); + } + else + { + switch (values) + { + case TElement[] array: + foreach (var value in array) + { + var result = EffectiveTypeInfo.GetResolution(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + case List list: + foreach (var value in list) + { + var result = EffectiveTypeInfo.GetResolution(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + case IList list: + foreach (var value in list) + { + var result = EffectiveTypeInfo.GetResolution(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + case Array array: + foreach (var value in array) + { + var result = EffectiveTypeInfo.GetResolutionAsObject(value, resolution?.PgTypeId ?? expectedEffectivePgTypeId); + resolution ??= result; + } + break; + default: + throw new NotSupportedException(); + } + } + + return resolution; + } +} + +// T is Array as we only know what type it will be after reading 'contains nulls'. +sealed class PolymorphicArrayConverter : PgStreamingConverter +{ + readonly PgConverter _structElementCollectionConverter; + readonly PgConverter _nullableElementCollectionConverter; + + public PolymorphicArrayConverter(PgConverter structElementCollectionConverter, PgConverter nullableElementCollectionConverter) + { + _structElementCollectionConverter = structElementCollectionConverter; + _nullableElementCollectionConverter = nullableElementCollectionConverter; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(read: sizeof(int) + sizeof(int), write: Size.Unknown); + return format is DataFormat.Binary; + } + + public override TBase Read(PgReader reader) + { + _ = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + reader.Rewind(sizeof(int) + sizeof(int)); + return containsNulls + ? _nullableElementCollectionConverter.Read(reader) + : _structElementCollectionConverter.Read(reader); + } + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + _ = reader.ReadInt32(); + var containsNulls = reader.ReadInt32() is 1; + reader.Rewind(sizeof(int) + sizeof(int)); + return containsNulls + ? _nullableElementCollectionConverter.ReadAsync(reader, cancellationToken) + : _structElementCollectionConverter.ReadAsync(reader, cancellationToken); + } + + public override Size GetSize(SizeContext context, TBase value, ref object? writeState) + => throw new NotSupportedException("Polymorphic writing is not supported"); + + public override void Write(PgWriter writer, TBase value) + => throw new NotSupportedException("Polymorphic writing is not supported"); + + public override ValueTask WriteAsync(PgWriter writer, TBase value, CancellationToken cancellationToken = default) + => throw new NotSupportedException("Polymorphic writing is not supported"); +} + +sealed class PolymorphicArrayConverterResolver : PolymorphicConverterResolver +{ + readonly PgResolverTypeInfo _effectiveInfo; + readonly PgResolverTypeInfo _effectiveNullableInfo; + readonly ConcurrentDictionary _converterCache = new(ReferenceEqualityComparer.Instance); + + public PolymorphicArrayConverterResolver(PgResolverTypeInfo effectiveInfo, PgResolverTypeInfo effectiveNullableInfo) + : base(effectiveInfo.PgTypeId!.Value) + { + if (effectiveInfo.PgTypeId is null || effectiveNullableInfo.PgTypeId is null) + throw new InvalidOperationException("Cannot accept undecided infos"); + + _effectiveInfo = effectiveInfo; + _effectiveNullableInfo = effectiveNullableInfo; + } + + protected override PgConverter Get(Field? maybeField) + { + var structResolution = maybeField is { } field + ? _effectiveInfo.GetResolution(field) + : _effectiveInfo.GetDefaultResolution(PgTypeId); + var nullableResolution = maybeField is { } field2 + ? _effectiveNullableInfo.GetResolution(field2) + : _effectiveNullableInfo.GetDefaultResolution(PgTypeId); + + (PgConverter StructConverter, PgConverter NullableConverter) state = (structResolution.Converter, nullableResolution.Converter); + return _converterCache.GetOrAdd(structResolution.Converter, + static (_, state) => new PolymorphicArrayConverter((PgConverter)state.StructConverter, (PgConverter)state.NullableConverter), + state); + } +} diff --git a/src/Npgsql/Internal/Converters/AsyncHelpers.cs b/src/Npgsql/Internal/Converters/AsyncHelpers.cs new file mode 100644 index 0000000000..54ac02262d --- /dev/null +++ b/src/Npgsql/Internal/Converters/AsyncHelpers.cs @@ -0,0 +1,116 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +// Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. +// As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. +static class AsyncHelpers +{ + static async void AwaitTask(Task task, CompletionSource tcs, Continuation continuation) + { + try + { + await task.ConfigureAwait(false); + continuation.Invoke(task, tcs); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + } + + abstract class CompletionSource + { + public abstract void SetException(Exception exception); + } + + sealed class CompletionSource : CompletionSource + { +#if NETSTANDARD + AsyncValueTaskMethodBuilder _amb = AsyncValueTaskMethodBuilder.Create(); +#else + PoolingAsyncValueTaskMethodBuilder _amb = PoolingAsyncValueTaskMethodBuilder.Create(); +#endif + public ValueTask Task => _amb.Task; + + public void SetResult(T value) + => _amb.SetResult(value); + + public override void SetException(Exception exception) + => _amb.SetException(exception); + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. + readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public void Invoke(Task task, CompletionSource tcs) => _continuation(task, tcs); + } + + public static unsafe ValueTask ReadAsyncAsNullable(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) + where T : struct + { + // Easy if we have all the data. + var task = effectiveConverter.ReadAsync(reader, cancellationToken); + if (task.IsCompletedSuccessfully) + return new(new T?(task.Result)); + + // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. + var source = new CompletionSource(); + AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + return source.Task; + + static void UnboxAndComplete(Task task, CompletionSource completionSource) + { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. + Debug.Assert(task is Task); + Debug.Assert(completionSource is CompletionSource); + Unsafe.As>(completionSource).SetResult(new T?(new ValueTask(Unsafe.As>(task)).Result)); + } + } + + public static unsafe ValueTask ReadAsObjectAsyncAsT(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) + { + if (!typeof(T).IsValueType) + { + var value = effectiveConverter.ReadAsObjectAsync(reader, cancellationToken); + // Justification: elides the async method bloat/perf cost to transition from object to T (where T : class) + Debug.Assert(typeof(T).IsClass); + return Unsafe.As, ValueTask>(ref value); + } + + // Easy if we have all the data. + var task = effectiveConverter.ReadAsObjectAsync(reader, cancellationToken); + if (task.IsCompletedSuccessfully) + return new((T)task.Result); + + // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. + var source = new CompletionSource(); + AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + return source.Task; + + static void UnboxAndComplete(Task task, CompletionSource completionSource) + { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. + Debug.Assert(task is Task); + Debug.Assert(completionSource is CompletionSource); + Unsafe.As>(completionSource).SetResult((T)new ValueTask(Unsafe.As>(task)).Result); + } + } +} diff --git a/src/Npgsql/Internal/Converters/BitStringConverters.cs b/src/Npgsql/Internal/Converters/BitStringConverters.cs new file mode 100644 index 0000000000..b7597f96d9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/BitStringConverters.cs @@ -0,0 +1,249 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Specialized; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; +using static Npgsql.Internal.Converters.BitStringHelpers; + +namespace Npgsql.Internal.Converters; + +static class BitStringHelpers +{ + public static int GetByteLengthFromBits(int n) + { + const int BitShiftPerByte = 3; + Debug.Assert(n >= 0); + // Due to sign extension, we don't need to special case for n == 0, since ((n - 1) >> 3) + 1 = 0 + // This doesn't hold true for ((n - 1) / 8) + 1, which equals 1. + return (int)((uint)(n - 1 + (1 << BitShiftPerByte)) >> BitShiftPerByte); + } + + // http://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits + public static byte ReverseBits(byte b) => (byte)(((b * 0x80200802UL) & 0x0884422110UL) * 0x0101010101UL >> 32); +} + +sealed class BitArrayBitStringConverter : PgStreamingConverter +{ + public override BitArray Read(PgReader reader) + { + if (reader.ShouldBuffer(sizeof(int))) + reader.Buffer(sizeof(int)); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + reader.ReadBytes(bytes); + return ReadValue(bytes, bits); + } + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.BufferAsync(sizeof(int), cancellationToken).ConfigureAwait(false); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + return ReadValue(bytes, bits); + } + + internal static BitArray ReadValue(byte[] bytes, int bits) + { + for (var i = 0; i < bytes.Length; i++) + { + ref var b = ref bytes[i]; + b = ReverseBits(b); + } + + return new(bytes) { Length = bits }; + } + + public override Size GetSize(SizeContext context, BitArray value, ref object? writeState) + => sizeof(int) + GetByteLengthFromBits(value.Length); + + public override void Write(PgWriter writer, BitArray value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask WriteAsync(PgWriter writer, BitArray value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, BitArray value, CancellationToken cancellationToken = default) + { + var byteCount = writer.Current.Size.Value - sizeof(int); + var array = ArrayPool.Shared.Rent(byteCount); + for (var pos = 0; pos < byteCount; pos++) + { + var bitPos = pos*8; + var bits = Math.Min(8, value.Length - bitPos); + var b = 0; + for (var i = 0; i < bits; i++) + b += (value[bitPos + i] ? 1 : 0) << (8 - i - 1); + array[pos] = (byte)b; + } + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(value.Length); + if (async) + await writer.WriteBytesAsync(new ReadOnlyMemory(array, 0, byteCount), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(new ReadOnlySpan(array, 0, byteCount)); + + ArrayPool.Shared.Return(array); + } +} + +sealed class BitVector32BitStringConverter : PgBufferedConverter +{ + static int MaxSize => sizeof(int) + sizeof(int); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(MaxSize)); + return format is DataFormat.Binary; + } + + protected override BitVector32 ReadCore(PgReader reader) + { + if (reader.CurrentRemaining > sizeof(int) + sizeof(int)) + throw new InvalidCastException("Can't read a BIT(N) with more than 32 bits to BitVector32, only up to BIT(32)."); + + var bits = reader.ReadInt32(); + return GetByteLengthFromBits(bits) switch + { + 4 => new(reader.ReadInt32()), + 3 => new((reader.ReadInt16() << 8) + reader.ReadByte()), + 2 => new(reader.ReadInt16() << 16), + 1 => new(reader.ReadByte() << 24), + _ => new(0) + }; + } + + public override Size GetSize(SizeContext context, BitVector32 value, ref object? writeState) + => value.Data is 0 ? 4 : MaxSize; + + protected override void WriteCore(PgWriter writer, BitVector32 value) + { + if (value.Data == 0) + writer.WriteInt32(0); + else + { + writer.WriteInt32(32); + writer.WriteInt32(value.Data); + } + } +} + +sealed class BoolBitStringConverter : PgBufferedConverter +{ + static int MaxSize => sizeof(int) + sizeof(byte); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(read: Size.CreateUpperBound(MaxSize), write: MaxSize); + return format is DataFormat.Binary; + } + + protected override bool ReadCore(PgReader reader) + { + var bits = reader.ReadInt32(); + return bits switch + { + > 1 => throw new InvalidCastException("Can't read a BIT(N) type to bool, only BIT(1)."), + // We make an accommodation for varbit with no data. + 0 => false, + _ => (reader.ReadByte() & 128) is not 0 + }; + } + + public override Size GetSize(SizeContext context, bool value, ref object? writeState) => MaxSize; + protected override void WriteCore(PgWriter writer, bool value) + { + writer.WriteInt32(1); + writer.WriteByte(value ? (byte)128 : (byte)0); + } +} + +sealed class StringBitStringConverter : PgStreamingConverter +{ + public override string Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var bits = reader.ReadInt32(); + var bytes = new byte[GetByteLengthFromBits(bits)]; + if (async) + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(bytes); + + var bitArray = BitArrayBitStringConverter.ReadValue(bytes, bits); + var sb = new StringBuilder(bits); + for (var i = 0; i < bitArray.Count; i++) + sb.Append(bitArray[i] ? '1' : '0'); + + return sb.ToString(); + } + + public override Size GetSize(SizeContext context, string value, ref object? writeState) + { + if (value.AsSpan().IndexOfAnyExcept('0', '1') is not -1 and var index) + throw new ArgumentException($"Invalid bitstring character '{value[index]}' at index: {index}", nameof(value)); + + return sizeof(int) + GetByteLengthFromBits(value.Length); + } + + public override void Write(PgWriter writer, string value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask WriteAsync(PgWriter writer, string value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, string value, CancellationToken cancellationToken) + { + var byteCount = writer.Current.Size.Value - sizeof(int); + var array = ArrayPool.Shared.Rent(byteCount); + for (var pos = 0; pos < byteCount; pos++) + { + var bitPos = pos*8; + var bits = Math.Min(8, value.Length - bitPos); + var b = 0; + for (var i = 0; i < bits; i++) + b += (value[bitPos + i] == '1' ? 1 : 0) << (8 - i - 1); + array[pos] = (byte)b; + } + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(value.Length); + if (async) + await writer.WriteBytesAsync(new ReadOnlyMemory(array, 0, byteCount), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(new ReadOnlySpan(array, 0, byteCount)); + + ArrayPool.Shared.Return(array); + } +} + +/// Note that for BIT(1), this resolver will return a bool by default, to align with SqlClient +/// (see discussion https://github.com/npgsql/npgsql/pull/362#issuecomment-59622101). +sealed class PolymorphicBitStringConverterResolver : PolymorphicConverterResolver +{ + BoolBitStringConverter? _boolConverter; + BitArrayBitStringConverter? _bitArrayConverter; + + public PolymorphicBitStringConverterResolver(PgTypeId bitString) : base(bitString) { } + + protected override PgConverter Get(Field? field) + => field?.TypeModifier is 1 + ? _boolConverter ??= new BoolBitStringConverter() + : _bitArrayConverter ??= new BitArrayBitStringConverter(); +} diff --git a/src/Npgsql/Internal/Converters/CastingConverter.cs b/src/Npgsql/Internal/Converters/CastingConverter.cs new file mode 100644 index 0000000000..3fbfc5059d --- /dev/null +++ b/src/Npgsql/Internal/Converters/CastingConverter.cs @@ -0,0 +1,85 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +/// A converter to map strongly typed apis onto boxed converter results to produce a strongly typed converter over T. +sealed class CastingConverter : PgConverter +{ + readonly PgConverter _effectiveConverter; + public CastingConverter(PgConverter effectiveConverter) + : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) + => _effectiveConverter = effectiveConverter; + + protected override bool IsDbNullValue(T? value, ref object? writeState) => _effectiveConverter.IsDbNullAsObject(value, ref writeState); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => _effectiveConverter.CanConvert(format, out bufferRequirements); + + public override T Read(PgReader reader) => (T)_effectiveConverter.ReadAsObject(reader); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => this.ReadAsObjectAsyncAsT(_effectiveConverter, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => _effectiveConverter.GetSizeAsObject(context, value!, ref writeState); + + public override void Write(PgWriter writer, T value) + => _effectiveConverter.WriteAsObject(writer, value!); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => _effectiveConverter.WriteAsObjectAsync(writer, value!, cancellationToken); + + internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => async + ? _effectiveConverter.ReadAsObjectAsync(reader, cancellationToken) + : new(_effectiveConverter.ReadAsObject(reader)); + + internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return _effectiveConverter.WriteAsObjectAsync(writer, value, cancellationToken); + + _effectiveConverter.WriteAsObject(writer, value); + return new(); + } +} + +// Given there aren't many instantiations of converter resolvers (and it's fairly involved to write a fast one) we use the composing base class. +sealed class CastingConverterResolver : PgComposingConverterResolver +{ + public CastingConverterResolver(PgResolverTypeInfo effectiveResolverTypeInfo) + : base(effectiveResolverTypeInfo.PgTypeId, effectiveResolverTypeInfo) { } + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + => new CastingConverter(effectiveResolution.Converter); + + protected override PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectiveTypeId) + => EffectiveTypeInfo.GetResolutionAsObject(value, expectedEffectiveTypeId); +} + +static class CastingTypeInfoExtensions +{ + [RequiresDynamicCode("Changing boxing converters to their non-boxing counterpart can require creating new generic types or methods, which requires creating code at runtime. This may not be AOT when AOT compiling")] + internal static PgTypeInfo ToNonBoxing(this PgTypeInfo typeInfo) + { + if (!typeInfo.IsBoxing) + return typeInfo; + + var type = typeInfo.Type; + if (typeInfo is PgResolverTypeInfo resolverTypeInfo) + return new PgResolverTypeInfo(typeInfo.Options, + (PgConverterResolver)Activator.CreateInstance(typeof(CastingConverterResolver<>).MakeGenericType(type), + resolverTypeInfo)!, typeInfo.PgTypeId); + + var resolution = typeInfo.GetResolution(); + return new PgTypeInfo(typeInfo.Options, + (PgConverter)Activator.CreateInstance(typeof(CastingConverter<>).MakeGenericType(type), resolution.Converter)!, resolution.PgTypeId); + } +} diff --git a/src/Npgsql/Internal/Converters/CompositeConverter.cs b/src/Npgsql/Internal/Converters/CompositeConverter.cs new file mode 100644 index 0000000000..24f3d36329 --- /dev/null +++ b/src/Npgsql/Internal/Converters/CompositeConverter.cs @@ -0,0 +1,229 @@ +using System; +using System.Buffers; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Composites; + +namespace Npgsql.Internal.Converters; + +sealed class CompositeConverter : PgStreamingConverter where T : notnull +{ + readonly CompositeInfo _composite; + readonly BufferRequirements _bufferRequirements; + + public CompositeConverter(CompositeInfo composite) + { + _composite = composite; + + var req = BufferRequirements.CreateFixedSize(sizeof(int) + _composite.Fields.Count * (sizeof(uint) + sizeof(int))); + foreach (var field in _composite.Fields) + { + var readReq = field.BinaryReadRequirement; + var writeReq = field.BinaryWriteRequirement; + + // If so we cannot depend on its buffer size being fixed. + if (field.IsDbNullable) + { + readReq = readReq.Combine(Size.CreateUpperBound(0)); + writeReq = writeReq.Combine(Size.CreateUpperBound(0)); + } + + req = req.Combine(readReq, writeReq); + } + + // We have to put a limit on the requirements we report otherwise smaller buffer sizes won't work. + req = BufferRequirements.Create(Limit(req.Read), Limit(req.Write)); + + _bufferRequirements = req; + + // Return unknown if we hit the limit. + Size Limit(Size requirement) + { + const int maxByteCount = 1024; + return requirement.GetValueOrDefault() > maxByteCount ? requirement.Combine(Size.Unknown) : requirement; + } + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = _bufferRequirements; + return format is DataFormat.Binary; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + // TODO we can make a nice thread-static cache for this. + using var builder = new CompositeBuilder(_composite); + + var count = reader.ReadInt32(); + if (count != _composite.Fields.Count) + throw new InvalidOperationException("Cannot read composite type with mismatched number of fields."); + + foreach (var field in _composite.Fields) + { + if (reader.ShouldBuffer(sizeof(uint) + sizeof(int))) + await reader.Buffer(async, sizeof(uint) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var oid = reader.ReadUInt32(); + var length = reader.ReadInt32(); + + // We're only requiring the PgTypeIds to be oids if this converter is actually used during execution. + // As a result we can still introspect in the global mapper and create all the info with portable ids. + if(oid != field.PgTypeId.Oid) + // We could remove this requirement by storing a dictionary of CompositeInfos keyed by backend. + throw new InvalidCastException( + $"Cannot read oid {oid} into composite field {field.Name} with oid {field.PgTypeId}. " + + $"This could be caused by a DDL change after this DataSource loaded its types, or a difference between column order of table composites between backends, make sure these line up identically."); + + if (length is -1) + field.ReadDbNull(builder); + else + { + var converter = field.GetReadInfo(out var readRequirement); + var scope = await reader.BeginNestedRead(async, length, readRequirement, cancellationToken).ConfigureAwait(false); + try + { + await field.Read(async, converter, builder, reader, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + return builder.Complete(); + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var arrayPool = ArrayPool.Shared; + var data = arrayPool.Rent(_composite.Fields.Count); + + var totalSize = Size.Create(sizeof(int) + _composite.Fields.Count * (sizeof(uint) + sizeof(int))); + var boxedInstance = (object)value; + var anyWriteState = false; + for (var i = 0; i < _composite.Fields.Count; i++) + { + var field = _composite.Fields[i]; + var converter = field.GetWriteInfo(boxedInstance, out var writeRequirement); + object? fieldState = null; + var fieldSize = field.GetSizeOrDbNull(converter, context.Format, writeRequirement, boxedInstance, ref fieldState); + anyWriteState = anyWriteState || fieldState is not null; + data[i] = new() + { + Size = fieldSize ?? -1, + WriteState = fieldState, + Converter = converter, + BufferRequirement = writeRequirement + }; + totalSize = totalSize.Combine(fieldSize ?? 0); + } + + writeState = new WriteState + { + ArrayPool = arrayPool, + Data = new(data, 0, _composite.Fields.Count), + AnyWriteState = anyWriteState, + BoxedInstance = boxedInstance, + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not null and not WriteState) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(_composite.Fields.Count); + + var writeState = writer.Current.WriteState as WriteState; + var boxedInstance = writeState?.BoxedInstance ?? value; + var data = writeState?.Data.Array; + for (var i = 0; i < _composite.Fields.Count; i++) + { + if (writer.ShouldFlush(sizeof(uint) + sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var field = _composite.Fields[i]; + writer.WriteAsOid(field.PgTypeId); + + ElementState elementState; + if (data?[i] is not { } state) + { + var converter = field.GetWriteInfo(boxedInstance, out var writeRequirement); + object? fieldState = null; + elementState = new() + { + Size = field.IsDbNull(converter, boxedInstance, ref fieldState) ? -1 : writeRequirement, + WriteState = null, + Converter = converter, + BufferRequirement = writeRequirement, + }; + } + else + elementState = state; + var length = elementState.Size.Value; + writer.WriteInt32(length); + if (length is not -1) + { + using var _ = await writer.BeginNestedWrite(async, elementState.BufferRequirement, length, elementState.WriteState, cancellationToken).ConfigureAwait(false); + await field.Write(async, elementState.Converter, writer, boxedInstance, cancellationToken).ConfigureAwait(false); + } + } + } + + readonly struct ElementState + { + public required Size Size { get; init; } + public required object? WriteState { get; init; } + public required PgConverter Converter { get; init; } + public required Size BufferRequirement { get; init; } + } + + class WriteState : IDisposable + { + public required ArrayPool? ArrayPool { get; init; } + public required ArraySegment Data { get; init; } + public required bool AnyWriteState { get; init; } + public required object BoxedInstance { get; init; } + + public void Dispose() + { + if (Data.Array is not { } array) + return; + + if (AnyWriteState) + { + for (var i = Data.Offset; i < array.Length; i++) + if (array[i].WriteState is IDisposable disposable) + disposable.Dispose(); + + Array.Clear(Data.Array, Data.Offset, Data.Count); + } + + ArrayPool?.Return(Data.Array); + } + } +} diff --git a/src/Npgsql/Internal/Converters/EnumConverter.cs b/src/Npgsql/Internal/Converters/EnumConverter.cs new file mode 100644 index 0000000000..12f85992f0 --- /dev/null +++ b/src/Npgsql/Internal/Converters/EnumConverter.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; + +namespace Npgsql.Internal.Converters; + +[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] +sealed class EnumConverter : PgBufferedConverter where TEnum : struct, Enum +{ + readonly Dictionary _enumToLabel; + readonly Dictionary _labelToEnum; + readonly Encoding _encoding; + + // Unmapped enums + public EnumConverter(Dictionary enumToLabel, Dictionary labelToEnum, Encoding encoding) + { + _enumToLabel = new(enumToLabel.Count); + foreach (var kv in enumToLabel) + _enumToLabel.Add((TEnum)kv.Key, kv.Value); + + _labelToEnum = new(labelToEnum.Count); + foreach (var kv in labelToEnum) + _labelToEnum.Add(kv.Key, (TEnum)kv.Value); + + _encoding = encoding; + } + + public EnumConverter(Dictionary enumToLabel, Dictionary labelToEnum, Encoding encoding) + { + _enumToLabel = enumToLabel; + _labelToEnum = labelToEnum; + _encoding = encoding; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Value; + return format is DataFormat.Binary or DataFormat.Text; + } + + public override Size GetSize(SizeContext context, TEnum value, ref object? writeState) + { + if (!_enumToLabel.TryGetValue(value, out var str)) + throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); + + return _encoding.GetByteCount(str); + } + + protected override TEnum ReadCore(PgReader reader) + { + var str = _encoding.GetString(reader.ReadBytes(reader.CurrentRemaining)); + var success = _labelToEnum.TryGetValue(str, out var value); + + if (!success) + throw new InvalidCastException($"Received enum value '{str}' from database which wasn't found on enum {typeof(TEnum)}"); + + return value; + } + + protected override void WriteCore(PgWriter writer, TEnum value) + { + if (!_enumToLabel.TryGetValue(value, out var str)) + throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); + + writer.WriteBytes(new ReadOnlySpan(_encoding.GetBytes(str))); + } +} diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs new file mode 100644 index 0000000000..220cc88894 --- /dev/null +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs @@ -0,0 +1,227 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; +using static NpgsqlTypes.NpgsqlTsQuery.NodeKind; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TsQueryConverter : PgStreamingConverter + where T : NpgsqlTsQuery +{ + readonly Encoding _encoding; + + public TsQueryConverter(Encoding encoding) + => _encoding = encoding; + + public override T Read(PgReader reader) + => (T)Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => (T)await Read(async: true, reader, cancellationToken).ConfigureAwait(false); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numTokens = reader.ReadInt32(); + if (numTokens == 0) + return new NpgsqlTsQueryEmpty(); + + NpgsqlTsQuery? value = null; + var nodes = new Stack<(NpgsqlTsQuery Node, int Location)>(); + + for (var i = 0; i < numTokens; i++) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + switch (reader.ReadByte()) + { + case 1: // lexeme + if (reader.ShouldBuffer(sizeof(byte) + sizeof(byte))) + await reader.Buffer(async, sizeof(byte) + sizeof(byte), cancellationToken).ConfigureAwait(false); + var weight = (NpgsqlTsQueryLexeme.Weight)reader.ReadByte(); + var prefix = reader.ReadByte() != 0; + + var str = async + ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(_encoding); + InsertInTree(new NpgsqlTsQueryLexeme(str, weight, prefix), nodes, ref value); + continue; + + case 2: // operation + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + var kind = (NpgsqlTsQuery.NodeKind)reader.ReadByte(); + + NpgsqlTsQuery node; + switch (kind) + { + case Not: + node = new NpgsqlTsQueryNot(null!); + InsertInTree(node, nodes, ref value); + nodes.Push((node, 0)); + continue; + + case And: + node = new NpgsqlTsQueryAnd(null!, null!); + break; + case Or: + node = new NpgsqlTsQueryOr(null!, null!); + break; + case Phrase: + if (reader.ShouldBuffer(sizeof(short))) + await reader.Buffer(async, sizeof(short), cancellationToken).ConfigureAwait(false); + node = new NpgsqlTsQueryFollowedBy(null!, reader.ReadInt16(), null!); + break; + default: + throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug."); + } + + InsertInTree(node, nodes, ref value); + + nodes.Push((node, 1)); + nodes.Push((node, 2)); + continue; + + case var tokenType: + throw new UnreachableException( + $"Internal Npgsql bug: unexpected token type {tokenType} when reading tsquery. Please file a bug."); + } + } + + if (nodes.Count != 0) + throw new UnreachableException("Internal Npgsql bug, please report."); + + return value!; + + static void InsertInTree(NpgsqlTsQuery node, Stack<(NpgsqlTsQuery Node, int Location)> nodes, ref NpgsqlTsQuery? value) + { + if (nodes.Count == 0) + value = node; + else + { + var parent = nodes.Pop(); + switch (parent.Location) + { + case 0: + ((NpgsqlTsQueryNot)parent.Node).Child = node; + break; + case 1: + ((NpgsqlTsQueryBinOp)parent.Node).Left = node; + break; + case 2: + ((NpgsqlTsQueryBinOp)parent.Node).Right = node; + break; + default: + throw new UnreachableException("Internal Npgsql bug, please report."); + } + } + } + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => value.Kind is Empty + ? 4 + : 4 + GetNodeLength(value); + + int GetNodeLength(NpgsqlTsQuery node) + => node.Kind switch + { + Lexeme when _encoding.GetByteCount(((NpgsqlTsQueryLexeme)node).Text) is var strLen + => strLen > 2046 + ? throw new InvalidCastException("Lexeme text too long. Must be at most 2046 encoded bytes.") + : 4 + strLen, + And or Or => 2 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right), + Not => 2 + GetNodeLength(((NpgsqlTsQueryNot)node).Child), + Empty => throw new InvalidOperationException("Empty tsquery nodes must be top-level"), + + // 2 additional bytes for uint16 phrase operator "distance" field. + Phrase => 4 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right), + + _ => throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {node.Kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") + }; + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlTsQuery value, CancellationToken cancellationToken) + { + var numTokens = GetTokenCount(value); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(numTokens); + + if (numTokens is 0) + return; + + await WriteCore(value).ConfigureAwait(false); + + async Task WriteCore(NpgsqlTsQuery node) + { + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(node.Kind is Lexeme ? (byte)1 : (byte)2); + + if (node.Kind is Lexeme) + { + var lexemeNode = (NpgsqlTsQueryLexeme)node; + + if (writer.ShouldFlush(sizeof(byte) + sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte((byte)lexemeNode.Weights); + writer.WriteByte(lexemeNode.IsPrefixSearch ? (byte)1 : (byte)0); + + if (async) + await writer.WriteCharsAsync(lexemeNode.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(lexemeNode.Text.AsMemory().Span, _encoding); + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); + return; + } + + writer.WriteByte((byte)node.Kind); + + switch (node.Kind) + { + case Not: + await WriteCore(((NpgsqlTsQueryNot)node).Child).ConfigureAwait(false); + return; + case Phrase: + writer.WriteInt16(((NpgsqlTsQueryFollowedBy)node).Distance); + break; + } + + await WriteCore(((NpgsqlTsQueryBinOp)node).Right).ConfigureAwait(false); + await WriteCore(((NpgsqlTsQueryBinOp)node).Left).ConfigureAwait(false); + } + } + + int GetTokenCount(NpgsqlTsQuery node) + => node.Kind switch + { + Lexeme => 1, + And or Or or Phrase => 1 + GetTokenCount(((NpgsqlTsQueryBinOp)node).Left) + GetTokenCount(((NpgsqlTsQueryBinOp)node).Right), + Not => 1 + GetTokenCount(((NpgsqlTsQueryNot)node).Child), + Empty => 0, + + _ => throw new UnreachableException( + $"Internal Npgsql bug: unexpected value {node.Kind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") + }; +} diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs new file mode 100644 index 0000000000..2c431fd35b --- /dev/null +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TsVectorConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + + public TsVectorConverter(Encoding encoding) + => _encoding = encoding; + + public override NpgsqlTsVector Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var numLexemes = reader.ReadInt32(); + var lexemes = new List(numLexemes); + + for (var i = 0; i < numLexemes; i++) + { + var lexemeString = async + ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(_encoding); + + if (reader.ShouldBuffer(sizeof(short))) + await reader.Buffer(async, sizeof(short), cancellationToken).ConfigureAwait(false); + var numPositions = reader.ReadInt16(); + + if (numPositions == 0) + { + lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, wordEntryPositions: null, noCopy: true)); + continue; + } + + // There can only be a maximum of 256 positions, so we just before them all (256 * sizeof(short) = 512) + if (numPositions > 256) + throw new NpgsqlException($"Got {numPositions} lexeme positions when reading tsvector"); + + if (reader.ShouldBuffer(numPositions * sizeof(short))) + await reader.Buffer(async, numPositions * sizeof(short), cancellationToken).ConfigureAwait(false); + + var positions = new List(numPositions); + + for (var j = 0; j < numPositions; j++) + { + var wordEntryPos = reader.ReadInt16(); + positions.Add(new NpgsqlTsVector.Lexeme.WordEntryPos(wordEntryPos)); + } + + lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, positions, noCopy: true)); + } + + return new NpgsqlTsVector(lexemes, noCheck: true); + } + + public override Size GetSize(SizeContext context, NpgsqlTsVector value, ref object? writeState) + { + var size = 4; + foreach (var l in value) + size += _encoding.GetByteCount(l.Text) + 1 + 2 + l.Count * 2; + + return size; + } + + public override void Write(PgWriter writer, NpgsqlTsVector value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlTsVector value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlTsVector value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + foreach (var lexeme in value) + { + if (async) + await writer.WriteCharsAsync(lexeme.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(lexeme.Text.AsMemory().Span, _encoding); + + if (writer.ShouldFlush(sizeof(byte) + sizeof(short))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte(0); + writer.WriteInt16((short)lexeme.Count); + + for (var i = 0; i < lexeme.Count; i++) + { + if (writer.ShouldFlush(sizeof(short))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt16(lexeme[i].Value); + } + } + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs b/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs new file mode 100644 index 0000000000..4a7578afba --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/BoxConverter.cs @@ -0,0 +1,26 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BoxConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 4); + return format is DataFormat.Binary; + } + + protected override NpgsqlBox ReadCore(PgReader reader) + => new( + new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble()), + new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + + protected override void WriteCore(PgWriter writer, NpgsqlBox value) + { + writer.WriteDouble(value.Right); + writer.WriteDouble(value.Top); + writer.WriteDouble(value.Left); + writer.WriteDouble(value.Bottom); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs b/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs new file mode 100644 index 0000000000..51eea75814 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/CircleConverter.cs @@ -0,0 +1,23 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class CircleConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 3); + return format is DataFormat.Binary; + } + + protected override NpgsqlCircle ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlCircle value) + { + writer.WriteDouble(value.X); + writer.WriteDouble(value.Y); + writer.WriteDouble(value.Radius); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs b/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs new file mode 100644 index 0000000000..17d89909b9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/LineConverter.cs @@ -0,0 +1,23 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class LineConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 3); + return format is DataFormat.Binary; + } + + protected override NpgsqlLine ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlLine value) + { + writer.WriteDouble(value.A); + writer.WriteDouble(value.B); + writer.WriteDouble(value.C); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs b/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs new file mode 100644 index 0000000000..117a108379 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/LineSegmentConverter.cs @@ -0,0 +1,24 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class LineSegmentConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 4); + return format is DataFormat.Binary; + } + + protected override NpgsqlLSeg ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlLSeg value) + { + writer.WriteDouble(value.Start.X); + writer.WriteDouble(value.Start.Y); + writer.WriteDouble(value.End.X); + writer.WriteDouble(value.End.Y); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs new file mode 100644 index 0000000000..c78ba84013 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs @@ -0,0 +1,68 @@ +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PathConverter : PgStreamingConverter +{ + public override NpgsqlPath Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(byte) + sizeof(int))) + await reader.Buffer(async, sizeof(byte) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var open = reader.ReadByte() switch + { + 1 => false, + 0 => true, + _ => throw new UnreachableException("Error decoding binary geometric path: bad open byte") + }; + + var numPoints = reader.ReadInt32(); + var result = new NpgsqlPath(numPoints, open); + + for (var i = 0; i < numPoints; i++) + { + if (reader.ShouldBuffer(sizeof(double) * 2)) + await reader.Buffer(async, sizeof(byte) + sizeof(int), cancellationToken).ConfigureAwait(false); + + result.Add(new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + } + + return result; + } + + public override Size GetSize(SizeContext context, NpgsqlPath value, ref object? writeState) + => 5 + value.Count * sizeof(double) * 2; + + public override void Write(PgWriter writer, NpgsqlPath value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlPath value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlPath value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(byte) + sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteByte((byte)(value.Open ? 0 : 1)); + writer.WriteInt32(value.Count); + + foreach (var p in value) + { + if (writer.ShouldFlush(sizeof(double) * 2)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(p.X); + writer.WriteDouble(p.Y); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs new file mode 100644 index 0000000000..03e84c05bd --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PointConverter.cs @@ -0,0 +1,22 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PointConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double) * 2); + return format is DataFormat.Binary; + } + + protected override NpgsqlPoint ReadCore(PgReader reader) + => new(reader.ReadDouble(), reader.ReadDouble()); + + protected override void WriteCore(PgWriter writer, NpgsqlPoint value) + { + writer.WriteDouble(value.X); + writer.WriteDouble(value.Y); + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs new file mode 100644 index 0000000000..9a889b4323 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/PolygonConverter.cs @@ -0,0 +1,55 @@ +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PolygonConverter : PgStreamingConverter +{ + public override NpgsqlPolygon Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numPoints = reader.ReadInt32(); + var result = new NpgsqlPolygon(numPoints); + for (var i = 0; i < numPoints; i++) + { + if (reader.ShouldBuffer(sizeof(double) * 2)) + await reader.Buffer(async, sizeof(double) * 2, cancellationToken).ConfigureAwait(false); + result.Add(new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); + } + + return result; + } + + public override Size GetSize(SizeContext context, NpgsqlPolygon value, ref object? writeState) + => 4 + value.Count * sizeof(double) * 2; + + public override void Write(PgWriter writer, NpgsqlPolygon value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlPolygon value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlPolygon value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + foreach (var p in value) + { + if (writer.ShouldFlush(sizeof(double) * 2)) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(p.X); + writer.WriteDouble(p.Y); + } + } +} diff --git a/src/Npgsql/Internal/Converters/HstoreConverter.cs b/src/Npgsql/Internal/Converters/HstoreConverter.cs new file mode 100644 index 0000000000..e2e8762d8e --- /dev/null +++ b/src/Npgsql/Internal/Converters/HstoreConverter.cs @@ -0,0 +1,159 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class HstoreConverter : PgStreamingConverter where T : ICollection> +{ + readonly Encoding _encoding; + readonly Func>, T>? _convert; + + public HstoreConverter(Encoding encoding, Func>, T>? convert = null) + { + _encoding = encoding; + _convert = convert; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + // Number of lengths (count, key length, value length). + var totalSize = sizeof(int) + value.Count * (sizeof(int) + sizeof(int)); + if (value.Count is 0) + return totalSize; + + var arrayPool = ArrayPool<(Size Size, object? WriteState)>.Shared; + var data = arrayPool.Rent(value.Count * 2); + + var i = 0; + foreach (var kv in value) + { + if (kv.Key is null) + throw new ArgumentException("Hstore doesn't support null keys", nameof(value)); + + var keySize = _encoding.GetByteCount(kv.Key); + var valueSize = kv.Value is null ? -1 : _encoding.GetByteCount(kv.Value); + totalSize += keySize + (valueSize is -1 ? 0 : valueSize); + data[i] = (keySize, null); + data[i + 1] = (valueSize, null); + i += 2; + } + writeState = new WriteState + { + ArrayPool = arrayPool, + Data = new(data, 0, value.Count * 2), + AnyWriteState = false + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var count = reader.ReadInt32(); + + var result = typeof(T) == typeof(Dictionary) || typeof(T) == typeof(IDictionary) + ? (ICollection>)new Dictionary(count) + : new List>(count); + + for (var i = 0; i < count; i++) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var keySize = reader.ReadInt32(); + var key = _encoding.GetString(async + ? await reader.ReadBytesAsync(keySize, cancellationToken).ConfigureAwait(false) + : reader.ReadBytes(keySize) + ); + + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var valueSize = reader.ReadInt32(); + string? value = null; + if (valueSize is not -1) + value = _encoding.GetString(async + ? await reader.ReadBytesAsync(valueSize, cancellationToken).ConfigureAwait(false) + : reader.ReadBytes(valueSize) + ); + + result.Add(new(key, value)); + } + + if (typeof(T) == typeof(Dictionary) || typeof(T) == typeof(IDictionary)) + return (T)result; + + return _convert is null ? throw new NotSupportedException() : _convert(result); + } + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not WriteState && value.Count is not 0) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + // Number of lengths (count, key length, value length). + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + if (value.Count is 0 || writer.Current.WriteState is not WriteState writeState) + return; + + var data = writeState.Data; + var i = data.Offset; + foreach (var kv in value) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (size, _) = data.Array![i]; + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (async) + await writer.WriteCharsAsync(kv.Key.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(kv.Key.AsSpan(), _encoding); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (valueSize, _) = data.Array![i + 1]; + if (valueSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var valueLength = valueSize.Value; + writer.WriteInt32(valueLength); + if (valueLength is not -1) + { + if (async) + await writer.WriteCharsAsync(kv.Value.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + else + writer.WriteChars(kv.Value.AsSpan(), _encoding); + } + i += 2; + } + } + + sealed class WriteState : MultiWriteState + { + } +} diff --git a/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs new file mode 100644 index 0000000000..5d00a26dcb --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class InternalCharConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(byte)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadByte()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteByte(byte.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadByte(); + if (typeof(byte) == typeof(T)) + return (T)(object)value; + if (typeof(char) == typeof(T)) + return (T)(object)(char)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(byte) == typeof(T)) + writer.WriteByte((byte)(object)value!); + else if (typeof(char) == typeof(T)) + writer.WriteByte(checked((byte)(char)(object)value!)); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs b/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs new file mode 100644 index 0000000000..96730c857a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/PgLsnConverter.cs @@ -0,0 +1,15 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class PgLsnConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(ulong)); + return format is DataFormat.Binary; + } + protected override NpgsqlLogSequenceNumber ReadCore(PgReader reader) => new(reader.ReadUInt64()); + protected override void WriteCore(PgWriter writer, NpgsqlLogSequenceNumber value) => writer.WriteUInt64((ulong)value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/TidConverter.cs b/src/Npgsql/Internal/Converters/Internal/TidConverter.cs new file mode 100644 index 0000000000..747d98fe17 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/TidConverter.cs @@ -0,0 +1,19 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(uint) + sizeof(ushort)); + return format is DataFormat.Binary; + } + protected override NpgsqlTid ReadCore(PgReader reader) => new(reader.ReadUInt32(), reader.ReadUInt16()); + protected override void WriteCore(PgWriter writer, NpgsqlTid value) + { + writer.WriteUInt32(value.BlockNumber); + writer.WriteUInt16(value.OffsetNumber); + } +} diff --git a/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs b/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs new file mode 100644 index 0000000000..92061b1fd2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/UInt32Converter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class UInt32Converter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(uint)); + return format is DataFormat.Binary; + } + protected override uint ReadCore(PgReader reader) => reader.ReadUInt32(); + protected override void WriteCore(PgWriter writer, uint value) => writer.WriteUInt32(value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs b/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs new file mode 100644 index 0000000000..fcf5e3695a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/UInt64Converter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class UInt64Converter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(ulong)); + return format is DataFormat.Binary; + } + protected override ulong ReadCore(PgReader reader) => reader.ReadUInt64(); + protected override void WriteCore(PgWriter writer, ulong value) => writer.WriteUInt64(value); +} diff --git a/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs b/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs new file mode 100644 index 0000000000..45b48df5b5 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Internal/VoidConverter.cs @@ -0,0 +1,13 @@ +using System; + +namespace Npgsql.Internal.Converters.Internal; + +// Void is not a value so we read it as a null reference, not a DBNull. +sealed class VoidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(DataFormat.Binary, out bufferRequirements); // Text is identical + + protected override object? ReadCore(PgReader reader) => null; + protected override void WriteCore(PgWriter writer, object? value) => throw new NotSupportedException(); +} diff --git a/src/Npgsql/Internal/Converters/JsonConverter.cs b/src/Npgsql/Internal/Converters/JsonConverter.cs new file mode 100644 index 0000000000..75873e1951 --- /dev/null +++ b/src/Npgsql/Internal/Converters/JsonConverter.cs @@ -0,0 +1,212 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class JsonConverter : PgStreamingConverter where T: TBase? +{ + readonly bool _jsonb; + readonly Encoding _textEncoding; + readonly JsonTypeInfo _jsonTypeInfo; + readonly JsonTypeInfo? _objectTypeInfo; + + public JsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerOptions serializerOptions) + { + if (serializerOptions.TypeInfoResolver is null) + throw new InvalidOperationException("System.Text.Json serialization requires a type info resolver, make sure to set-it up beforehand."); + + _jsonb = jsonb; + _textEncoding = textEncoding; + _jsonTypeInfo = typeof(TBase) != typeof(object) && typeof(T) != typeof(TBase) + ? (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(TBase)) + : (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(T)); + // Unspecified polymorphism, let STJ handle it. + _objectTypeInfo = typeof(TBase) == typeof(object) + ? (JsonTypeInfo)serializerOptions.GetTypeInfo(typeof(object)) + : null; + } + + public override T? Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (_jsonb && reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + // We always fall back to buffers on older targets due to the absence of transcoding stream. + if (JsonConverter.TryReadStream(_jsonb, _textEncoding, reader, out var byteCount, out var stream)) + { + using var _ = stream; + return _jsonTypeInfo switch + { + JsonTypeInfo => (T)(object)(async + ? await JsonDocument.ParseAsync(stream, cancellationToken: cancellationToken).ConfigureAwait(false) + : JsonDocument.Parse(stream)), + + JsonTypeInfo typeInfoOfT => async + ? await JsonSerializer.DeserializeAsync(stream, typeInfoOfT, cancellationToken).ConfigureAwait(false) + : JsonSerializer.Deserialize(stream, typeInfoOfT), + + _ => (T?)(async + ? await JsonSerializer.DeserializeAsync(stream, (JsonTypeInfo)_jsonTypeInfo, cancellationToken) + .ConfigureAwait(false) + : JsonSerializer.Deserialize(stream, (JsonTypeInfo)_jsonTypeInfo)) + }; + } + + var (rentedChars, rentedBytes) = await JsonConverter.ReadRentedBuffer(async, _textEncoding, byteCount, reader, cancellationToken).ConfigureAwait(false); + var result = _jsonTypeInfo switch + { + JsonTypeInfo => (T)(object)JsonDocument.Parse(rentedChars.AsMemory()), + JsonTypeInfo typeInfoOfT => JsonSerializer.Deserialize(rentedChars.AsSpan(), typeInfoOfT), + _ => (T?)JsonSerializer.Deserialize(rentedChars.AsSpan(), (JsonTypeInfo)_jsonTypeInfo) + }; + + ArrayPool.Shared.Return(rentedChars.Array!); + if (rentedBytes is not null) + ArrayPool.Shared.Return(rentedBytes); + + return result; + } + + public override Size GetSize(SizeContext context, T? value, ref object? writeState) + { + var capacity = 0; + if (typeof(T) == typeof(JsonDocument)) + capacity = ((JsonDocument?)(object?)value)?.RootElement.GetRawText().Length ?? 0; + var stream = new MemoryStream(capacity); + + // Mirroring ASP.NET Core serialization strategy https://github.com/dotnet/aspnetcore/issues/47548 + if (_objectTypeInfo is null) + JsonSerializer.Serialize(stream, value, (JsonTypeInfo)_jsonTypeInfo); + else + JsonSerializer.Serialize(stream, value, _objectTypeInfo); + + return JsonConverter.GetSizeCore(_jsonb, stream, _textEncoding, ref writeState); + } + + public override void Write(PgWriter writer, T? value) + => JsonConverter.Write(_jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => JsonConverter.Write(_jsonb, async: true, writer, cancellationToken); +} + +// Split out to avoid unnecessary code duplication. +static class JsonConverter +{ + public const byte JsonbProtocolVersion = 1; + // We pick a value that is the largest multiple of 4096 that is still smaller than the large object heap threshold (85K). + const int StreamingThreshold = 81920; + + public static bool TryReadStream(bool jsonb, Encoding encoding, PgReader reader, out int byteCount, [NotNullWhen(true)]out Stream? stream) + { + if (jsonb) + { + var version = reader.ReadByte(); + if (version != JsonbProtocolVersion) + throw new InvalidCastException($"Unknown jsonb wire format version {version}"); + } + + var isUtf8 = encoding.CodePage == Encoding.UTF8.CodePage; + byteCount = reader.CurrentRemaining; + // We always fall back to buffers on older targets + if (isUtf8 +#if !NETSTANDARD + || byteCount >= StreamingThreshold +#endif + ) + { + stream = +#if !NETSTANDARD + !isUtf8 + ? Encoding.CreateTranscodingStream(reader.GetStream(), encoding, Encoding.UTF8) + : reader.GetStream(); +#else + reader.GetStream(); + Debug.Assert(isUtf8); +#endif + } + else + stream = null; + + return stream is not null; + } + + public static async ValueTask<(ArraySegment RentedChars, byte[]? RentedBytes)> ReadRentedBuffer(bool async, Encoding encoding, int byteCount, PgReader reader, CancellationToken cancellationToken) + { + // Never utf8, but we may still be able to save a copy. + byte[]? rentedBuffer = null; + if (!reader.TryReadBytes(byteCount, out ReadOnlyMemory buffer)) + { + rentedBuffer = ArrayPool.Shared.Rent(byteCount); + if (async) + await reader.ReadBytesAsync(rentedBuffer.AsMemory(0, byteCount), cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(rentedBuffer.AsSpan(0, byteCount)); + buffer = rentedBuffer.AsMemory(0, byteCount); + } + + var charCount = encoding.GetCharCount(buffer.Span); + var chars = ArrayPool.Shared.Rent(charCount); + encoding.GetChars(buffer.Span, chars); + + return (new(chars, 0, charCount), rentedBuffer); + } + + public static Size GetSizeCore(bool jsonb, MemoryStream stream, Encoding encoding, ref object? writeState) + { + if (encoding.CodePage == Encoding.UTF8.CodePage) + { + writeState = stream; + return (int)stream.Length + (jsonb ? sizeof(byte) : 0); + } + + if (!stream.TryGetBuffer(out var buffer)) + throw new InvalidOperationException(); + + var bytes = encoding.GetBytes(Encoding.UTF8.GetChars(buffer.Array!, buffer.Offset, buffer.Count)); + writeState = bytes; + return bytes.Length + (jsonb ? sizeof(byte) : 0); + } + + public static async ValueTask Write(bool jsonb, bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (jsonb) + { + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(JsonbProtocolVersion); + } + + ArraySegment buffer; + switch (writer.Current.WriteState) + { + case MemoryStream stream: + if (!stream.TryGetBuffer(out buffer)) + throw new InvalidOperationException(); + break; + case byte[] bytes: + buffer = new ArraySegment(bytes); + break; + default: + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + } + + if (async) + await writer.WriteBytesAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + else + writer.WriteBytes(buffer.AsSpan()); + } +} diff --git a/src/Npgsql/Internal/Converters/MoneyConverter.cs b/src/Npgsql/Internal/Converters/MoneyConverter.cs new file mode 100644 index 0000000000..8443acedc3 --- /dev/null +++ b/src/Npgsql/Internal/Converters/MoneyConverter.cs @@ -0,0 +1,74 @@ +using System; +using System.Numerics; + +namespace Npgsql.Internal.Converters; + +sealed class MoneyConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override T ReadCore(PgReader reader) => ConvertTo(new PgMoney(reader.ReadInt64())); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(ConvertFrom(value).GetValue()); + + static PgMoney ConvertFrom(T value) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return new PgMoney((decimal)(short)(object)value!); + if (typeof(int) == typeof(T)) + return new PgMoney((decimal)(int)(object)value!); + if (typeof(long) == typeof(T)) + return new PgMoney((decimal)(long)(object)value!); + + if (typeof(byte) == typeof(T)) + return new PgMoney((decimal)(byte)(object)value!); + if (typeof(sbyte) == typeof(T)) + return new PgMoney((decimal)(sbyte)(object)value!); + + if (typeof(float) == typeof(T)) + return new PgMoney((decimal)(float)(object)value!); + if (typeof(double) == typeof(T)) + return new PgMoney((decimal)(double)(object)value!); + if (typeof(decimal) == typeof(T)) + return new PgMoney((decimal)(object)value!); + + throw new NotSupportedException(); +#else + return new PgMoney(decimal.CreateChecked(value)); +#endif + } + + static T ConvertTo(PgMoney money) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return (T)(object)(short)money.ToDecimal(); + if (typeof(int) == typeof(T)) + return (T)(object)(int)money.ToDecimal(); + if (typeof(long) == typeof(T)) + return (T)(object)(long)money.ToDecimal(); + + if (typeof(byte) == typeof(T)) + return (T)(object)(byte)money.ToDecimal(); + if (typeof(sbyte) == typeof(T)) + return (T)(object)(sbyte)money.ToDecimal(); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)money.ToDecimal(); + if (typeof(double) == typeof(T)) + return (T)(object)(double)money.ToDecimal(); + if (typeof(decimal) == typeof(T)) + return (T)(object)money.ToDecimal(); + + throw new NotSupportedException(); +#else + return T.CreateChecked(money.ToDecimal()); +#endif + } +} diff --git a/src/Npgsql/Internal/Converters/MultirangeConverter.cs b/src/Npgsql/Internal/Converters/MultirangeConverter.cs new file mode 100644 index 0000000000..36ae35a11c --- /dev/null +++ b/src/Npgsql/Internal/Converters/MultirangeConverter.cs @@ -0,0 +1,139 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class MultirangeConverter : PgStreamingConverter + where T : IList + where TRange : notnull +{ + readonly PgConverter _rangeConverter; + readonly BufferRequirements _rangeRequirements; + + public MultirangeConverter(PgConverter rangeConverter) + { + if (!rangeConverter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Range subtype converter has to support the binary format to be compatible."); + _rangeRequirements = bufferRequirements; + _rangeConverter = rangeConverter; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var numRanges = reader.ReadInt32(); + var multirange = (T)(object)(typeof(T).IsArray ? new TRange[numRanges] : new List()); + + for (var i = 0; i < numRanges; i++) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + Debug.Assert(length != -1); + + var scope = await reader.BeginNestedRead(async, length, _rangeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + var range = async + ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : _rangeConverter.Read(reader); + + if (typeof(T).IsArray) + multirange[i] = range; + else + multirange.Add(range); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + + return multirange; + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + { + var arrayPool = ArrayPool<(Size Size, object? WriteState)>.Shared; + var data = arrayPool.Rent(value.Count); + + var totalSize = Size.Create(sizeof(int) + sizeof(int) * value.Count); + var anyWriteState = false; + for (var i = 0; i < value.Count; i++) + { + object? innerState = null; + var rangeSize = _rangeConverter.GetSizeOrDbNull(context.Format, _rangeRequirements.Write, value[i], ref innerState); + anyWriteState = anyWriteState || innerState is not null; + // Ranges should never be NULL. + Debug.Assert(rangeSize.HasValue); + data[i] = new(rangeSize.Value, innerState); + totalSize = totalSize.Combine(rangeSize.Value); + } + + writeState = new WriteState + { + ArrayPool = arrayPool, + Data = new(data, 0, value.Count), + AnyWriteState = anyWriteState + }; + return totalSize; + } + + public override void Write(PgWriter writer, T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) + { + if (writer.Current.WriteState is not WriteState writeState) + throw new InvalidCastException($"Invalid state {writer.Current.WriteState?.GetType().FullName}."); + + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(value.Count); + + var data = writeState.Data.Array!; + for (var i = 0; i < value.Count; i++) + { + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var (size, state) = data[i]; + if (size.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var length = size.Value; + writer.WriteInt32(length); + if (length != -1) + { + using var _ = await writer.BeginNestedWrite(async, _rangeRequirements.Write, length, state, cancellationToken).ConfigureAwait(false); + if (async) + await _rangeConverter.WriteAsync(writer, value[i], cancellationToken).ConfigureAwait(false); + else + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + _rangeConverter.Write(writer, value[i]); + } + } + } + + sealed class WriteState : MultiWriteState + { + } +} diff --git a/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs new file mode 100644 index 0000000000..9050f36f16 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs @@ -0,0 +1,23 @@ +using System.Net; +using System.Net.Sockets; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class IPAddressConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, IPAddress value, ref object? writeState) + => NpgsqlInetConverter.GetSizeImpl(context, value, ref writeState); + + protected override IPAddress ReadCore(PgReader reader) + => NpgsqlInetConverter.ReadImpl(reader, shouldBeCidr: false).Address; + + protected override void WriteCore(PgWriter writer, IPAddress value) + => NpgsqlInetConverter.WriteImpl( + writer, + (value, (byte)(value.AddressFamily == AddressFamily.InterNetwork ? 32 : 128)), + isCidr: false); +} diff --git a/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs new file mode 100644 index 0000000000..dd8aac78bc --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs @@ -0,0 +1,40 @@ +using System; +using System.Diagnostics; +using System.Net.NetworkInformation; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class MacaddrConverter : PgBufferedConverter +{ + readonly bool _macaddr8; + + public MacaddrConverter(bool macaddr8) => _macaddr8 = macaddr8; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = _macaddr8 ? BufferRequirements.Create(Size.CreateUpperBound(8)) : BufferRequirements.CreateFixedSize(6); + return format is DataFormat.Binary; + } + + public override Size GetSize(SizeContext context, PhysicalAddress value, ref object? writeState) + => value.GetAddressBytes().Length; + + protected override PhysicalAddress ReadCore(PgReader reader) + { + var len = reader.CurrentRemaining; + Debug.Assert(len is 6 or 8); + + var bytes = new byte[len]; + reader.Read(bytes); + return new PhysicalAddress(bytes); + } + + protected override void WriteCore(PgWriter writer, PhysicalAddress value) + { + var bytes = value.GetAddressBytes(); + if (!_macaddr8 && bytes.Length is not 6) + throw new ArgumentException("A macaddr value must be 6 bytes long."); + writer.WriteBytes(bytes); + } +} diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs new file mode 100644 index 0000000000..c6d0ab8d88 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs @@ -0,0 +1,22 @@ +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class NpgsqlCidrConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, NpgsqlCidr value, ref object? writeState) + => NpgsqlInetConverter.GetSizeImpl(context, value.Address, ref writeState); + + protected override NpgsqlCidr ReadCore(PgReader reader) + { + var (ip, netmask) = NpgsqlInetConverter.ReadImpl(reader, shouldBeCidr: true); + return new(ip, netmask); + } + + protected override void WriteCore(PgWriter writer, NpgsqlCidr value) + => NpgsqlInetConverter.WriteImpl(writer, (value.Address, value.Netmask), isCidr: true); +} diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs new file mode 100644 index 0000000000..f3af04e80a --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs @@ -0,0 +1,73 @@ +using System; +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class NpgsqlInetConverter : PgBufferedConverter +{ + const byte IPv4 = 2; + const byte IPv6 = 3; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, NpgsqlInet value, ref object? writeState) + => GetSizeImpl(context, value.Address, ref writeState); + + internal static Size GetSizeImpl(SizeContext context, IPAddress ipAddress, ref object? writeState) + => ipAddress.AddressFamily switch + { + AddressFamily.InterNetwork => 8, + AddressFamily.InterNetworkV6 => 20, + _ => throw new InvalidCastException( + $"Can't handle IPAddress with AddressFamily {ipAddress.AddressFamily}, only InterNetwork or InterNetworkV6!") + }; + + protected override NpgsqlInet ReadCore(PgReader reader) + { + var (ip, netmask) = ReadImpl(reader, shouldBeCidr: false); + return new(ip, netmask); + } + + internal static (IPAddress Address, byte Netmask) ReadImpl(PgReader reader, bool shouldBeCidr) + { + _ = reader.ReadByte(); // addressFamily + var mask = reader.ReadByte(); // mask + + var isCidr = reader.ReadByte() == 1; + Debug.Assert(isCidr == shouldBeCidr); + + var numBytes = reader.ReadByte(); + Span bytes = stackalloc byte[numBytes]; + reader.Read(bytes); +#if NETSTANDARD2_0 + return (new IPAddress(bytes.ToArray()), mask); +#else + return (new IPAddress(bytes), mask); +#endif + } + + protected override void WriteCore(PgWriter writer, NpgsqlInet value) + => WriteImpl(writer, (value.Address, value.Netmask), isCidr: false); + + internal static void WriteImpl(PgWriter writer, (IPAddress Address, byte Netmask) value, bool isCidr) + { + writer.WriteByte(value.Address.AddressFamily switch + { + AddressFamily.InterNetwork => IPv4, + AddressFamily.InterNetworkV6 => IPv6, + _ => throw new InvalidCastException( + $"Can't handle IPAddress with AddressFamily {value.Address.AddressFamily}, only InterNetwork or InterNetworkV6!") + }); + + writer.WriteByte(value.Netmask); + writer.WriteByte((byte)(isCidr ? 1 : 0)); // Ignored on server side + var bytes = value.Address.GetAddressBytes(); + writer.WriteByte((byte)bytes.Length); + writer.WriteBytes(bytes); + } +} diff --git a/src/Npgsql/Internal/Converters/NullableConverter.cs b/src/Npgsql/Internal/Converters/NullableConverter.cs new file mode 100644 index 0000000000..292def140a --- /dev/null +++ b/src/Npgsql/Internal/Converters/NullableConverter.cs @@ -0,0 +1,60 @@ +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +// NULL writing is always responsibility of the caller writing the length, so there is not much we do here. +/// Special value converter to be able to use struct converters as System.Nullable converters, it delegates all behavior to the effective converter. +sealed class NullableConverter : PgConverter where T : struct +{ + readonly PgConverter _effectiveConverter; + public NullableConverter(PgConverter effectiveConverter) + : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) + => _effectiveConverter = effectiveConverter; + + protected override bool IsDbNullValue(T? value, ref object? writeState) + => value is null || _effectiveConverter.IsDbNull(value.GetValueOrDefault(), ref writeState); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => _effectiveConverter.CanConvert(format, out bufferRequirements); + + public override T? Read(PgReader reader) + => _effectiveConverter.Read(reader); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => this.ReadAsyncAsNullable(_effectiveConverter, reader, cancellationToken); + + public override Size GetSize(SizeContext context, [DisallowNull]T? value, ref object? writeState) + => _effectiveConverter.GetSize(context, value.GetValueOrDefault(), ref writeState); + + public override void Write(PgWriter writer, T? value) + => _effectiveConverter.Write(writer, value.GetValueOrDefault()); + + public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) + => _effectiveConverter.WriteAsync(writer, value.GetValueOrDefault(), cancellationToken); + + internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => _effectiveConverter.ReadAsObject(async, reader, cancellationToken); + + internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + => _effectiveConverter.WriteAsObject(async, writer, value, cancellationToken); +} + +sealed class NullableConverterResolver : PgComposingConverterResolver where T : struct +{ + public NullableConverterResolver(PgResolverTypeInfo effectiveTypeInfo) + : base(effectiveTypeInfo.PgTypeId, effectiveTypeInfo) { } + + protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; + protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; + + protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) + => new NullableConverter(effectiveResolution.GetConverter()); + + protected override PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectivePgTypeId) + => value is null + ? EffectiveTypeInfo.GetDefaultResolution(expectedEffectivePgTypeId) + : EffectiveTypeInfo.GetResolution(value.GetValueOrDefault(), expectedEffectivePgTypeId); +} diff --git a/src/Npgsql/Internal/Converters/ObjectConverter.cs b/src/Npgsql/Internal/Converters/ObjectConverter.cs new file mode 100644 index 0000000000..568fc32c2b --- /dev/null +++ b/src/Npgsql/Internal/Converters/ObjectConverter.cs @@ -0,0 +1,109 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class ObjectConverter : PgStreamingConverter +{ + readonly PgSerializerOptions _options; + readonly PgTypeId _pgTypeId; + + public ObjectConverter(PgSerializerOptions options, PgTypeId pgTypeId) + : base(customDbNullPredicate: true) + { + _options = options; + _pgTypeId = pgTypeId; + } + + protected override bool IsDbNullValue(object? value, ref object? writeState) + { + if (value is null or DBNull) + return true; + + var typeInfo = GetTypeInfo(value.GetType()); + + object? effectiveState = null; + var converter = typeInfo.GetObjectResolution(value).Converter; + if (converter.IsDbNullAsObject(value, ref effectiveState)) + return true; + + writeState = effectiveState is not null ? new WriteState { TypeInfo = typeInfo, EffectiveState = effectiveState } : typeInfo; + return false; + } + + public override object Read(PgReader reader) => throw new NotSupportedException(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) => throw new NotSupportedException(); + + public override Size GetSize(SizeContext context, object value, ref object? writeState) + { + var (typeInfo, effectiveState) = writeState switch + { + PgTypeInfo info => (info, null), + WriteState state => (state.TypeInfo, state.EffectiveState), + _ => throw new InvalidOperationException("Invalid state") + }; + + // We can call GetDefaultResolution here as validation has already happened in IsDbNullValue. + // And we know it was called due to the writeState being filled. + var converter = typeInfo is PgResolverTypeInfo resolverTypeInfo + ? resolverTypeInfo.GetDefaultResolution(null).Converter + : typeInfo.GetResolution().Converter; + if (typeInfo.GetBufferRequirements(converter, context.Format) is not { } bufferRequirements) + { + ThrowHelper.ThrowNotSupportedException($"Resolved converter '{converter.GetType()}' has to support the {context.Format} format to be compatible."); + return default; + } + + // Fixed size converters won't have a GetSize implementation. + if (bufferRequirements.Write.Kind is SizeKind.Exact) + return bufferRequirements.Write; + + var result = converter.GetSizeAsObject(context, value, ref effectiveState); + if (effectiveState is not null) + { + if (writeState is WriteState state && !ReferenceEquals(state.EffectiveState, effectiveState)) + state.EffectiveState = effectiveState; + else + writeState = new WriteState { TypeInfo = typeInfo, EffectiveState = effectiveState }; + } + + return result; + } + + public override void Write(PgWriter writer, object value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, object value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + var (typeInfo, effectiveState) = writer.Current.WriteState switch + { + PgTypeInfo info => (info, null), + WriteState state => (state.TypeInfo, state.EffectiveState), + _ => throw new InvalidOperationException("Invalid state") + }; + + // We can call GetDefaultResolution here as validation has already happened in IsDbNullValue. + // And we know it was called due to the writeState being filled. + var converter = typeInfo is PgResolverTypeInfo resolverTypeInfo + ? resolverTypeInfo.GetDefaultResolution(null).Converter + : typeInfo.GetResolution().Converter; + var writeRequirement = typeInfo.GetBufferRequirements(converter, DataFormat.Binary)!.Value.Write; + using var _ = await writer.BeginNestedWrite(async, writeRequirement, writer.Current.Size.Value, effectiveState, cancellationToken).ConfigureAwait(false); + await converter.WriteAsObject(async, writer, value, cancellationToken).ConfigureAwait(false); + } + + PgTypeInfo GetTypeInfo(Type type) + => _options.GetTypeInfo(type, _pgTypeId) + ?? throw new NotSupportedException($"Writing values of '{type.FullName}' having DataTypeName '{_options.DatabaseInfo.GetPostgresType(_pgTypeId).DisplayName}' is not supported."); + + sealed class WriteState + { + public required PgTypeInfo TypeInfo { get; init; } + public required object EffectiveState { get; set; } + } +} diff --git a/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs new file mode 100644 index 0000000000..7c78e34a24 --- /dev/null +++ b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +abstract class PolymorphicConverterResolver : PgConverterResolver +{ + protected PolymorphicConverterResolver(PgTypeId pgTypeId) => PgTypeId = pgTypeId; + + protected PgTypeId PgTypeId { get; } + + protected abstract PgConverter Get(Field? field); + + public sealed override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + if (pgTypeId is not null && pgTypeId != PgTypeId) + throw CreateUnsupportedPgTypeIdException(pgTypeId.Value); + + return new(Get(null), PgTypeId); + } + + public sealed override PgConverterResolution? Get(TBase? value, PgTypeId? expectedPgTypeId) + => new(Get(null), PgTypeId); + + public sealed override PgConverterResolution Get(Field field) + { + if (field.PgTypeId != PgTypeId) + throw CreateUnsupportedPgTypeIdException(field.PgTypeId); + + var converter = Get(field); + return new(converter, PgTypeId); + } +} + +// Many ways to achieve strongly typed composition on top of a polymorphic element type. +// Including pushing construction through a GVM visitor pattern on the element handler, +// manual reimplementation of the element logic in the array resolver, and other ways. +// This one however is by far the most lightweight on both the implementation duplication and code bloat axes. +sealed class ArrayPolymorphicConverterResolver : PolymorphicConverterResolver +{ + readonly PgResolverTypeInfo _elemTypeInfo; + readonly Func _elemToArrayConverterFactory; + readonly PgTypeId _elemPgTypeId; + readonly ConcurrentDictionary _converterCache = new(ReferenceEqualityComparer.Instance); + + public ArrayPolymorphicConverterResolver(PgTypeId pgTypeId, PgResolverTypeInfo elemTypeInfo, Func elemToArrayConverterFactory) + : base(pgTypeId) + { + if (elemTypeInfo.PgTypeId is null) + throw new ArgumentException("elemTypeInfo.PgTypeId must be non-null.", nameof(elemTypeInfo)); + + _elemTypeInfo = elemTypeInfo; + _elemToArrayConverterFactory = elemToArrayConverterFactory; + _elemPgTypeId = elemTypeInfo.PgTypeId!.Value; + } + + protected override PgConverter Get(Field? maybeField) + { + var elemResolution = maybeField is { } field + ? _elemTypeInfo.GetResolution(field with { PgTypeId = _elemPgTypeId }) + : _elemTypeInfo.GetDefaultResolution(_elemPgTypeId); + + (Func Factory, PgConverterResolution Resolution) state = (_elemToArrayConverterFactory, elemResolution); + return _converterCache.GetOrAdd(elemResolution.Converter, static (_, state) => state.Factory(state.Resolution), state); + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs b/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs new file mode 100644 index 0000000000..196877ad0e --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/BoolConverter.cs @@ -0,0 +1,13 @@ +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BoolConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(byte)); + return format is DataFormat.Binary; + } + protected override bool ReadCore(PgReader reader) => reader.ReadByte() is not 0; + protected override void WriteCore(PgWriter writer, bool value) => writer.WriteByte((byte)(value ? 1 : 0)); +} diff --git a/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs b/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs new file mode 100644 index 0000000000..f7760f836c --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/ByteaConverters.cs @@ -0,0 +1,155 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +abstract class ByteaConverters : PgStreamingConverter +{ + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => ConvertTo(value).Length; + + public override void Write(PgWriter writer, T value) + => writer.WriteBytes(ConvertTo(value).Span); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteBytesAsync(ConvertTo(value), cancellationToken); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + var bytes = new byte[reader.CurrentRemaining]; + if (async) + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + else + reader.ReadBytes(bytes); + + return ConvertFrom(new(bytes)); + } + + protected abstract Memory ConvertTo(T value); + protected abstract T ConvertFrom(Memory value); +} + +sealed class ArraySegmentByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(ArraySegment value) => value; + protected override ArraySegment ConvertFrom(Memory value) + => MemoryMarshal.TryGetArray(value, out var segment) + ? segment + : throw new UnreachableException("Expected array-backed memory"); +} + +sealed class ArrayByteaConverter : PgStreamingConverter +{ + public override byte[] Read(PgReader reader) + { + var bytes = new byte[reader.CurrentRemaining]; + reader.ReadBytes(bytes); + return bytes; + } + + public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + var bytes = new byte[reader.CurrentRemaining]; + await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); + return bytes; + } + + public override Size GetSize(SizeContext context, byte[] value, ref object? writeState) + => value.Length; + + public override void Write(PgWriter writer, byte[] value) + => writer.WriteBytes(value); + + public override ValueTask WriteAsync(PgWriter writer, byte[] value, CancellationToken cancellationToken = default) + => writer.WriteBytesAsync(value, cancellationToken); +} + +sealed class ReadOnlyMemoryByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(ReadOnlyMemory value) => MemoryMarshal.AsMemory(value); + protected override ReadOnlyMemory ConvertFrom(Memory value) => value; +} + +sealed class MemoryByteaConverter : ByteaConverters> +{ + protected override Memory ConvertTo(Memory value) => value; + protected override Memory ConvertFrom(Memory value) => value; +} + +sealed class StreamByteaConverter : PgStreamingConverter +{ + public override Stream Read(PgReader reader) + => throw new NotSupportedException("Handled by generic stream support in NpgsqlDataReader"); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => throw new NotSupportedException("Handled by generic stream support in NpgsqlDataReader"); + + public override Size GetSize(SizeContext context, Stream value, ref object? writeState) + { + if (value.CanSeek) + return checked((int)(value.Length - value.Position)); + + var memoryStream = new MemoryStream(); + value.CopyTo(memoryStream); + writeState = memoryStream; + return checked((int)memoryStream.Length); + } + + public override void Write(PgWriter writer, Stream value) + { + if (writer.Current.WriteState is not null) + { + if (!((MemoryStream)writer.Current.WriteState!).TryGetBuffer(out var writeStateSegment)) + throw new InvalidOperationException(); + + writer.WriteBytes(writeStateSegment.AsSpan()); + return; + } + + // Non-derived MemoryStream fast path + if (value is MemoryStream memoryStream && memoryStream.TryGetBuffer(out var segment)) + writer.WriteBytes(segment.AsSpan((int)value.Position)); + else + value.CopyTo(writer.GetStream()); + } + + public override ValueTask WriteAsync(PgWriter writer, Stream value, CancellationToken cancellationToken = default) + { + if (writer.Current.WriteState is not null) + { + if (!((MemoryStream)writer.Current.WriteState!).TryGetBuffer(out var writeStateSegment)) + throw new InvalidOperationException(); + + return writer.WriteBytesAsync(writeStateSegment.AsMemory(), cancellationToken); + } + + // Non-derived MemoryStream fast path + if (value is MemoryStream memoryStream && memoryStream.TryGetBuffer(out var segment)) + { + return writer.WriteBytesAsync(segment.AsMemory((int)value.Position), cancellationToken); + } + else + { +#if NETSTANDARD2_0 + return new ValueTask(value.CopyToAsync(writer.GetStream())); +#else + return new ValueTask(value.CopyToAsync(writer.GetStream(), cancellationToken)); +#endif + } + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs new file mode 100644 index 0000000000..74a56d06ae --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DoubleConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(double)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadDouble()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteDouble(double.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadDouble(); + if (typeof(float) == typeof(T)) + return (T)(object)value; + if (typeof(double) == typeof(T)) + return (T)(object)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(float) == typeof(T)) + writer.WriteDouble((float)(object)value!); + else if (typeof(double) == typeof(T)) + writer.WriteDouble((double)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs new file mode 100644 index 0000000000..596deedfce --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs @@ -0,0 +1,70 @@ +using System; +using System.Buffers.Binary; +using System.Runtime.InteropServices; + +namespace Npgsql.Internal.Converters; + +sealed class GuidUuidConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(16 * sizeof(byte)); + return format is DataFormat.Binary; + } + protected override Guid ReadCore(PgReader reader) + { +#if NET8_0_OR_GREATER + return new Guid(reader.ReadBytes(16).FirstSpan, bigEndian: true); +#else + return new GuidRaw + { + Data1 = reader.ReadInt32(), + Data2 = reader.ReadInt16(), + Data3 = reader.ReadInt16(), + Data4 = BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadInt64()) : reader.ReadInt64() + }.Value; +#endif + } + + protected override void WriteCore(PgWriter writer, Guid value) + { +#if NET8_0_OR_GREATER + Span bytes = stackalloc byte[16]; + value.TryWriteBytes(bytes, bigEndian: true, out _); + writer.WriteBytes(bytes); +#else + var raw = new GuidRaw(value); + + writer.WriteInt32(raw.Data1); + writer.WriteInt16(raw.Data2); + writer.WriteInt16(raw.Data3); + writer.WriteInt64(BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(raw.Data4) : raw.Data4); +#endif + } + +#if !NET8_0_OR_GREATER + // The following table shows .NET GUID vs Postgres UUID (RFC 4122) layouts. + // + // Note that the first fields are converted from/to native endianness (handled by the Read* + // and Write* methods), while the last field is always read/written in big-endian format. + // + // We're reverting endianness on little endian systems to get it into big endian format. + // + // | Bits | Bytes | Name | Endianness (GUID) | Endianness (RFC 4122) | + // | ---- | ----- | ----- | ----------------- | --------------------- | + // | 32 | 4 | Data1 | Native | Big | + // | 16 | 2 | Data2 | Native | Big | + // | 16 | 2 | Data3 | Native | Big | + // | 64 | 8 | Data4 | Big | Big | + [StructLayout(LayoutKind.Explicit)] + struct GuidRaw + { + [FieldOffset(0)] public Guid Value; + [FieldOffset(0)] public int Data1; + [FieldOffset(4)] public short Data2; + [FieldOffset(6)] public short Data3; + [FieldOffset(8)] public long Data4; + public GuidRaw(Guid value) : this() => Value = value; + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs new file mode 100644 index 0000000000..e54658d925 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs @@ -0,0 +1,70 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int2Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(short)); + return format is DataFormat.Binary; + } +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt16()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt16(short.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt16(); + if (typeof(short) == typeof(T)) + return (T)(object)value; + if (typeof(int) == typeof(T)) + return (T)(object)(int)value; + if (typeof(long) == typeof(T)) + return (T)(object)(long)value; + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt16((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt16(checked((short)(int)(object)value!)); + else if (typeof(long) == typeof(T)) + writer.WriteInt16(checked((short)(long)(object)value!)); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt16((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt16((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt16(checked((short)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt16(checked((short)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt16((short)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs new file mode 100644 index 0000000000..1831ca9b1e --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs @@ -0,0 +1,71 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int4Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt32()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt32(int.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt32(); + if (typeof(short) == typeof(T)) + return (T)(object)checked((short)value); + if (typeof(int) == typeof(T)) + return (T)(object)value; + if (typeof(long) == typeof(T)) + return (T)(object)(long)value; + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt32((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt32((int)(object)value!); + else if (typeof(long) == typeof(T)) + writer.WriteInt32(checked((int)(long)(object)value!)); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt32((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt32((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt32(checked((int)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt32(checked((int)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt32((int)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs new file mode 100644 index 0000000000..b422816244 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs @@ -0,0 +1,72 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class Int8Converter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt64()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(long.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadInt64(); + if (typeof(long) == typeof(T)) + return (T)(object)value; + + if (typeof(short) == typeof(T)) + return (T)(object)checked((short)value); + if (typeof(int) == typeof(T)) + return (T)(object)checked((int)value); + + if (typeof(byte) == typeof(T)) + return (T)(object)checked((byte)value); + if (typeof(sbyte) == typeof(T)) + return (T)(object)checked((sbyte)value); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + if (typeof(decimal) == typeof(T)) + return (T)(object)(decimal)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(short) == typeof(T)) + writer.WriteInt64((short)(object)value!); + else if (typeof(int) == typeof(T)) + writer.WriteInt64((int)(object)value!); + else if (typeof(long) == typeof(T)) + writer.WriteInt64((long)(object)value!); + + else if (typeof(byte) == typeof(T)) + writer.WriteInt64((byte)(object)value!); + else if (typeof(sbyte) == typeof(T)) + writer.WriteInt64((sbyte)(object)value!); + + else if (typeof(float) == typeof(T)) + writer.WriteInt64(checked((long)(float)(object)value!)); + else if (typeof(double) == typeof(T)) + writer.WriteInt64(checked((long)(double)(object)value!)); + else if (typeof(decimal) == typeof(T)) + writer.WriteInt64((long)(decimal)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs new file mode 100644 index 0000000000..c43e90a1f7 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs @@ -0,0 +1,262 @@ +using System; +using System.Buffers; +using System.Numerics; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class BigIntegerNumericConverter : PgStreamingConverter +{ + const int StackAllocByteThreshold = 64 * sizeof(uint); + + public override BigInteger Read(PgReader reader) + { + var digitCount = reader.ReadInt16(); + short[]? digitsFromPool = null; + var digits = (digitCount <= StackAllocByteThreshold / sizeof(short) + ? stackalloc short[StackAllocByteThreshold / sizeof(short)] + : (digitsFromPool = ArrayPool.Shared.Rent(digitCount)).AsSpan()).Slice(0, digitCount); + + var value = ConvertTo(NumericConverter.Read(reader, digits)); + + if (digitsFromPool is not null) + ArrayPool.Shared.Return(digitsFromPool); + + return value; + } + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + { + // If we don't need a read and can read buffered we delegate to our sync read method which won't do IO in such a case. + if (!reader.ShouldBuffer(reader.CurrentRemaining)) + Read(reader); + + return AsyncCore(reader, cancellationToken); + + static async ValueTask AsyncCore(PgReader reader, CancellationToken cancellationToken) + { + await reader.BufferAsync(PgNumeric.GetByteCount(0), cancellationToken).ConfigureAwait(false); + var digitCount = reader.ReadInt16(); + var digits = new ArraySegment(ArrayPool.Shared.Rent(digitCount), 0, digitCount); + var value = ConvertTo(await NumericConverter.ReadAsync(reader, digits, cancellationToken).ConfigureAwait(false)); + + ArrayPool.Shared.Return(digits.Array!); + + return value; + } + } + + public override Size GetSize(SizeContext context, BigInteger value, ref object? writeState) => + PgNumeric.GetByteCount(PgNumeric.GetDigitCount(value)); + + public override void Write(PgWriter writer, BigInteger value) + { + // We don't know how many digits we need so we allocate a decent chunk of stack for the builder to use. + // If it's not enough for the builder will do a heap allocation (for decimal it's always enough). + Span destination = stackalloc short[StackAllocByteThreshold / sizeof(short)]; + var numeric = ConvertFrom(value, destination); + NumericConverter.Write(writer, numeric); + } + + public override ValueTask WriteAsync(PgWriter writer, BigInteger value, CancellationToken cancellationToken = default) + { + if (writer.ShouldFlush(writer.Current.Size)) + return AsyncCore(writer, value, cancellationToken); + + // If we don't need a flush and can write buffered we delegate to our sync write method which won't flush in such a case. + Write(writer, value); + return new(); + + static async ValueTask AsyncCore(PgWriter writer, BigInteger value, CancellationToken cancellationToken) + { + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + var numeric = ConvertFrom(value, Array.Empty()).Build(); + await NumericConverter.WriteAsync(writer, numeric, cancellationToken).ConfigureAwait(false); + } + } + + static PgNumeric.Builder ConvertFrom(BigInteger value, Span destination) => new(value, destination); + static BigInteger ConvertTo(in PgNumeric.Builder numeric) => numeric.ToBigInteger(); + static BigInteger ConvertTo(in PgNumeric numeric) => numeric.ToBigInteger(); +} + +sealed class DecimalNumericConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#else + where T : notnull +#endif +{ + const int StackAllocByteThreshold = 64 * sizeof(uint); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + // This upper bound would already cause an overflow exception in the builder, no need to do + 1. + bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(NumericConverter.DecimalBasedMaxByteCount)); + return format is DataFormat.Binary; + } + + protected override T ReadCore(PgReader reader) + { + var digitCount = reader.ReadInt16(); + var digits = stackalloc short[StackAllocByteThreshold / sizeof(short)].Slice(0, digitCount);; + var value = ConvertTo(NumericConverter.Read(reader, digits)); + return value; + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) => + PgNumeric.GetByteCount(default(T) switch + { + _ when typeof(decimal) == typeof(T) => PgNumeric.GetDigitCount((decimal)(object)value), + _ when typeof(short) == typeof(T) => PgNumeric.GetDigitCount((decimal)(short)(object)value), + _ when typeof(int) == typeof(T) => PgNumeric.GetDigitCount((decimal)(int)(object)value), + _ when typeof(long) == typeof(T) => PgNumeric.GetDigitCount((decimal)(long)(object)value), + _ when typeof(byte) == typeof(T) => PgNumeric.GetDigitCount((decimal)(byte)(object)value), + _ when typeof(sbyte) == typeof(T) => PgNumeric.GetDigitCount((decimal)(sbyte)(object)value), + _ when typeof(float) == typeof(T) => PgNumeric.GetDigitCount((decimal)(float)(object)value), + _ when typeof(double) == typeof(T) => PgNumeric.GetDigitCount((decimal)(double)(object)value), + _ => throw new NotSupportedException() + }); + + protected override void WriteCore(PgWriter writer, T value) + { + // We don't know how many digits we need so we allocate enough for the builder to use. + Span destination = stackalloc short[PgNumeric.Builder.MaxDecimalNumericDigits]; + var numeric = ConvertFrom(value, destination); + NumericConverter.Write(writer, numeric); + } + + static PgNumeric.Builder ConvertFrom(T value, Span destination) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return new PgNumeric.Builder((decimal)(short)(object)value!, destination); + if (typeof(int) == typeof(T)) + return new PgNumeric.Builder((decimal)(int)(object)value!, destination); + if (typeof(long) == typeof(T)) + return new PgNumeric.Builder((decimal)(long)(object)value!, destination); + + if (typeof(byte) == typeof(T)) + return new PgNumeric.Builder((decimal)(byte)(object)value!, destination); + if (typeof(sbyte) == typeof(T)) + return new PgNumeric.Builder((decimal)(sbyte)(object)value!, destination); + + if (typeof(float) == typeof(T)) + return new PgNumeric.Builder((decimal)(float)(object)value!, destination); + if (typeof(double) == typeof(T)) + return new PgNumeric.Builder((decimal)(double)(object)value!, destination); + if (typeof(decimal) == typeof(T)) + return new PgNumeric.Builder((decimal)(object)value!, destination); + + throw new NotSupportedException(); +#else + return new PgNumeric.Builder(decimal.CreateChecked(value), destination); +#endif + } + + static T ConvertTo(in PgNumeric.Builder numeric) + { +#if !NET7_0_OR_GREATER + if (typeof(short) == typeof(T)) + return (T)(object)(short)numeric.ToDecimal(); + if (typeof(int) == typeof(T)) + return (T)(object)(int)numeric.ToDecimal(); + if (typeof(long) == typeof(T)) + return (T)(object)(long)numeric.ToDecimal(); + + if (typeof(byte) == typeof(T)) + return (T)(object)(byte)numeric.ToDecimal(); + if (typeof(sbyte) == typeof(T)) + return (T)(object)(sbyte)numeric.ToDecimal(); + + if (typeof(float) == typeof(T)) + return (T)(object)(float)numeric.ToDecimal(); + if (typeof(double) == typeof(T)) + return (T)(object)(double)numeric.ToDecimal(); + if (typeof(decimal) == typeof(T)) + return (T)(object)numeric.ToDecimal(); + + throw new NotSupportedException(); +#else + return T.CreateChecked(numeric.ToDecimal()); +#endif + } +} + +static class NumericConverter +{ + public static int DecimalBasedMaxByteCount = PgNumeric.GetByteCount(PgNumeric.Builder.MaxDecimalNumericDigits); + + public static PgNumeric.Builder Read(PgReader reader, Span digits) + { + var remainingStructureSize = PgNumeric.GetByteCount(0) - sizeof(short); + if (reader.ShouldBuffer(remainingStructureSize)) + reader.Buffer(remainingStructureSize); + var weight = reader.ReadInt16(); + var sign = reader.ReadInt16(); + var scale = reader.ReadInt16(); + foreach (ref var digit in digits) + { + if (reader.ShouldBuffer(sizeof(short))) + reader.Buffer(sizeof(short)); + digit = reader.ReadInt16(); + } + + return new PgNumeric.Builder(digits, weight, sign, scale); + } + + public static async ValueTask ReadAsync(PgReader reader, ArraySegment digits, CancellationToken cancellationToken) + { + var remainingStructureSize = PgNumeric.GetByteCount(0) - sizeof(short); + if (reader.ShouldBuffer(remainingStructureSize)) + await reader.BufferAsync(remainingStructureSize, cancellationToken).ConfigureAwait(false); + var weight = reader.ReadInt16(); + var sign = reader.ReadInt16(); + var scale = reader.ReadInt16(); + var array = digits.Array!; + for (var i = digits.Offset; i < array.Length; i++) + { + if (reader.ShouldBuffer(sizeof(short))) + await reader.BufferAsync(sizeof(short), cancellationToken).ConfigureAwait(false); + array[i] = reader.ReadInt16(); + } + + return new PgNumeric.Builder(digits, weight, sign, scale).Build(); + } + + public static void Write(PgWriter writer, PgNumeric.Builder numeric) + { + if (writer.ShouldFlush(PgNumeric.GetByteCount(0))) + writer.Flush(); + writer.WriteInt16((short)numeric.Digits.Length); + writer.WriteInt16(numeric.Weight); + writer.WriteInt16(numeric.Sign); + writer.WriteInt16(numeric.Scale); + + foreach (var digit in numeric.Digits) + { + if (writer.ShouldFlush(sizeof(short))) + writer.Flush(); + writer.WriteInt16(digit); + } + } + + public static async ValueTask WriteAsync(PgWriter writer, PgNumeric numeric, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(PgNumeric.GetByteCount(0))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt16((short)numeric.Digits.Count); + writer.WriteInt16(numeric.Weight); + writer.WriteInt16(numeric.Sign); + writer.WriteInt16(numeric.Scale); + + foreach (var digit in numeric.Digits) + { + if (writer.ShouldFlush(sizeof(short))) + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + writer.WriteInt16(digit); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs new file mode 100644 index 0000000000..495e2a8aba --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs @@ -0,0 +1,104 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Npgsql.Internal.Converters; + +readonly struct PgMoney +{ + const int DecimalBits = 4; + const int MoneyScale = 2; + readonly long _value; + + public PgMoney(long value) => _value = value; + + public PgMoney(decimal value) + { + if (value is < -92233720368547758.08M or > 92233720368547758.07M) + throw new OverflowException($"The supplied value '{value}' is outside the range for a PostgreSQL money value."); + + // No-op if scale was already 2 or less. + value = decimal.Round(value, MoneyScale, MidpointRounding.AwayFromZero); + + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + + var money = (long)bits[1] << 32 | bits[0]; + if (value < 0) + money = -money; + + // If we were less than scale 2, multiply. + _value = (MoneyScale - scale) switch + { + 1 => money * 10, + 2 => money * 100, + _ => money + }; + } + + public long GetValue() => _value; + + public decimal ToDecimal() + { + var result = new decimal(_value); + var scaleFactor = new decimal(1, 0, 0, false, MoneyScale); + result *= scaleFactor; + return result; + } + + static void GetDecimalBits(decimal value, Span destination, out short scale) + { + Debug.Assert(destination.Length >= DecimalBits); + +#if NETSTANDARD + var raw = new DecimalRaw(value); + destination[0] = raw.Low; + destination[1] = raw.Mid; + destination[2] = raw.High; + destination[3] = (uint)raw.Flags; + scale = raw.Scale; +#else + decimal.GetBits(value, MemoryMarshal.Cast(destination)); +#endif +#if NET7_0_OR_GREATER + scale = value.Scale; +#else + scale = (byte)(destination[3] >> 16); +#endif + } + +#if NETSTANDARD + // Zero-alloc access to the decimal bits on netstandard. + [StructLayout(LayoutKind.Explicit)] + readonly struct DecimalRaw + { + const int ScaleMask = 0x00FF0000; + const int ScaleShift = 16; + + // Do not change the order in which these fields are declared. It + // should be same as in the System.Decimal.DecCalc struct. + [FieldOffset(0)] + readonly decimal _value; + [FieldOffset(0)] + readonly int _flags; + [FieldOffset(4)] + readonly uint _high; + [FieldOffset(8)] + readonly ulong _low64; + + // Convenience aliased fields but their usage needs to take endianness into account. + [FieldOffset(8)] + readonly uint _low; + [FieldOffset(12)] + readonly uint _mid; + + public DecimalRaw(decimal value) : this() => _value = value; + + public uint High => _high; + public uint Mid => BitConverter.IsLittleEndian ? _mid : _low; + public uint Low => BitConverter.IsLittleEndian ? _low : _mid; + public int Flags => _flags; + public short Scale => (short)((_flags & ScaleMask) >> ScaleShift); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs new file mode 100644 index 0000000000..1691170d34 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs @@ -0,0 +1,462 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.InteropServices; +using static Npgsql.Internal.Converters.PgNumeric.Builder; + +namespace Npgsql.Internal.Converters; + +readonly struct PgNumeric +{ + // numeric digit count + weight + sign + scale + const int StructureByteCount = 4 * sizeof(short); + const int DecimalBits = 4; + const int StackAllocByteThreshold = 64 * sizeof(uint); + + readonly ushort _sign; + + public PgNumeric(ArraySegment digits, short weight, short sign, short scale) + { + Digits = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + /// Big endian array of numeric digits + public ArraySegment Digits { get; } + public short Weight { get; } + public short Sign => (short)_sign; + public short Scale { get; } + + public int GetByteCount() => GetByteCount(Digits.Count); + public static int GetByteCount(int digitCount) => StructureByteCount + digitCount * sizeof(short); + + static void GetDecimalBits(decimal value, Span destination, out short scale) + { + Debug.Assert(destination.Length >= DecimalBits); + +#if NETSTANDARD + var raw = new DecimalRaw(value); + destination[0] = raw.Low; + destination[1] = raw.Mid; + destination[2] = raw.High; + destination[3] = (uint)raw.Flags; + scale = raw.Scale; +#else + decimal.GetBits(value, MemoryMarshal.Cast(destination)); +#endif +#if NET7_0_OR_GREATER + scale = value.Scale; +#else + scale = (byte)(destination[3] >> 16); +#endif + } + + public static int GetDigitCount(decimal value) + { + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + bits = bits.Slice(0, DecimalBits - 1); + return GetDigitCountCore(bits, scale); + } + + public static int GetDigitCount(BigInteger value) + { +# if NETSTANDARD2_0 + var bits = value.ToByteArray().AsSpan(); + // Detect the presence of a padding byte and slice it away (as we don't have isUnsigned: true overloads on ns2.0). + if (value.Sign == 1 && bits.Length > 2 && (bits[bits.Length - 2] & 0x80) != 0 && bits[bits.Length - 1] == 0) + bits = bits.Slice(0, bits.Length - 1); + var uintRoundedByteCount = (bits.Length + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +# else + var absValue = BigInteger.Abs(value); // isUnsigned: true fails for negative values. + var uintRoundedByteCount = (absValue.GetByteCount(isUnsigned: true) + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +#endif + byte[]? uintRoundedBitsFromPool = null; + var uintRoundedBits = (uintRoundedByteCount <= StackAllocByteThreshold + ? stackalloc byte[StackAllocByteThreshold] + : uintRoundedBitsFromPool = ArrayPool.Shared.Rent(uintRoundedByteCount) + ).Slice(0, uintRoundedByteCount); + // Fill the last uint worth of bytes as it may only be partially written to. + uintRoundedBits.Slice(uintRoundedBits.Length - sizeof(uint)).Fill(0); + +#if NETSTANDARD2_0 + bits.CopyTo(uintRoundedBits); +#else + var success = absValue.TryWriteBytes(uintRoundedBits, out _, isUnsigned: true); + Debug.Assert(success); +#endif + var uintBits = MemoryMarshal.Cast(uintRoundedBits); + if (!BitConverter.IsLittleEndian) + for (var i = 0; i < uintBits.Length; i++) + uintBits[i] = BinaryPrimitives.ReverseEndianness(uintBits[i]); + + var size = GetDigitCountCore(uintBits, scale: 0); + + if (uintRoundedBitsFromPool is not null) + ArrayPool.Shared.Return(uintRoundedBitsFromPool); + + return size; + } + + public decimal ToDecimal() => Builder.ToDecimal(Scale, Weight, _sign, Digits); + public BigInteger ToBigInteger() => Builder.ToBigInteger(Weight, _sign, Digits); + + public readonly ref struct Builder + { + const ushort SignPositive = 0x0000; + const ushort SignNegative = 0x4000; + const ushort SignNan = 0xC000; + const ushort SignPinf = 0xD000; + const ushort SignNinf = 0xF000; + + const uint NumericBase = 10000; + const int NumericBaseLog10 = 4; // log10(10000) + + internal const int MaxDecimalNumericDigits = 8; + + // Fast access for 10^n where n is 0-9 + static ReadOnlySpan UIntPowers10 => new uint[] { + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000 + }; + + const int MaxUInt32Scale = 9; + const int MaxUInt16Scale = 4; + + public short Weight { get; } + + readonly ushort _sign; + public short Sign => (short)_sign; + + public short Scale { get; } + public Span Digits { get; } + readonly short[]? _digitsArray; + + public Builder(Span digits, short weight, short sign, short scale) + { + Digits = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + public Builder(short[] digits, short weight, short sign, short scale) + { + Digits = _digitsArray = digits; + Weight = weight; + _sign = (ushort)sign; + Scale = scale; + } + + [Conditional("DEBUG")] + static void AssertInvariants() + { + Debug.Assert(UIntPowers10.Length >= NumericBaseLog10); + Debug.Assert(NumericBase < short.MaxValue); + } + + static void Create(ref short[]? digitsArray, ref Span destination, scoped Span bits, short scale, out short weight, out int digitCount) + { + AssertInvariants(); + digitCount = 0; + var digitWeight = -scale / NumericBaseLog10 - 1; + + var bitsUpperBound = (bits.Length * (MaxUInt32Scale + 1) + MaxUInt16Scale - 1) / MaxUInt16Scale + 1; + if (bitsUpperBound > destination.Length) + destination = digitsArray = new short[bitsUpperBound]; + + // When the given scale does not sit on a numeric digit boundary we divide once by the remainder power of 10 instead of the base. + // As a result the quotient is aligned to a digit boundary, we must then scale up the remainder by the missed power of 10 to compensate. + var scaleRemainder = scale % NumericBaseLog10; + if (scaleRemainder > 0 && DivideInPlace(bits, UIntPowers10[scaleRemainder], out var remainder) && remainder != 0) + { + remainder *= UIntPowers10[NumericBaseLog10 - scaleRemainder]; + digitWeight--; + destination[destination.Length - 1 - digitCount++] = (short)remainder; + } + while (DivideInPlace(bits, NumericBase, out remainder)) + { + // Initial zero remainders are skipped as these present trailing zero digits, which should not be stored. + if (digitCount == 0 && remainder == 0) + digitWeight++; + else + // We store the results starting from the end so the final digits end up in big endian. + destination[destination.Length - 1 - digitCount++] = (short)remainder; + } + + weight = (short)(digitWeight + digitCount); + + } + + public Builder(decimal value, Span destination) + { + Span bits = stackalloc uint[DecimalBits]; + GetDecimalBits(value, bits, out var scale); + bits = bits.Slice(0, DecimalBits - 1); + + Create(ref _digitsArray, ref destination, bits, scale, out var weight, out var digitCount); + Digits = destination.Slice(destination.Length - digitCount); + Weight = weight; + _sign = value < 0 ? SignNegative : SignPositive; + Scale = scale; + } + + /// + /// + /// + /// + /// If the destination ends up being too small the builder allocates instead + public Builder(BigInteger value, Span destination) + { +# if NETSTANDARD2_0 + var bits = value.ToByteArray().AsSpan(); + // Detect the presence of a padding byte and slice it away (as we don't have isUnsigned: true overloads on ns2.0). + if (value.Sign == 1 && bits.Length > 2 && (bits[bits.Length - 2] & 0x80) != 0 && bits[bits.Length - 1] == 0) + bits = bits.Slice(0, bits.Length - 1); + var uintRoundedByteCount = (bits.Length + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +# else + var absValue = BigInteger.Abs(value); // isUnsigned: true fails for negative values. + var uintRoundedByteCount = (absValue.GetByteCount(isUnsigned: true) + (sizeof(uint) - 1)) / sizeof(uint) * sizeof(uint); +#endif + byte[]? uintRoundedBitsFromPool = null; + var uintRoundedBits = (uintRoundedByteCount <= StackAllocByteThreshold + ? stackalloc byte[StackAllocByteThreshold] + : uintRoundedBitsFromPool = ArrayPool.Shared.Rent(uintRoundedByteCount) + ).Slice(0, uintRoundedByteCount); + // Fill the last uint worth of bytes as it may only be partially written to. + uintRoundedBits.Slice(uintRoundedBits.Length - sizeof(uint)).Fill(0); + +#if NETSTANDARD2_0 + bits.CopyTo(uintRoundedBits); +#else + var success = absValue.TryWriteBytes(uintRoundedBits, out _, isUnsigned: true); + Debug.Assert(success); +#endif + var uintBits = MemoryMarshal.Cast(uintRoundedBits); + + // Our calculations are all done in little endian, meaning the least significant *uint* is first, just like in BigInteger. + // The bytes comprising every individual uint should still be converted to big endian though. + // As a result an array of bytes like [ 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8 ] should become [ 0x4, 0x3, 0x2, 0x1, 0x8, 0x7, 0x6, 0x5 ]. + if (!BitConverter.IsLittleEndian) + for (var i = 0; i < uintBits.Length; i++) + uintBits[i] = BinaryPrimitives.ReverseEndianness(uintBits[i]); + + Create(ref _digitsArray, ref destination, uintBits, scale: 0, out var weight, out var digitCount); + Digits = destination.Slice(destination.Length - digitCount); + Weight = weight; + _sign = value < 0 ? SignNegative : SignPositive; + Scale = 0; + + if (uintRoundedBitsFromPool is not null) + ArrayPool.Shared.Return(uintRoundedBitsFromPool); + } + + public PgNumeric Build() + { + var digitsArray = _digitsArray is not null + ? new ArraySegment(_digitsArray, _digitsArray.Length - Digits.Length, Digits.Length) + : new ArraySegment(Digits.ToArray()); + + return new(digitsArray, Weight, Sign, Scale); + } + + public decimal ToDecimal() => ToDecimal(Scale, Weight, _sign, Digits); + public BigInteger ToBigInteger() => ToBigInteger(Weight, _sign, Digits); + + int DigitCount => Digits.Length; + + /// + /// + /// + /// + /// + /// + /// Whether the input consists of any non zero bits + static bool DivideInPlace(Span left, uint right, out uint remainder) + => Divide(left, right, left, out remainder); + + /// Adapted from BigInteger, to allow us to operate directly on stack allocated bits + static bool Divide(ReadOnlySpan left, uint right, Span quotient, out uint remainder) + { + Debug.Assert(quotient.Length == left.Length); + + // Executes the division for one big and one 32-bit integer. + // Thus, we've similar code than below, but there is no loop for + // processing the 32-bit integer, since it's a single element. + + var carry = 0UL; + + var nonZeroInput = false; + for (var i = left.Length - 1; i >= 0; i--) + { + var value = (carry << 32) | left[i]; + nonZeroInput = nonZeroInput || value != 0; + var digit = value / right; + quotient[i] = (uint)digit; + carry = value - digit * right; + } + remainder = (uint)carry; + + return nonZeroInput; + } + + internal static int GetDigitCountCore(Span bits, int scale) + { + AssertInvariants(); + // When a fractional result is expected we must send two numeric digits. + // When the given scale does not sit on a numeric digit boundary- + // we divide once by the remaining power of 10 instead of the full base to align things. + var baseLogRemainder = scale % NumericBaseLog10; + var den = baseLogRemainder > 0 ? UIntPowers10[baseLogRemainder] : NumericBase; + var digits = 0; + while (DivideInPlace(bits, den, out var remainder)) + { + den = NumericBase; + // Initial zero remainders are skipped as these present trailing zero digits, which should not be transmitted. + if (digits != 0 || remainder != 0) + digits++; + } + + return digits; + } + + internal static decimal ToDecimal(short scale, short weight, ushort sign, Span digits) + { + const int MaxUIntScale = 9; + const int MaxDecimalScale = 28; + + var digitCount = digits.Length; + if (digitCount > MaxDecimalNumericDigits) + throw new OverflowException("Numeric value does not fit in a System.Decimal"); + + if (Math.Abs(scale) > MaxDecimalScale) + throw new OverflowException("Numeric value does not fit in a System.Decimal"); + + var scaleFactor = new decimal(1, 0, 0, false, (byte)(scale > 0 ? scale : 0)); + if (digitCount == 0) + return sign switch + { + SignPositive or SignNegative => decimal.Zero * scaleFactor, + SignNan => throw new InvalidCastException("Numeric NaN not supported by System.Decimal"), + SignPinf => throw new InvalidCastException("Numeric Infinity not supported by System.Decimal"), + SignNinf => throw new InvalidCastException("Numeric -Infinity not supported by System.Decimal"), + _ => throw new ArgumentOutOfRangeException() + }; + + var numericBase = new decimal(NumericBase); + var result = decimal.Zero; + for (var i = 0; i < digitCount - 1; i++) + { + result *= numericBase; + result += digits[i]; + } + + var digitScale = (weight + 1 - digitCount) * NumericBaseLog10; + var scaleDifference = scale < 0 ? digitScale : digitScale + scale; + + var digit = digits[digitCount - 1]; + if (digitCount == MaxDecimalNumericDigits) + { + // On the max group we adjust the base based on the scale difference, to prevent overflow for valid values. + var pow = UIntPowers10[-scaleDifference]; + result *= numericBase / pow; + result += new decimal(digit / pow); + } + else + { + result *= numericBase; + result += digit; + + if (scaleDifference < 0) + result /= UIntPowers10[-scaleDifference]; + else + while (scaleDifference > 0) + { + var scaleChunk = Math.Min(MaxUIntScale, scaleDifference); + result *= UIntPowers10[scaleChunk]; + scaleDifference -= scaleChunk; + } + } + + result *= scaleFactor; + return sign == SignNegative ? -result : result; + } + + internal static BigInteger ToBigInteger(short weight, ushort sign, Span digits) + { + var digitCount = digits.Length; + if (digitCount == 0) + return sign switch + { + SignPositive or SignNegative => BigInteger.Zero, + SignNan => throw new InvalidCastException("Numeric NaN not supported by BigInteger"), + SignPinf => throw new InvalidCastException("Numeric Infinity not supported by BigInteger"), + SignNinf => throw new InvalidCastException("Numeric -Infinity not supported by BigInteger"), + _ => throw new ArgumentOutOfRangeException() + }; + + var digitWeight = weight + 1 - digitCount; + if (digitWeight < 0) + throw new InvalidCastException("Numeric value with non-zero fractional digits not supported by BigInteger"); + + var numericBase = new BigInteger(NumericBase); + var result = BigInteger.Zero; + foreach (var digit in digits) + { + result *= numericBase; + result += new BigInteger(digit); + } + + var exponentCorrection = BigInteger.Pow(numericBase, digitWeight); + result *= exponentCorrection; + return sign == SignNegative ? -result : result; + } + } + +#if NETSTANDARD + // Zero-alloc access to the decimal bits on netstandard. + [StructLayout(LayoutKind.Explicit)] + readonly struct DecimalRaw + { + const int ScaleMask = 0x00FF0000; + const int ScaleShift = 16; + + // Do not change the order in which these fields are declared. It + // should be same as in the System.Decimal.DecCalc struct. + [FieldOffset(0)] + readonly decimal _value; + [FieldOffset(0)] + readonly int _flags; + [FieldOffset(4)] + readonly uint _high; + [FieldOffset(8)] + readonly ulong _low64; + + // Convenience aliased fields but their usage needs to take endianness into account. + [FieldOffset(8)] + readonly uint _low; + [FieldOffset(12)] + readonly uint _mid; + + public DecimalRaw(decimal value) : this() => _value = value; + + public uint High => _high; + public uint Mid => BitConverter.IsLittleEndian ? _mid : _low; + public uint Low => BitConverter.IsLittleEndian ? _low : _mid; + public int Flags => _flags; + public short Scale => (short)((_flags & ScaleMask) >> ScaleShift); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs new file mode 100644 index 0000000000..b47e641aa5 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs @@ -0,0 +1,43 @@ +using System; +using System.Numerics; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class RealConverter : PgBufferedConverter +#if NET7_0_OR_GREATER + where T : INumberBase +#endif +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(float)); + return format is DataFormat.Binary; + } + +#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadFloat()); + protected override void WriteCore(PgWriter writer, T value) => writer.WriteFloat(float.CreateChecked(value)); +#else + protected override T ReadCore(PgReader reader) + { + var value = reader.ReadFloat(); + if (typeof(float) == typeof(T)) + return (T)(object)value; + if (typeof(double) == typeof(T)) + return (T)(object)(double)value; + + throw new NotSupportedException(); + } + + protected override void WriteCore(PgWriter writer, T value) + { + if (typeof(float) == typeof(T)) + writer.WriteFloat((float)(object)value!); + else if (typeof(double) == typeof(T)) + writer.WriteFloat((float)(double)(object)value!); + else + throw new NotSupportedException(); + } +#endif +} diff --git a/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs new file mode 100644 index 0000000000..d13e6f14e6 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs @@ -0,0 +1,351 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +abstract class StringBasedTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + protected StringBasedTextConverter(Encoding encoding) => _encoding = encoding; + + public override T Read(PgReader reader) + => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, _encoding, cancellationToken); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + + public override void Write(PgWriter writer, T value) + => writer.WriteChars(ConvertTo(value).Span, _encoding); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + protected abstract ReadOnlyMemory ConvertTo(T value); + protected abstract T ConvertFrom(string value); + + ValueTask Read(bool async, PgReader reader, Encoding encoding, CancellationToken cancellationToken = default) + { + return async + ? ReadAsync(reader, encoding, cancellationToken) + : new(ConvertFrom(encoding.GetString(reader.ReadBytes(reader.CurrentRemaining)))); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask ReadAsync(PgReader reader, Encoding encoding, CancellationToken cancellationToken) + => ConvertFrom(encoding.GetString(await reader.ReadBytesAsync(reader.CurrentRemaining, cancellationToken).ConfigureAwait(false))); + } +} + +sealed class ReadOnlyMemoryTextConverter : StringBasedTextConverter> +{ + public ReadOnlyMemoryTextConverter(Encoding encoding) : base(encoding) { } + protected override ReadOnlyMemory ConvertTo(ReadOnlyMemory value) => value; + protected override ReadOnlyMemory ConvertFrom(string value) => value.AsMemory(); +} + +sealed class StringTextConverter : StringBasedTextConverter +{ + public StringTextConverter(Encoding encoding) : base(encoding) { } + protected override ReadOnlyMemory ConvertTo(string value) => value.AsMemory(); + protected override string ConvertFrom(string value) => value; +} + +abstract class ArrayBasedTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + protected ArrayBasedTextConverter(Encoding encoding) => _encoding = encoding; + + public override T Read(PgReader reader) + => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, _encoding); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + + public override void Write(PgWriter writer, T value) + => writer.WriteChars(ConvertTo(value).AsSpan(), _encoding); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + protected abstract ArraySegment ConvertTo(T value); + protected abstract T ConvertFrom(ArraySegment value); + + ValueTask Read(bool async, PgReader reader, Encoding encoding) + { + return async ? ReadAsync(reader, encoding) : new(ConvertFrom(GetSegment(reader.ReadBytes(reader.CurrentRemaining), encoding))); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask ReadAsync(PgReader reader, Encoding encoding) + => ConvertFrom(GetSegment(await reader.ReadBytesAsync(reader.CurrentRemaining).ConfigureAwait(false), encoding)); + + static ArraySegment GetSegment(ReadOnlySequence bytes, Encoding encoding) + { + var array = TextConverter.GetChars(encoding, bytes); + return new(array, 0, array.Length); + } + } +} + +sealed class CharArraySegmentTextConverter : ArrayBasedTextConverter> +{ + public CharArraySegmentTextConverter(Encoding encoding) : base(encoding) { } + protected override ArraySegment ConvertTo(ArraySegment value) => value; + protected override ArraySegment ConvertFrom(ArraySegment value) => value; +} + +sealed class CharArrayTextConverter : ArrayBasedTextConverter +{ + public CharArrayTextConverter(Encoding encoding) : base(encoding) { } + protected override ArraySegment ConvertTo(char[] value) => new(value, 0, value.Length); + protected override char[] ConvertFrom(ArraySegment value) + { + if (value.Array?.Length == value.Count) + return value.Array!; + + var array = new char[value.Count]; + Array.Copy(value.Array!, value.Offset, array, 0, value.Count); + return array; + } +} + +sealed class CharTextConverter : PgBufferedConverter +{ + readonly Encoding _encoding; + readonly Size _oneCharMaxByteCount; + + public CharTextConverter(Encoding encoding) + { + _encoding = encoding; + _oneCharMaxByteCount = Size.CreateUpperBound(encoding.GetMaxByteCount(1)); + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(_oneCharMaxByteCount); + return format is DataFormat.Binary or DataFormat.Text; + } + + protected override char ReadCore(PgReader reader) + { + var byteSeq = reader.ReadBytes(Math.Min(_oneCharMaxByteCount.Value, reader.CurrentRemaining)); + Debug.Assert(byteSeq.IsSingleSegment); + var bytes = byteSeq.GetFirstSpan(); + + var chars = _encoding.GetCharCount(bytes); + if (chars < 1) + throw new NpgsqlException("Could not read char - string was empty"); + + Span destination = stackalloc char[chars]; + _encoding.GetChars(bytes, destination); + return destination[0]; + } + + public override Size GetSize(SizeContext context, char value, ref object? writeState) + { + Span spanValue = stackalloc char[] { value }; + return _encoding.GetByteCount(spanValue); + } + + protected override void WriteCore(PgWriter writer, char value) + { + Span spanValue = stackalloc char[] { value }; + writer.WriteChars(spanValue, _encoding); + } +} + +sealed class TextReaderTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + public TextReaderTextConverter(Encoding encoding) => _encoding = encoding; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary or DataFormat.Text; + } + + public override TextReader Read(PgReader reader) + => reader.GetTextReader(_encoding); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => reader.GetTextReaderAsync(_encoding, cancellationToken); + + public override Size GetSize(SizeContext context, TextReader value, ref object? writeState) => throw new NotImplementedException(); + public override void Write(PgWriter writer, TextReader value) => throw new NotImplementedException(); + public override ValueTask WriteAsync(PgWriter writer, TextReader value, CancellationToken cancellationToken = default) => throw new NotImplementedException(); +} + + +readonly struct GetChars +{ + public int Read { get; } + public GetChars(int read) => Read = read; +} + +sealed class GetCharsTextConverter : PgStreamingConverter +{ + readonly Encoding _encoding; + public GetCharsTextConverter(Encoding encoding) => _encoding = encoding; + + public override GetChars Read(PgReader reader) + => reader.CharsReadActive + ? ResumableRead(reader) + : throw new NotSupportedException(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => throw new NotSupportedException(); + + public override Size GetSize(SizeContext context, GetChars value, ref object? writeState) => throw new NotSupportedException(); + public override void Write(PgWriter writer, GetChars value) => throw new NotSupportedException(); + public override ValueTask WriteAsync(PgWriter writer, GetChars value, CancellationToken cancellationToken = default) => throw new NotSupportedException(); + + GetChars ResumableRead(PgReader reader) + { + reader.GetCharsReadInfo(_encoding, out var charsRead, out var textReader, out var charsOffset, out var buffer); + + // With variable length encodings, moving backwards based on bytes means we have to start over. + if (charsRead > charsOffset) + { + reader.RestartCharsRead(); + charsRead = 0; + } + + // First seek towards the charsOffset. + // If buffer is null read the entire thing and report the length, see sql client remarks. + // https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqldatareader.getchars + var read = ConsumeChars(textReader, buffer is null ? null : charsOffset - charsRead); + Debug.Assert(buffer is null || read == charsOffset - charsRead); + reader.AdvanceCharsRead(read); + if (buffer is null) + return new(read); + + read = textReader.ReadBlock(buffer.GetValueOrDefault().Array!, buffer.GetValueOrDefault().Offset, buffer.GetValueOrDefault().Count); + reader.AdvanceCharsRead(read); + return new(read); + + static int ConsumeChars(TextReader reader, int? count) + { + if (count is 0) + return 0; + + const int maxStackAlloc = 512; +#if NETSTANDARD + var tempCharBuf = new char[maxStackAlloc]; +#else + Span tempCharBuf = stackalloc char[maxStackAlloc]; +#endif + var totalRead = 0; + var fin = false; + while (!fin) + { + var toRead = count is null ? maxStackAlloc : Math.Min(maxStackAlloc, count.Value - totalRead); +#if NETSTANDARD + var read = reader.ReadBlock(tempCharBuf, 0, toRead); +#else + var read = reader.ReadBlock(tempCharBuf.Slice(0, toRead)); +#endif + totalRead += read; + if (count is not null && read is 0) + throw new EndOfStreamException(); + + fin = count is null ? read is 0 : totalRead >= count; + } + return totalRead; + } + } +} + +// Moved out for code size/sharing. +static class TextConverter +{ + public static Size GetSize(ref SizeContext context, ReadOnlyMemory value, Encoding encoding) + => encoding.GetByteCount(value.Span); + + // Adapted version of GetString(ROSeq) removing the intermediate string allocation to make a contiguous char array. + public static char[] GetChars(Encoding encoding, ReadOnlySequence bytes) + { + if (bytes.IsSingleSegment) + { + // If the incoming sequence is single-segment, one-shot this. + var firstSpan = bytes.First.Span; + var chars = new char[encoding.GetCharCount(firstSpan)]; + encoding.GetChars(bytes.First.Span, chars); + return chars; + } + else + { + // If the incoming sequence is multi-segment, create a stateful Decoder + // and use it as the workhorse. On the final iteration we'll pass flush=true. + + var decoder = encoding.GetDecoder(); + + // Maintain a list of all the segments we'll need to concat together. + // These will be released back to the pool at the end of the method. + + var listOfSegments = new List<(char[], int)>(); + var totalCharCount = 0; + + var remainingBytes = bytes; + bool isFinalSegment; + + do + { + var firstSpan = remainingBytes.First.Span; + var next = remainingBytes.GetPosition(firstSpan.Length); + isFinalSegment = remainingBytes.IsSingleSegment; + + var charCountThisIteration = decoder.GetCharCount(firstSpan, flush: isFinalSegment); // could throw ArgumentException if overflow would occur + var rentedArray = ArrayPool.Shared.Rent(charCountThisIteration); + var actualCharsWrittenThisIteration = decoder.GetChars(firstSpan, rentedArray, flush: isFinalSegment); + listOfSegments.Add((rentedArray, actualCharsWrittenThisIteration)); + + totalCharCount += actualCharsWrittenThisIteration; + if (totalCharCount < 0) + throw new OutOfMemoryException(); + + remainingBytes = remainingBytes.Slice(next); + } while (!isFinalSegment); + + // Now build up the string to return, then release all of our scratch buffers + // back to the shared pool. + var chars = new char[totalCharCount]; + var span = chars.AsSpan(); + foreach (var (array, length) in listOfSegments) + { + array.AsSpan(0, length).CopyTo(span); + ArrayPool.Shared.Return(array); + span = span.Slice(length); + } + + return chars; + } + } +} diff --git a/src/Npgsql/Internal/Converters/RangeConverter.cs b/src/Npgsql/Internal/Converters/RangeConverter.cs new file mode 100644 index 0000000000..c378d830f7 --- /dev/null +++ b/src/Npgsql/Internal/Converters/RangeConverter.cs @@ -0,0 +1,216 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +namespace Npgsql.Internal.Converters; + +sealed class RangeConverter : PgStreamingConverter> +{ + readonly PgConverter _subtypeConverter; + readonly BufferRequirements _subtypeRequirements; + + public RangeConverter(PgConverter subtypeConverter) + { + if (!subtypeConverter.CanConvert(DataFormat.Binary, out var bufferRequirements)) + throw new NotSupportedException("Range subtype converter has to support the binary format to be compatible."); + _subtypeRequirements = bufferRequirements; + _subtypeConverter = subtypeConverter; + } + + public override NpgsqlRange Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask> ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask> Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + var flags = (RangeFlags)reader.ReadByte(); + if ((flags & RangeFlags.Empty) != 0) + return NpgsqlRange.Empty; + + var lowerBound = default(TSubtype); + var upperBound = default(TSubtype); + + var converter = _subtypeConverter; + if ((flags & RangeFlags.LowerBoundInfinite) == 0) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + + // Note that we leave the CLR default for nulls + if (length != -1) + { + var scope = await reader.BeginNestedRead(async, length, _subtypeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + lowerBound = async + ? await converter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : converter.Read(reader); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + if ((flags & RangeFlags.UpperBoundInfinite) == 0) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var length = reader.ReadInt32(); + + // Note that we leave the CLR default for nulls + if (length != -1) + { + var scope = await reader.BeginNestedRead(async, length, _subtypeRequirements.Read, cancellationToken).ConfigureAwait(false); + try + { + upperBound = async + ? await converter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + : converter.Read(reader); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + } + + return new NpgsqlRange(lowerBound, upperBound, flags); + } + + public override Size GetSize(SizeContext context, NpgsqlRange value, ref object? writeState) + { + var totalSize = Size.Create(1); + if (value.IsEmpty) + return totalSize; // Just flags. + + WriteState? state = null; + if (!value.LowerBoundInfinite) + { + totalSize = totalSize.Combine(sizeof(int)); + var subTypeState = (object?)null; + if (_subtypeConverter.GetSizeOrDbNull(context.Format, _subtypeRequirements.Write, value.LowerBound, ref subTypeState) is { } size) + { + totalSize = totalSize.Combine(size); + (state ??= new WriteState()).LowerBoundSize = size; + state.LowerBoundWriteState = subTypeState; + } + else if (state is not null) + state.LowerBoundSize = -1; + } + + if (!value.UpperBoundInfinite) + { + totalSize = totalSize.Combine(sizeof(int)); + var subTypeState = (object?)null; + if (_subtypeConverter.GetSizeOrDbNull(context.Format, _subtypeRequirements.Write, value.UpperBound, ref subTypeState) is { } size) + { + totalSize = totalSize.Combine(size); + (state ??= new WriteState()).UpperBoundSize = size; + state.UpperBoundWriteState = subTypeState; + } + else if (state is not null) + state.UpperBoundSize = -1; + } + + writeState = state; + return totalSize; + } + + public override void Write(PgWriter writer, NpgsqlRange value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlRange value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlRange value, CancellationToken cancellationToken) + { + var writeState = writer.Current.WriteState as WriteState; + var lowerBoundSize = writeState?.LowerBoundSize ?? -1; + var upperBoundSize = writeState?.UpperBoundSize ?? -1; + + var flags = value.Flags; + if (!value.IsEmpty) + { + // Normalize nulls to infinite, as pg does. + if (lowerBoundSize == -1 && !value.LowerBoundInfinite) + flags = (flags & ~RangeFlags.LowerBoundInclusive) | RangeFlags.LowerBoundInfinite; + + if (upperBoundSize == -1 && !value.UpperBoundInfinite) + flags = (flags & ~RangeFlags.UpperBoundInclusive) | RangeFlags.UpperBoundInfinite; + } + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte((byte)flags); + var lowerBoundInfinite = flags.HasFlag(RangeFlags.LowerBoundInfinite); + var upperBoundInfinite = flags.HasFlag(RangeFlags.UpperBoundInfinite); + if (value.IsEmpty || (lowerBoundInfinite && upperBoundInfinite)) + return; + + // Always need write state from this point. + if (writeState is null) + throw new InvalidCastException($"Invalid write state, expected {typeof(WriteState).FullName}."); + + if (!lowerBoundInfinite) + { + Debug.Assert(lowerBoundSize.Value != -1); + if (lowerBoundSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var byteCount = lowerBoundSize.Value; // Never -1 so it's a byteCount. + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(byteCount); + using var _ = await writer.BeginNestedWrite(async, _subtypeRequirements.Write, byteCount, + writeState.LowerBoundWriteState, cancellationToken).ConfigureAwait(false); + if (async) + await _subtypeConverter.WriteAsync(writer, value.LowerBound!, cancellationToken).ConfigureAwait(false); + else + _subtypeConverter.Write(writer, value.LowerBound!); + } + + if (!upperBoundInfinite) + { + Debug.Assert(upperBoundSize.Value != -1); + if (upperBoundSize.Kind is SizeKind.Unknown) + throw new NotImplementedException(); + + var byteCount = upperBoundSize.Value; // Never -1 so it's a byteCount. + if (writer.ShouldFlush(sizeof(int))) // Length + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteInt32(byteCount); + using var _ = await writer.BeginNestedWrite(async, _subtypeRequirements.Write, byteCount, + writeState.UpperBoundWriteState, cancellationToken).ConfigureAwait(false); + if (async) + await _subtypeConverter.WriteAsync(writer, value.UpperBound!, cancellationToken).ConfigureAwait(false); + else + _subtypeConverter.Write(writer, value.UpperBound!); + } + } + + sealed class WriteState + { + internal Size LowerBoundSize { get; set; } + internal object? LowerBoundWriteState { get; set; } + internal Size UpperBoundSize { get; set; } + internal object? UpperBoundWriteState { get; set; } + } +} diff --git a/src/Npgsql/Internal/Converters/RecordConverter.cs b/src/Npgsql/Internal/Converters/RecordConverter.cs new file mode 100644 index 0000000000..aabd914b49 --- /dev/null +++ b/src/Npgsql/Internal/Converters/RecordConverter.cs @@ -0,0 +1,77 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.Converters; + +sealed class RecordConverter : PgStreamingConverter +{ + readonly PgSerializerOptions _options; + readonly Func? _factory; + + public RecordConverter(PgSerializerOptions options, Func? factory = null) + { + _options = options; + _factory = factory; + } + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + var fieldCount = reader.ReadInt32(); + var result = new object[fieldCount]; + for (var i = 0; i < fieldCount; i++) + { + if (reader.ShouldBuffer(sizeof(uint) + sizeof(int))) + await reader.Buffer(async, sizeof(uint) + sizeof(int), cancellationToken).ConfigureAwait(false); + + var typeOid = reader.ReadUInt32(); + var length = reader.ReadInt32(); + + // Note that we leave .NET nulls in the object array rather than DBNull. + if (length == -1) + continue; + + var postgresType = + _options.DatabaseInfo.GetPostgresType(typeOid).GetRepresentationalType() + ?? throw new NotSupportedException($"Reading isn't supported for record field {i} (unknown type OID {typeOid}"); + + var typeInfo = _options.GetObjectOrDefaultTypeInfo(postgresType) + ?? throw new NotSupportedException( + $"Reading isn't supported for record field {i} (PG type '{postgresType.DisplayName}'"); + + var converterInfo = typeInfo.Bind(new Field("?", _options.ToCanonicalTypeId(postgresType), -1), DataFormat.Binary); + var scope = await reader.BeginNestedRead(async, length, converterInfo.BufferRequirement, cancellationToken).ConfigureAwait(false); + try + { + result[i] = await converterInfo.Converter.ReadAsObject(async, reader, cancellationToken).ConfigureAwait(false); + } + finally + { + if (async) + await scope.DisposeAsync().ConfigureAwait(false); + else + scope.Dispose(); + } + } + + return _factory is null ? (T)(object)result : _factory(result); + } + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => throw new NotSupportedException(); + + public override void Write(PgWriter writer, T value) + => throw new NotSupportedException(); + + public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) + => throw new NotSupportedException(); +} diff --git a/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs new file mode 100644 index 0000000000..261d305439 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs @@ -0,0 +1,103 @@ +using System; +using Npgsql.Properties; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + static readonly DateTime BaseValue = new(2000, 1, 1, 0, 0, 0); + + public DateTimeDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? DateTime.MaxValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? DateTime.MinValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + var value => BaseValue + TimeSpan.FromDays(value) + }; + + protected override void WriteCore(PgWriter writer, DateTime value) + { + if (_dateTimeInfinityConversions) + { + if (value == DateTime.MaxValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + + if (value == DateTime.MinValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + writer.WriteInt32((value.Date - BaseValue).Days); + } +} + +#if NET6_0_OR_GREATER +sealed class DateOnlyDateConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + static readonly DateOnly BaseValue = new(2000, 1, 1); + + public DateOnlyDateConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateOnly ReadCore(PgReader reader) + => reader.ReadInt32() switch + { + int.MaxValue => _dateTimeInfinityConversions + ? DateOnly.MaxValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + int.MinValue => _dateTimeInfinityConversions + ? DateOnly.MinValue + : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), + var value => BaseValue.AddDays(value) + }; + + protected override void WriteCore(PgWriter writer, DateOnly value) + { + if (_dateTimeInfinityConversions) + { + if (value == DateOnly.MaxValue) + { + writer.WriteInt32(int.MaxValue); + return; + } + + if (value == DateOnly.MinValue) + { + writer.WriteInt32(int.MinValue); + return; + } + } + + writer.WriteInt32(value.DayNumber - BaseValue.DayNumber); + } +} +#endif diff --git a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs new file mode 100644 index 0000000000..6ae5a783a1 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverterResolver.cs @@ -0,0 +1,143 @@ +using System; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeConverterResolver : PgConverterResolver +{ + readonly PgSerializerOptions _options; + readonly Func, T?, PgTypeId?, PgConverterResolution?> _resolver; + readonly Func _factory; + readonly PgTypeId _timestampTz; + PgConverter? _timestampTzConverter; + readonly PgTypeId _timestamp; + PgConverter? _timestampConverter; + readonly bool _dateTimeInfinityConversions; + + internal DateTimeConverterResolver(PgSerializerOptions options, Func, T?, PgTypeId?, PgConverterResolution?> resolver, Func factory, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + { + _options = options; + _resolver = resolver; + _factory = factory; + _timestampTz = timestampTz; + _timestamp = timestamp; + _dateTimeInfinityConversions = dateTimeInfinityConversions; + } + + public override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + if (pgTypeId == _timestampTz) + return new(_timestampTzConverter ??= _factory(_timestampTz), _timestampTz); + if (pgTypeId is null || pgTypeId == _timestamp) + return new(_timestampConverter ??= _factory(_timestamp), _timestamp); + + throw CreateUnsupportedPgTypeIdException(pgTypeId.Value); + } + + public PgConverterResolution? Get(DateTime value, PgTypeId? expectedPgTypeId, bool validateOnly = false) + { + if (value.Kind is DateTimeKind.Utc) + { + // We coalesce with expectedPgTypeId to throw on unknown type ids. + return expectedPgTypeId == _timestamp + ? throw new ArgumentException( + string.Format(NpgsqlStrings.TimestampNoDateTimeUtc, _options.GetDataTypeName(_timestamp).DisplayName, _options.GetDataTypeName(_timestampTz).DisplayName), nameof(value)) + : validateOnly ? null : GetDefault(expectedPgTypeId ?? _timestampTz); + } + + // For timestamptz types we'll accept unspecified MinValue/MaxValue as well. + if (expectedPgTypeId == _timestampTz + && !(_dateTimeInfinityConversions && (value == DateTime.MinValue || value == DateTime.MaxValue))) + { + throw new ArgumentException( + string.Format(NpgsqlStrings.TimestampTzNoDateTimeUnspecified, value.Kind, _options.GetDataTypeName(_timestampTz).DisplayName), nameof(value)); + } + + // We coalesce with expectedPgTypeId to throw on unknown type ids. + return GetDefault(expectedPgTypeId ?? _timestamp); + } + + public override PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId) + => _resolver(this, value, expectedPgTypeId); +} + +sealed class DateTimeConverterResolver +{ + public static DateTimeConverterResolver CreateResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + => new(options, static (resolver, value, expectedPgTypeId) => resolver.Get(value, expectedPgTypeId), pgTypeId => + { + if (pgTypeId == timestampTz) + return new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc); + if (pgTypeId == timestamp) + return new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + + public static DateTimeConverterResolver> CreateRangeResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + => new(options, static (resolver, value, expectedPgTypeId) => + { + // Resolve both sides to make sure we end up with consistent PgTypeIds. + PgConverterResolution? resolution = null; + if (!value.LowerBoundInfinite) + resolution = resolver.Get(value.LowerBound, expectedPgTypeId); + + if (!value.UpperBoundInfinite) + { + var result = resolver.Get(value.UpperBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + + return resolution; + }, pgTypeId => + { + if (pgTypeId == timestampTz) + return new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc)); + if (pgTypeId == timestamp) + return new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified)); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + + public static DateTimeConverterResolver CreateMultirangeResolver(PgSerializerOptions options, PgTypeId timestampTz, PgTypeId timestamp, bool dateTimeInfinityConversions) + where T : IList where TElement : notnull + { + if (typeof(TElement) != typeof(NpgsqlRange)) + ThrowHelper.ThrowNotSupportedException("Unsupported element type"); + + return new DateTimeConverterResolver(options, static (resolver, value, expectedPgTypeId) => + { + PgConverterResolution? resolution = null; + if (value is null) + return null; + + foreach (var element in (IList>)value) + { + PgConverterResolution? result; + if (!element.LowerBoundInfinite) + { + result = resolver.Get(element.LowerBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + if (!element.UpperBoundInfinite) + { + result = resolver.Get(element.UpperBound, resolution?.PgTypeId ?? expectedPgTypeId, validateOnly: resolution is not null); + resolution ??= result; + } + } + return resolution; + }, pgTypeId => + { + if (pgTypeId == timestampTz) + return new MultirangeConverter((PgConverter)(object)new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Utc))); + if (pgTypeId == timestamp) + return new MultirangeConverter((PgConverter)(object)new RangeConverter(new DateTimeConverter(dateTimeInfinityConversions, DateTimeKind.Unspecified))); + + throw new NotSupportedException(); + }, timestampTz, timestamp, dateTimeInfinityConversions); + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs new file mode 100644 index 0000000000..ed744bb099 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs @@ -0,0 +1,53 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class DateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly DateTimeKind _kind; + + public DateTimeConverter(bool dateTimeInfinityConversions, DateTimeKind kind) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _kind = kind; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + => PgTimestamp.Decode(reader.ReadInt64(), _kind, _dateTimeInfinityConversions); + + protected override void WriteCore(PgWriter writer, DateTime value) + => writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); +} + +sealed class DateTimeOffsetConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + public DateTimeOffsetConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + => new(PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions), TimeSpan.Zero); + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + { + if (value.Offset != TimeSpan.Zero) + throw new ArgumentException($"Cannot write DateTimeOffset with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', only offset 0 (UTC) is supported. ", nameof(value)); + + writer.WriteInt64(PgTimestamp.Encode(value.DateTime, _dateTimeInfinityConversions)); + + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs b/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs new file mode 100644 index 0000000000..1e1cbe9df2 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/IntervalConverters.cs @@ -0,0 +1,58 @@ +using System; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TimeSpanIntervalConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override TimeSpan ReadCore(PgReader reader) + { + var microseconds = reader.ReadInt64(); + var days = reader.ReadInt32(); + var months = reader.ReadInt32(); + + return months > 0 + ? throw new InvalidCastException( + "Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it.") + : new(microseconds * 10 + days * TimeSpan.TicksPerDay); + } + + protected override void WriteCore(PgWriter writer, TimeSpan value) + { + var ticksInDay = value.Ticks - TimeSpan.TicksPerDay * value.Days; + writer.WriteInt64(ticksInDay / 10); + writer.WriteInt32(value.Days); + writer.WriteInt32(0); + } +} + +sealed class NpgsqlIntervalConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override NpgsqlInterval ReadCore(PgReader reader) + { + var ticks = reader.ReadInt64(); + var day = reader.ReadInt32(); + var month = reader.ReadInt32(); + return new NpgsqlInterval(month, day, ticks); + } + + protected override void WriteCore(PgWriter writer, NpgsqlInterval value) + { + writer.WriteInt64(value.Time); + writer.WriteInt32(value.Days); + writer.WriteInt32(value.Months); + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs new file mode 100644 index 0000000000..99ad4ed599 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs @@ -0,0 +1,74 @@ +using System; + +namespace Npgsql.Internal.Converters; + +sealed class LegacyDateTimeConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + readonly bool _timestamp; + + public LegacyDateTimeConverter(bool dateTimeInfinityConversions, bool timestamp) + { + _dateTimeInfinityConversions = dateTimeInfinityConversions; + _timestamp = timestamp; + } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTime ReadCore(PgReader reader) + { + if (_timestamp) + { + return PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Unspecified, _dateTimeInfinityConversions); + } + + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + return (dateTime == DateTime.MinValue || dateTime == DateTime.MaxValue) && _dateTimeInfinityConversions + ? dateTime + : dateTime.ToLocalTime(); + } + + protected override void WriteCore(PgWriter writer, DateTime value) + { + if (!_timestamp && value.Kind is DateTimeKind.Local) + value = value.ToUniversalTime(); + + writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); + } +} + +sealed class LegacyDateTimeOffsetConverter : PgBufferedConverter +{ + readonly bool _dateTimeInfinityConversions; + + public LegacyDateTimeOffsetConverter(bool dateTimeInfinityConversions) + => _dateTimeInfinityConversions = dateTimeInfinityConversions; + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + { + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + + if (_dateTimeInfinityConversions) + { + if (dateTime == DateTime.MinValue) + return DateTimeOffset.MinValue; + if (dateTime == DateTime.MaxValue) + return DateTimeOffset.MaxValue; + } + + return dateTime.ToLocalTime(); + } + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + => writer.WriteInt64(PgTimestamp.Encode(value.UtcDateTime, _dateTimeInfinityConversions)); +} diff --git a/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs b/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs new file mode 100644 index 0000000000..6a44ccbdc9 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/PgTimestamp.cs @@ -0,0 +1,43 @@ +using System; + +namespace Npgsql.Internal.Converters; + +static class PgTimestamp +{ + const long PostgresTimestampOffsetTicks = 630822816000000000L; + + internal static long Encode(DateTime value, bool dateTimeInfinityConversions) + { + if (dateTimeInfinityConversions) + { + if (value.Ticks == DateTime.MaxValue.Ticks) + return long.MaxValue; + if (value.Ticks == DateTime.MinValue.Ticks) + return long.MinValue; + } + // Rounding here would cause problems because we would round up DateTime.MaxValue + // which would make it impossible to retrieve it back from the database, so we just drop the additional precision + return (value.Ticks - PostgresTimestampOffsetTicks) / 10; + } + + internal static DateTime Decode(long value, DateTimeKind kind, bool dateTimeInfinityConversions) + { + try + { + return value switch + { + long.MaxValue => dateTimeInfinityConversions + ? DateTime.MaxValue + : throw new InvalidCastException("Cannot read infinity value since DisableDateTimeInfinityConversions is true."), + long.MinValue => dateTimeInfinityConversions + ? DateTime.MinValue + : throw new InvalidCastException("Cannot read infinity value since DisableDateTimeInfinityConversions is true."), + _ => new(value * 10 + PostgresTimestampOffsetTicks, kind) + }; + } + catch (ArgumentOutOfRangeException e) + { + throw new InvalidCastException("Out of range of DateTime (year must be between 1 and 9999).", e); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs new file mode 100644 index 0000000000..e756a03b85 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs @@ -0,0 +1,52 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class TimeSpanTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override TimeSpan ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeSpan value) => writer.WriteInt64(value.Ticks / 10); +} + +#if NET6_0_OR_GREATER +sealed class TimeOnlyTimeConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); + return format is DataFormat.Binary; + } + protected override TimeOnly ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeOnly value) => writer.WriteInt64(value.Ticks / 10); +} +#endif + +sealed class DateTimeOffsetTimeTzConverter : PgBufferedConverter +{ + // Binary Format: int64 expressing microseconds, int32 expressing timezone in seconds, negative + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long) + sizeof(int)); + return format is DataFormat.Binary; + } + + protected override DateTimeOffset ReadCore(PgReader reader) + { + // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. + var ticks = reader.ReadInt64() * 10; + var offset = new TimeSpan(0, 0, -reader.ReadInt32()); + return new DateTimeOffset(ticks + TimeSpan.TicksPerDay, offset); + } + + protected override void WriteCore(PgWriter writer, DateTimeOffset value) + { + writer.WriteInt64(value.TimeOfDay.Ticks / 10); + writer.WriteInt32(-(int)(value.Offset.Ticks / TimeSpan.TicksPerSecond)); + } +} diff --git a/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs new file mode 100644 index 0000000000..ccb4f2041e --- /dev/null +++ b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs @@ -0,0 +1,105 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal.Converters; + +sealed class VersionPrefixedTextConverter : PgStreamingConverter +{ + readonly byte _versionPrefix; + readonly PgConverter _textConverter; + BufferRequirements _innerRequirements; + + public VersionPrefixedTextConverter(byte versionPrefix, PgConverter textConverter) + : base(textConverter.DbNullPredicateKind is DbNullPredicate.Custom) + { + _versionPrefix = versionPrefix; + _textConverter = textConverter; + } + + protected override bool IsDbNullValue(T? value, ref object? writeState) => _textConverter.IsDbNull(value, ref writeState); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => VersionPrefixedTextConverter.CanConvert(_textConverter, format, out _innerRequirements, out bufferRequirements); + + public override T Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).Result; + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + public override Size GetSize(SizeContext context, [DisallowNull]T value, ref object? writeState) + => _textConverter.GetSize(context, value, ref writeState).Combine(context.Format is DataFormat.Binary ? sizeof(byte) : 0); + + public override void Write(PgWriter writer, [DisallowNull]T value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, [DisallowNull]T value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + await VersionPrefixedTextConverter.ReadVersion(async, _versionPrefix, reader, _innerRequirements.Read, cancellationToken).ConfigureAwait(false); + return async ? await _textConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) : _textConverter.Read(reader); + } + + async ValueTask Write(bool async, PgWriter writer, [DisallowNull]T value, CancellationToken cancellationToken) + { + await VersionPrefixedTextConverter.WriteVersion(async, _versionPrefix, writer, cancellationToken).ConfigureAwait(false); + if (async) + await _textConverter.WriteAsync(writer, value, cancellationToken).ConfigureAwait(false); + else + _textConverter.Write(writer, value); + } +} + +static class VersionPrefixedTextConverter +{ + public static async ValueTask WriteVersion(bool async, byte version, PgWriter writer, CancellationToken cancellationToken) + { + if (writer.Current.Format is not DataFormat.Binary) + return; + + if (writer.ShouldFlush(sizeof(byte))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteByte(version); + } + + public static async ValueTask ReadVersion(bool async, byte expectedVersion, PgReader reader, Size textConverterReadRequirement, CancellationToken cancellationToken) + { + if (reader.Current.Format is not DataFormat.Binary) + return; + + if (!reader.IsResumed) + { + if (reader.ShouldBuffer(sizeof(byte))) + await reader.Buffer(async, sizeof(byte), cancellationToken).ConfigureAwait(false); + + var actualVersion = reader.ReadByte(); + if (actualVersion != expectedVersion) + throw new InvalidCastException($"Unknown wire format version: {actualVersion}"); + } + + // No need for a nested read, all text converters will read CurrentRemaining bytes. + // We only need to buffer data if we're binary, otherwise the caller would have had to do so + // as we directly expose the underlying text converter requirements for the text data format. + await reader.Buffer(async, textConverterReadRequirement, cancellationToken).ConfigureAwait(false); + } + + public static bool CanConvert(PgConverter textConverter, DataFormat format, out BufferRequirements textConverterRequirements, out BufferRequirements bufferRequirements) + { + var success = textConverter.CanConvert(format, out textConverterRequirements); + if (!success) + { + bufferRequirements = default; + return false; + } + if (textConverter.CanConvert(format is DataFormat.Binary ? DataFormat.Text : DataFormat.Binary, out var otherRequirements) && otherRequirements != textConverterRequirements) + throw new InvalidOperationException("Text converter should have identical requirements for text and binary formats."); + + bufferRequirements = format is DataFormat.Binary ? textConverterRequirements.Combine(sizeof(byte)) : textConverterRequirements; + + return success; + } +} diff --git a/src/Npgsql/Internal/DataFormat.cs b/src/Npgsql/Internal/DataFormat.cs new file mode 100644 index 0000000000..c52b418b7d --- /dev/null +++ b/src/Npgsql/Internal/DataFormat.cs @@ -0,0 +1,31 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public enum DataFormat : byte +{ + Binary, + Text +} + +static class DataFormatUtils +{ + public static DataFormat Create(short formatCode) + => formatCode switch + { + 0 => DataFormat.Text, + 1 => DataFormat.Binary, + _ => throw new ArgumentOutOfRangeException(nameof(formatCode), formatCode, "Unknown postgres format code, please file a bug,") + }; + + public static short ToFormatCode(this DataFormat dataFormat) + => dataFormat switch + { + DataFormat.Text => 0, + DataFormat.Binary => 1, + _ => throw new UnreachableException() + }; +} diff --git a/src/Npgsql/Internal/DynamicTypeInfoResolver.cs b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs new file mode 100644 index 0000000000..421de703f5 --- /dev/null +++ b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs @@ -0,0 +1,133 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +[RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] +public abstract class DynamicTypeInfoResolver : IPgTypeInfoResolver +{ + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is null) + return null; + + var context = GetMappings(type, dataTypeName.GetValueOrDefault(), options); + return context?.Find(type, dataTypeName.GetValueOrDefault(), options); + } + + protected static DynamicMappingCollection CreateCollection(TypeInfoMappingCollection? baseCollection = null) => new(baseCollection); + + protected static bool IsTypeOrNullableOfType(Type type, Func predicate, out Type matchedType) + { + matchedType = Nullable.GetUnderlyingType(type) ?? type; + return predicate(matchedType); + } + + protected static bool IsArrayLikeType(Type type, [NotNullWhen(true)]out Type? elementType) => TypeInfoMappingCollection.IsArrayLikeType(type, out elementType); + + protected static bool IsArrayDataTypeName(DataTypeName dataTypeName, PgSerializerOptions options, out DataTypeName elementDataTypeName) + { + if (options.DatabaseInfo.GetPostgresType(dataTypeName) is PostgresArrayType arrayType) + { + elementDataTypeName = arrayType.Element.DataTypeName; + return true; + } + + elementDataTypeName = default; + return false; + } + + protected abstract DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options); + + [RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] + protected class DynamicMappingCollection + { + TypeInfoMappingCollection? _mappings; + + static readonly MethodInfo AddTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddStructType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( + nameof(TypeInfoMappingCollection.AddResolverType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( + nameof(TypeInfoMappingCollection.AddResolverStructType), + new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); + + static readonly MethodInfo AddResolverStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); + + internal DynamicMappingCollection(TypeInfoMappingCollection? baseCollection = null) + { + if (baseCollection is not null) + _mappings = new(baseCollection); + } + + public DynamicMappingCollection AddMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + { + if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) + throw new NotSupportedException("Mapping nullable types is not supported, map its underlying type instead to get both."); + + (type.IsValueType ? AddStructTypeMethodInfo : AddTypeMethodInfo) + .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] + { + dataTypeName, + factory, + configureMapping + }); + return this; + } + + public DynamicMappingCollection AddArrayMapping(Type elementType, string dataTypeName) + { + (elementType.IsValueType ? AddStructArrayTypeMethodInfo : AddArrayTypeMethodInfo) + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + return this; + } + + public DynamicMappingCollection AddResolverMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + { + if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) + throw new NotSupportedException("Mapping nullable types is not supported"); + + (type.IsValueType ? AddResolverStructTypeMethodInfo : AddResolverTypeMethodInfo) + .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] + { + dataTypeName, + factory, + configureMapping + }); + return this; + } + + public DynamicMappingCollection AddResolverArrayMapping(Type elementType, string dataTypeName) + { + (elementType.IsValueType ? AddResolverStructArrayTypeMethodInfo : AddResolverArrayTypeMethodInfo) + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + return this; + } + + internal PgTypeInfo? Find(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => _mappings?.Find(type, dataTypeName, options); + + public TypeInfoMappingCollection ToTypeInfoMappingCollection() + => new(_mappings?.Items ?? Array.Empty()); + } +} diff --git a/src/Npgsql/Internal/HackyEnumTypeMapping.cs b/src/Npgsql/Internal/HackyEnumTypeMapping.cs new file mode 100644 index 0000000000..25f98ed8cc --- /dev/null +++ b/src/Npgsql/Internal/HackyEnumTypeMapping.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using Npgsql.Internal; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + +/// +/// Hacky temporary measure used by EFCore.PG to extract user-configured enum mappings. Accessed via reflection only. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public sealed class HackyEnumTypeMapping +{ + public HackyEnumTypeMapping(Type enumClrType, string pgTypeName, INpgsqlNameTranslator nameTranslator) + { + EnumClrType = enumClrType; + PgTypeName = pgTypeName; + NameTranslator = nameTranslator; + } + + public string PgTypeName { get; } + public Type EnumClrType { get; } + public INpgsqlNameTranslator NameTranslator { get; } +} diff --git a/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs b/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs new file mode 100644 index 0000000000..ea3f0ad525 --- /dev/null +++ b/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs @@ -0,0 +1,24 @@ +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; +using Npgsql.Util; + +namespace Npgsql.Internal; + +/// +/// A factory which get generate instances of , which describe a database +/// and the types it contains. When first connecting to a database, Npgsql will attempt to load information +/// about it via this factory. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public interface INpgsqlDatabaseInfoFactory +{ + /// + /// Given a connection, loads all necessary information about the connected database, e.g. its types. + /// A factory should only handle the exact database type it was meant for, and return null otherwise. + /// + /// + /// An object describing the database to which is connected, or null if the + /// database isn't of the correct type and isn't handled by this factory. + /// + Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async); +} diff --git a/src/Npgsql/Internal/IPgTypeInfoResolver.cs b/src/Npgsql/Internal/IPgTypeInfoResolver.cs new file mode 100644 index 0000000000..b7b3ddc9ec --- /dev/null +++ b/src/Npgsql/Internal/IPgTypeInfoResolver.cs @@ -0,0 +1,21 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +/// +/// An Npgsql resolver for type info. Used by Npgsql to read and write values to PostgreSQL. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public interface IPgTypeInfoResolver +{ + /// + /// Resolve a type info for a given type and data type name, at least one value will be non-null. + /// + /// The clr type being requested. + /// The postgres type being requested. + /// Used for configuration state and Npgsql type info or PostgreSQL type catalog lookups. + /// A result, or null if there was no match. + PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options); +} diff --git a/src/Npgsql/Internal/IntegratedSecurityHandler.cs b/src/Npgsql/Internal/IntegratedSecurityHandler.cs new file mode 100644 index 0000000000..2b2f2f1bb9 --- /dev/null +++ b/src/Npgsql/Internal/IntegratedSecurityHandler.cs @@ -0,0 +1,32 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Properties; + +namespace Npgsql.Internal; + +class IntegratedSecurityHandler +{ + public virtual bool IsSupported => false; + + public virtual ValueTask GetUsername(bool async, bool includeRealm, ILogger connectionLogger, CancellationToken cancellationToken) + { + connectionLogger.LogDebug(string.Format(NpgsqlStrings.IntegratedSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableIntegratedSecurity))); + return new(); + } + + public virtual ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector) + => throw new NotSupportedException(string.Format(NpgsqlStrings.IntegratedSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableIntegratedSecurity))); +} + +sealed class RealIntegratedSecurityHandler : IntegratedSecurityHandler +{ + public override bool IsSupported => true; + + public override ValueTask GetUsername(bool async, bool includeRealm, ILogger connectionLogger, CancellationToken cancellationToken) + => KerberosUsernameProvider.GetUsername(async, includeRealm, connectionLogger, cancellationToken); + + public override ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector) + => new(connector.AuthenticateGSS(async)); +} diff --git a/src/Npgsql/Internal/NpgsqlConnector.Auth.cs b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs new file mode 100644 index 0000000000..6eeb0fa44b --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs @@ -0,0 +1,402 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Security; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.BackendMessages; +using Npgsql.Util; +using static Npgsql.Util.Statics; + +namespace Npgsql.Internal; + +partial class NpgsqlConnector +{ + async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + while (true) + { + timeout.CheckAndApply(this); + var msg = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); + switch (msg.AuthRequestType) + { + case AuthenticationRequestType.AuthenticationOk: + return; + + case AuthenticationRequestType.AuthenticationCleartextPassword: + await AuthenticateCleartext(username, async, cancellationToken).ConfigureAwait(false); + break; + + case AuthenticationRequestType.AuthenticationMD5Password: + await AuthenticateMD5(username, ((AuthenticationMD5PasswordMessage)msg).Salt, async, cancellationToken).ConfigureAwait(false); + break; + + case AuthenticationRequestType.AuthenticationSASL: + await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async, + cancellationToken).ConfigureAwait(false); + break; + + case AuthenticationRequestType.AuthenticationGSS: + case AuthenticationRequestType.AuthenticationSSPI: + await DataSource.IntegratedSecurityHandler.NegotiateAuthentication(async, this).ConfigureAwait(false); + return; + + case AuthenticationRequestType.AuthenticationGSSContinue: + throw new NpgsqlException("Can't start auth cycle with AuthenticationGSSContinue"); + + default: + throw new NotSupportedException($"Authentication method not supported (Received: {msg.AuthRequestType})"); + } + } + } + + async Task AuthenticateCleartext(string username, bool async, CancellationToken cancellationToken = default) + { + var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); + if (passwd == null) + throw new NpgsqlException("No password has been provided but the backend requires one (in cleartext)"); + + var encoded = new byte[Encoding.UTF8.GetByteCount(passwd) + 1]; + Encoding.UTF8.GetBytes(passwd, 0, passwd.Length, encoded, 0); + + await WritePassword(encoded, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + } + + async Task AuthenticateSASL(List mechanisms, string username, bool async, CancellationToken cancellationToken) + { + // At the time of writing PostgreSQL only supports SCRAM-SHA-256 and SCRAM-SHA-256-PLUS + var serverSupportsSha256 = mechanisms.Contains("SCRAM-SHA-256"); + var clientSupportsSha256 = serverSupportsSha256 && Settings.ChannelBinding != ChannelBinding.Require; + var serverSupportsSha256Plus = mechanisms.Contains("SCRAM-SHA-256-PLUS"); + var clientSupportsSha256Plus = serverSupportsSha256Plus && Settings.ChannelBinding != ChannelBinding.Disable; + if (!clientSupportsSha256 && !clientSupportsSha256Plus) + { + if (serverSupportsSha256 && Settings.ChannelBinding == ChannelBinding.Require) + throw new NpgsqlException($"Couldn't connect because {nameof(ChannelBinding)} is set to {nameof(ChannelBinding.Require)} " + + "but the server doesn't support SCRAM-SHA-256-PLUS"); + if (serverSupportsSha256Plus && Settings.ChannelBinding == ChannelBinding.Disable) + throw new NpgsqlException($"Couldn't connect because {nameof(ChannelBinding)} is set to {nameof(ChannelBinding.Disable)} " + + "but the server doesn't support SCRAM-SHA-256"); + + throw new NpgsqlException("No supported SASL mechanism found (only SCRAM-SHA-256 and SCRAM-SHA-256-PLUS are supported for now). " + + "Mechanisms received from server: " + string.Join(", ", mechanisms)); + } + + var mechanism = string.Empty; + var cbindFlag = string.Empty; + var cbind = string.Empty; + var successfulBind = false; + + if (clientSupportsSha256Plus) + DataSource.TransportSecurityHandler.AuthenticateSASLSha256Plus(this, ref mechanism, ref cbindFlag, ref cbind, ref successfulBind); + + if (!successfulBind && serverSupportsSha256) + { + mechanism = "SCRAM-SHA-256"; + // We can get here if PostgreSQL supports only SCRAM-SHA-256 or there was an error while binding to SCRAM-SHA-256-PLUS + // Or the user specifically requested to not use bindings + // So, we set 'n' (client does not support binding) if there was an error while binding + // or 'y' (client supports but server doesn't) in other case + cbindFlag = serverSupportsSha256Plus ? "n" : "y"; + cbind = serverSupportsSha256Plus ? "biws" : "eSws"; + successfulBind = true; + IsScram = true; + } + + if (!successfulBind) + { + // We can get here if PostgreSQL supports only SCRAM-SHA-256-PLUS but there was an error while binding to it + throw new NpgsqlException("Unable to bind to SCRAM-SHA-256-PLUS, check logs for more information"); + } + + var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false) ?? + throw new NpgsqlException($"No password has been provided but the backend requires one (in SASL/{mechanism})"); + + // Assumption: the write buffer is big enough to contain all our outgoing messages + var clientNonce = GetNonce(); + + await WriteSASLInitialResponse(mechanism, NpgsqlWriteBuffer.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + + var saslContinueMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); + if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLContinue) + throw new NpgsqlException("[SASL] AuthenticationSASLContinue message expected"); + var firstServerMsg = AuthenticationSCRAMServerFirstMessage.Load(saslContinueMsg.Payload, ConnectionLogger); + if (!firstServerMsg.Nonce.StartsWith(clientNonce, StringComparison.Ordinal)) + throw new NpgsqlException("[SCRAM] Malformed SCRAMServerFirst message: server nonce doesn't start with client nonce"); + + var saltBytes = Convert.FromBase64String(firstServerMsg.Salt); + var saltedPassword = Hi(passwd.Normalize(NormalizationForm.FormKC), saltBytes, firstServerMsg.Iteration); + + var clientKey = HMAC(saltedPassword, "Client Key"); + byte[] storedKey; +#if NET7_0_OR_GREATER + storedKey = SHA256.HashData(clientKey); +#else + using (var sha256 = SHA256.Create()) + storedKey = sha256.ComputeHash(clientKey); +#endif + var clientFirstMessageBare = $"n=*,r={clientNonce}"; + var serverFirstMessage = $"r={firstServerMsg.Nonce},s={firstServerMsg.Salt},i={firstServerMsg.Iteration}"; + var clientFinalMessageWithoutProof = $"c={cbind},r={firstServerMsg.Nonce}"; + + var authMessage = $"{clientFirstMessageBare},{serverFirstMessage},{clientFinalMessageWithoutProof}"; + + var clientSignature = HMAC(storedKey, authMessage); + var clientProofBytes = Xor(clientKey, clientSignature); + var clientProof = Convert.ToBase64String(clientProofBytes); + + var serverKey = HMAC(saltedPassword, "Server Key"); + var serverSignature = HMAC(serverKey, authMessage); + + var messageStr = $"{clientFinalMessageWithoutProof},p={clientProof}"; + + await WriteSASLResponse(Encoding.UTF8.GetBytes(messageStr), async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + + var saslFinalServerMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); + if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLFinal) + throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); + + var scramFinalServerMsg = AuthenticationSCRAMServerFinalMessage.Load(saslFinalServerMsg.Payload, ConnectionLogger); + if (scramFinalServerMsg.ServerSignature != Convert.ToBase64String(serverSignature)) + throw new NpgsqlException("[SCRAM] Unable to verify server signature"); + + + static string GetNonce() + { + using var rncProvider = RandomNumberGenerator.Create(); + var nonceBytes = new byte[18]; + + rncProvider.GetBytes(nonceBytes); + return Convert.ToBase64String(nonceBytes); + } + } + + internal void AuthenticateSASLSha256Plus(ref string mechanism, ref string cbindFlag, ref string cbind, + ref bool successfulBind) + { + // The check below is copied from libpq (with commentary) + // https://github.com/postgres/postgres/blob/98640f960eb9ed80cf90de3ef5d2e829b785b3eb/src/interfaces/libpq/fe-auth.c#L507-L517 + + // The server offered SCRAM-SHA-256-PLUS, but the connection + // is not SSL-encrypted. That's not sane. Perhaps SSL was + // stripped by a proxy? There's no point in continuing, + // because the server will reject the connection anyway if we + // try authenticate without channel binding even though both + // the client and server supported it. The SCRAM exchange + // checks for that, to prevent downgrade attacks. + if (!IsSecure) + throw new NpgsqlException("Server offered SCRAM-SHA-256-PLUS authentication over a non-SSL connection"); + + var sslStream = (SslStream)_stream; + if (sslStream.RemoteCertificate is null) + { + ConnectionLogger.LogWarning("Remote certificate null, falling back to SCRAM-SHA-256"); + return; + } + + using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate); + // Checking for hashing algorithms + HashAlgorithm? hashAlgorithm = null; + var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName; + if (algorithmName is null) + { + ConnectionLogger.LogWarning("Signature algorithm was null, falling back to SCRAM-SHA-256"); + } + else if (algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) || + algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) || + algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase)) + { + hashAlgorithm = SHA256.Create(); + } + else if (algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase)) + { + hashAlgorithm = SHA384.Create(); + } + else if (algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase)) + { + hashAlgorithm = SHA512.Create(); + } + else + { + ConnectionLogger.LogWarning( + $"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256"); + } + + if (hashAlgorithm != null) + { + using var _ = hashAlgorithm; + + // RFC 5929 + mechanism = "SCRAM-SHA-256-PLUS"; + // PostgreSQL only supports tls-server-end-point binding + cbindFlag = "p=tls-server-end-point"; + // SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine + var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,"); + + var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData()); + var cbindBytes = new byte[cbindFlagBytes.Length + certificateHash.Length]; + cbindFlagBytes.CopyTo(cbindBytes, 0); + certificateHash.CopyTo(cbindBytes, cbindFlagBytes.Length); + cbind = Convert.ToBase64String(cbindBytes); + successfulBind = true; + IsScramPlus = true; + } + } + +#if NET6_0_OR_GREATER + static byte[] Hi(string str, byte[] salt, int count) + => Rfc2898DeriveBytes.Pbkdf2(str, salt, count, HashAlgorithmName.SHA256, 256 / 8); +#endif + + static byte[] Xor(byte[] buffer1, byte[] buffer2) + { + for (var i = 0; i < buffer1.Length; i++) + buffer1[i] ^= buffer2[i]; + return buffer1; + } + + static byte[] HMAC(byte[] key, string data) + { + var dataBytes = Encoding.UTF8.GetBytes(data); +#if NET7_0_OR_GREATER + return HMACSHA256.HashData(key, dataBytes); +#else + using var ih = IncrementalHash.CreateHMAC(HashAlgorithmName.SHA256, key); + ih.AppendData(dataBytes); + return ih.GetHashAndReset(); +#endif + } + + async Task AuthenticateMD5(string username, byte[] salt, bool async, CancellationToken cancellationToken = default) + { + var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); + if (passwd == null) + throw new NpgsqlException("No password has been provided but the backend requires one (in MD5)"); + + byte[] result; +#if !NET7_0_OR_GREATER + using (var md5 = MD5.Create()) +#endif + { + // First phase + var passwordBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(passwd); + var usernameBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(username); + var cryptBuf = new byte[passwordBytes.Length + usernameBytes.Length]; + passwordBytes.CopyTo(cryptBuf, 0); + usernameBytes.CopyTo(cryptBuf, passwordBytes.Length); + + var sb = new StringBuilder(); +#if NET7_0_OR_GREATER + var hashResult = MD5.HashData(cryptBuf); +#else + var hashResult = md5.ComputeHash(cryptBuf); +#endif + foreach (var b in hashResult) + sb.Append(b.ToString("x2")); + + var prehash = sb.ToString(); + + var prehashbytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(prehash); + cryptBuf = new byte[prehashbytes.Length + 4]; + + Array.Copy(salt, 0, cryptBuf, prehashbytes.Length, 4); + + // 2. + prehashbytes.CopyTo(cryptBuf, 0); + + sb = new StringBuilder("md5"); +#if NET7_0_OR_GREATER + hashResult = MD5.HashData(cryptBuf); +#else + hashResult = md5.ComputeHash(cryptBuf); +#endif + foreach (var b in hashResult) + sb.Append(b.ToString("x2")); + + var resultString = sb.ToString(); + result = new byte[Encoding.UTF8.GetByteCount(resultString) + 1]; + Encoding.UTF8.GetBytes(resultString, 0, resultString.Length, result, 0); + result[result.Length - 1] = 0; + } + + await WritePassword(result, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + } + +#if NET7_0_OR_GREATER + internal async Task AuthenticateGSS(bool async) + { + var targetName = $"{KerberosServiceName}/{Host}"; + + using var authContext = new NegotiateAuthentication(new NegotiateAuthenticationClientOptions{ TargetName = targetName}); + var data = authContext.GetOutgoingBlob(ReadOnlySpan.Empty, out var statusCode)!; + Debug.Assert(statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded); + await WritePassword(data, 0, data.Length, async, UserCancellationToken).ConfigureAwait(false); + await Flush(async, UserCancellationToken).ConfigureAwait(false); + while (true) + { + var response = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); + if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) + break; + if (response is not AuthenticationGSSContinueMessage gssMsg) + throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); + data = authContext.GetOutgoingBlob(gssMsg.AuthenticationData.AsSpan(), out statusCode)!; + if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) + throw new NpgsqlException($"Error while authenticating GSS/SSPI: {statusCode}"); + // We might get NegotiateAuthenticationStatusCode.Completed but the data will not be null + // This can happen if it's the first cycle, in which case we have to send that data to complete handshake (#4888) + if (data is null) + continue; + await WritePassword(data, 0, data.Length, async, UserCancellationToken).ConfigureAwait(false); + await Flush(async, UserCancellationToken).ConfigureAwait(false); + } + } +#endif + + async ValueTask GetPassword(string username, bool async, CancellationToken cancellationToken = default) + { + var password = await DataSource.GetPassword(async, cancellationToken).ConfigureAwait(false); + + if (password is not null) + return password; + + if (ProvidePasswordCallback is { } passwordCallback) + { + try + { + ConnectionLogger.LogTrace($"Taking password from {nameof(ProvidePasswordCallback)} delegate"); + password = passwordCallback(Host, Port, Settings.Database!, username); + } + catch (Exception e) + { + throw new NpgsqlException($"Obtaining password using {nameof(NpgsqlConnection)}.{nameof(ProvidePasswordCallback)} delegate failed", e); + } + } + + password ??= PostgresEnvironment.Password; + + if (password != null) + return password; + + var passFile = Settings.Passfile ?? PostgresEnvironment.PassFile ?? PostgresEnvironment.PassFileDefault; + if (passFile != null) + { + var matchingEntry = new PgPassFile(passFile!) + .GetFirstMatchingEntry(Host, Port, Settings.Database!, username); + if (matchingEntry != null) + { + ConnectionLogger.LogTrace("Taking password from pgpass file"); + password = matchingEntry.Password; + } + } + + return password; + } +} diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs new file mode 100644 index 0000000000..9e0fd45dd3 --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -0,0 +1,504 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +partial class NpgsqlConnector +{ + internal Task WriteDescribe(StatementOrPortal statementOrPortal, byte[] asciiName, bool async, CancellationToken cancellationToken = default) + { + NpgsqlWriteBuffer.AssertASCIIOnly(asciiName); + + var len = sizeof(byte) + // Message code + sizeof(int) + // Length + sizeof(byte) + // Statement or portal + (asciiName.Length + 1); // Statement/portal name + + var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); + if (writeBuffer.WriteSpaceLeft < len) + return FlushAndWrite(len, statementOrPortal, asciiName, async, cancellationToken); + + Write(writeBuffer, len, statementOrPortal, asciiName); + return Task.CompletedTask; + + async Task FlushAndWrite(int len, StatementOrPortal statementOrPortal, byte[] name, bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); + Write(WriteBuffer, len, statementOrPortal, name); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Write(NpgsqlWriteBuffer writeBuffer, int len, StatementOrPortal statementOrPortal, byte[] name) + { + writeBuffer.WriteByte(FrontendMessageCode.Describe); + writeBuffer.WriteInt32(len - 1); + writeBuffer.WriteByte((byte)statementOrPortal); + writeBuffer.WriteNullTerminatedString(name); + } + } + + internal Task WriteSync(bool async, CancellationToken cancellationToken = default) + { + const int len = sizeof(byte) + // Message code + sizeof(int); // Length + + var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); + if (writeBuffer.WriteSpaceLeft < len) + return FlushAndWrite(async, cancellationToken); + + Write(writeBuffer); + return Task.CompletedTask; + + async Task FlushAndWrite(bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); + Write(WriteBuffer); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Write(NpgsqlWriteBuffer writeBuffer) + { + writeBuffer.WriteByte(FrontendMessageCode.Sync); + writeBuffer.WriteInt32(len - 1); + } + } + + internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellationToken = default) + { + // Note: non-empty portal currently not supported + + const int len = sizeof(byte) + // Message code + sizeof(int) + // Length + sizeof(byte) + // Null-terminated portal name (always empty for now) + sizeof(int); // Max number of rows + + var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); + if (writeBuffer.WriteSpaceLeft < len) + return FlushAndWrite(maxRows, async, cancellationToken); + + Write(writeBuffer, maxRows); + return Task.CompletedTask; + + async Task FlushAndWrite(int maxRows, bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(10 <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length 10 which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); + Write(WriteBuffer, maxRows); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Write(NpgsqlWriteBuffer writeBuffer, int maxRows) + { + writeBuffer.WriteByte(FrontendMessageCode.Execute); + writeBuffer.WriteInt32(len - 1); + writeBuffer.WriteByte(0); // Portal is always empty for now + writeBuffer.WriteInt32(maxRows); + } + } + + internal async Task WriteParse(string sql, byte[] asciiName, List inputParameters, bool async, CancellationToken cancellationToken = default) + { + NpgsqlWriteBuffer.AssertASCIIOnly(asciiName); + + int queryByteLen; + try + { + queryByteLen = TextEncoding.GetByteCount(sql); + } + catch (Exception e) + { + Break(e); + throw; + } + + var writeBuffer = WriteBuffer; + var messageLength = + sizeof(byte) + // Message code + sizeof(int) + // Length + asciiName.Length + // Statement name + sizeof(byte) + // Null terminator for the statement name + queryByteLen + sizeof(byte) + // SQL query length plus null terminator + sizeof(ushort) + // Number of parameters + inputParameters.Count * sizeof(int); // Parameter OIDs + + + WriteBuffer.StartMessage(messageLength); + if (WriteBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1) + await Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.Parse); + WriteBuffer.WriteInt32(messageLength - 1); + WriteBuffer.WriteNullTerminatedString(asciiName); + + await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false); + + if (writeBuffer.WriteSpaceLeft < 1 + 2) + await Flush(async, cancellationToken).ConfigureAwait(false); + writeBuffer.WriteByte(0); // Null terminator for the query + writeBuffer.WriteUInt16((ushort)inputParameters.Count); + + var databaseInfo = DatabaseInfo; + foreach (var p in inputParameters) + { + if (writeBuffer.WriteSpaceLeft < 4) + await Flush(async, cancellationToken).ConfigureAwait(false); + + writeBuffer.WriteUInt32(databaseInfo.GetOid(p.PgTypeId).Value); + } + } + + internal async Task WriteBind( + List parameters, + string portal, + byte[] asciiName, + bool allResultTypesAreUnknown, + bool[]? unknownResultTypeList, + bool async, + CancellationToken cancellationToken = default) + { + NpgsqlWriteBuffer.AssertASCIIOnly(asciiName); + NpgsqlWriteBuffer.AssertASCIIOnly(portal); + + var headerLength = + sizeof(byte) + // Message code + sizeof(int) + // Message length + sizeof(byte) + // Portal is always empty (only a null terminator) + asciiName.Length + sizeof(byte) + // Statement name plus null terminator + sizeof(ushort); // Number of parameter format codes that follow + + var writeBuffer = WriteBuffer; + var formatCodesSum = 0; + var paramsLength = 0; + for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + { + var param = parameters[paramIndex]; + param.Bind(out var format, out var size); + paramsLength += size.Value > 0 ? size.Value : 0; + formatCodesSum += format.ToFormatCode(); + } + + var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == parameters.Count ? 1 : parameters.Count; + + var messageLength = headerLength + + sizeof(short) * formatCodeListLength + // List of format codes + sizeof(short) + // Number of parameters + sizeof(int) * parameters.Count + // Parameter lengths + paramsLength + // Parameter values + sizeof(short) + // Number of result format codes + sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes + + WriteBuffer.StartMessage(messageLength); + if (WriteBuffer.WriteSpaceLeft < headerLength) + { + Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header"); + await Flush(async, cancellationToken).ConfigureAwait(false); + } + + WriteBuffer.WriteByte(FrontendMessageCode.Bind); + WriteBuffer.WriteInt32(messageLength - 1); + Debug.Assert(portal == string.Empty); + writeBuffer.WriteByte(0); // Portal is always empty + + writeBuffer.WriteNullTerminatedString(asciiName); + writeBuffer.WriteInt16((short)formatCodeListLength); + + // 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match + if (formatCodeListLength == 1) + { + if (writeBuffer.WriteSpaceLeft < sizeof(short)) + await Flush(async, cancellationToken).ConfigureAwait(false); + writeBuffer.WriteInt16(DataFormat.Binary.ToFormatCode()); + } + else if (formatCodeListLength > 1) + { + for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + { + if (writeBuffer.WriteSpaceLeft < sizeof(short)) + await Flush(async, cancellationToken).ConfigureAwait(false); + writeBuffer.WriteInt16(parameters[paramIndex].Format.ToFormatCode()); + } + } + + if (writeBuffer.WriteSpaceLeft < sizeof(ushort)) + await Flush(async, cancellationToken).ConfigureAwait(false); + + writeBuffer.WriteUInt16((ushort)parameters.Count); + if (parameters.Count > 0) + { + var writer = writeBuffer.GetWriter(DatabaseInfo, async ? FlushMode.NonBlocking : FlushMode.Blocking); + try + { + for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + { + var param = parameters[paramIndex]; + await param.Write(async, writer, cancellationToken).ConfigureAwait(false); + } + } + catch(Exception ex) + { + Break(ex); + throw; + } + } + + if (unknownResultTypeList != null) + { + if (writeBuffer.WriteSpaceLeft < 2 + unknownResultTypeList.Length * 2) + await Flush(async, cancellationToken).ConfigureAwait(false); + writeBuffer.WriteInt16((short)unknownResultTypeList.Length); + foreach (var t in unknownResultTypeList) + writeBuffer.WriteInt16((short)(t ? 0 : 1)); + } + else + { + if (writeBuffer.WriteSpaceLeft < 4) + await Flush(async, cancellationToken).ConfigureAwait(false); + writeBuffer.WriteInt16(1); + writeBuffer.WriteInt16((short)(allResultTypesAreUnknown ? 0 : 1)); + } + } + + internal Task WriteClose(StatementOrPortal type, byte[] asciiName, bool async, CancellationToken cancellationToken = default) + { + var len = sizeof(byte) + // Message code + sizeof(int) + // Length + sizeof(byte) + // Statement or portal + asciiName.Length + sizeof(byte); // Statement or portal name plus null terminator + + var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); + if (writeBuffer.WriteSpaceLeft < len) + return FlushAndWrite(len, type, asciiName, async, cancellationToken); + + Write(writeBuffer, len, type, asciiName); + return Task.CompletedTask; + + async Task FlushAndWrite(int len, StatementOrPortal type, byte[] name, bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); + Write(WriteBuffer, len, type, name); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Write(NpgsqlWriteBuffer writeBuffer, int len, StatementOrPortal type, byte[] name) + { + writeBuffer.WriteByte(FrontendMessageCode.Close); + writeBuffer.WriteInt32(len - 1); + writeBuffer.WriteByte((byte)type); + writeBuffer.WriteNullTerminatedString(name); + } + } + + internal async Task WriteQuery(string sql, bool async, CancellationToken cancellationToken = default) + { + var queryByteLen = TextEncoding.GetByteCount(sql); + + var len = sizeof(byte) + + sizeof(int) + // Message length (including self excluding code) + queryByteLen + // Query byte length + sizeof(byte); + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < 1 + 4) + await Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.Query); + WriteBuffer.WriteInt32(len - 1); + + await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false); + if (WriteBuffer.WriteSpaceLeft < 1) + await Flush(async, cancellationToken).ConfigureAwait(false); + WriteBuffer.WriteByte(0); // Null terminator + } + + internal async Task WriteCopyDone(bool async, CancellationToken cancellationToken = default) + { + const int len = sizeof(byte) + // Message code + sizeof(int); // Length + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + await Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.CopyDone); + WriteBuffer.WriteInt32(len - 1); + } + + internal async Task WriteCopyFail(bool async, CancellationToken cancellationToken = default) + { + // Note: error message not supported for now + + const int len = sizeof(byte) + // Message code + sizeof(int) + // Length + sizeof(byte); // Error message is always empty (only a null terminator) + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + await Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.CopyFail); + WriteBuffer.WriteInt32(len - 1); + WriteBuffer.WriteByte(0); // Error message is always empty (only a null terminator) + } + + internal void WriteCancelRequest(int backendProcessId, int backendSecretKey) + { + const int len = sizeof(int) + // Length + sizeof(int) + // Cancel request code + sizeof(int) + // Backend process id + sizeof(int); // Backend secret key + + Debug.Assert(backendProcessId != 0); + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + Flush(false).GetAwaiter().GetResult(); + + WriteBuffer.WriteInt32(len); + WriteBuffer.WriteInt32(1234 << 16 | 5678); + WriteBuffer.WriteInt32(backendProcessId); + WriteBuffer.WriteInt32(backendSecretKey); + } + + internal void WriteTerminate() + { + const int len = sizeof(byte) + // Message code + sizeof(int); // Length + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + Flush(false).GetAwaiter().GetResult(); + + WriteBuffer.WriteByte(FrontendMessageCode.Terminate); + WriteBuffer.WriteInt32(len - 1); + } + + internal void WriteSslRequest() + { + const int len = sizeof(int) + // Length + sizeof(int); // SSL request code + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + Flush(false).GetAwaiter().GetResult(); + + WriteBuffer.WriteInt32(len); + WriteBuffer.WriteInt32(80877103); + } + + internal void WriteStartup(Dictionary parameters) + { + const int protocolVersion3 = 3 << 16; // 196608 + + var len = sizeof(int) + // Length + sizeof(int) + // Protocol version + sizeof(byte); // Trailing zero byte + + foreach (var kvp in parameters) + len += NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Key) + 1 + + NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1; + + // Should really never happen, just in case + WriteBuffer.StartMessage(len); + if (len > WriteBuffer.Size) + throw new Exception("Startup message bigger than buffer"); + + WriteBuffer.WriteInt32(len); + WriteBuffer.WriteInt32(protocolVersion3); + + foreach (var kv in parameters) + { + WriteBuffer.WriteString(kv.Key); + WriteBuffer.WriteByte(0); + WriteBuffer.WriteString(kv.Value); + WriteBuffer.WriteByte(0); + } + + WriteBuffer.WriteByte(0); + } + + #region Authentication + + internal Task WritePassword(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, 0, payload.Length, async, cancellationToken); + + internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default) + { + WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count); + if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int)) + await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.Password); + WriteBuffer.WriteInt32(sizeof(int) + count); + + if (count <= WriteBuffer.WriteSpaceLeft) + { + // The entire array fits in our WriteBuffer, copy it into the WriteBuffer as usual. + WriteBuffer.WriteBytes(payload, offset, count); + return; + } + + await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); + await WriteBuffer.DirectWrite(new ReadOnlyMemory(payload, offset, count), async, cancellationToken).ConfigureAwait(false); + } + + internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialResponse, bool async, CancellationToken cancellationToken = default) + { + var len = sizeof(byte) + // Message code + sizeof(int) + // Length + NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(mechanism) + sizeof(byte) + // Mechanism plus null terminator + sizeof(int) + // Initial response length + (initialResponse?.Length ?? 0); // Initial response payload + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); + + WriteBuffer.WriteByte(FrontendMessageCode.Password); + WriteBuffer.WriteInt32(len - 1); + + WriteBuffer.WriteString(mechanism); + WriteBuffer.WriteByte(0); // null terminator + if (initialResponse == null) + WriteBuffer.WriteInt32(-1); + else + { + WriteBuffer.WriteInt32(initialResponse.Length); + WriteBuffer.WriteBytes(initialResponse); + } + } + + internal Task WriteSASLResponse(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, async, cancellationToken); + + #endregion Authentication + + internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default) + { + WriteBuffer.StartMessage(data.Length); + if (WriteBuffer.WriteSpaceLeft < data.Length) + return FlushAndWrite(data, async, cancellationToken); + + WriteBuffer.WriteBytes(data, 0, data.Length); + return Task.CompletedTask; + + async Task FlushAndWrite(byte[] data, bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(data.Length <= WriteBuffer.WriteSpaceLeft, $"Pregenerated message has length {data.Length} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); + WriteBuffer.WriteBytes(data, 0, data.Length); + } + } + + internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult(); + + internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken); +} diff --git a/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs b/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs new file mode 100644 index 0000000000..e750e730cb --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs @@ -0,0 +1,176 @@ +using System; +using System.IO; +using System.Net; +using System.Net.Security; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.BackendMessages; +using static Npgsql.Util.Statics; + +namespace Npgsql.Internal; + + +partial class NpgsqlConnector +{ +#if !NET6_0_OR_GREATER + static byte[] Hi(string str, byte[] salt, int count) + { + using var hmac = new HMACSHA256(Encoding.UTF8.GetBytes(str)); + var salt1 = new byte[salt.Length + 4]; + byte[] hi, u1; + + Buffer.BlockCopy(salt, 0, salt1, 0, salt.Length); + salt1[salt1.Length - 1] = 1; + + hi = u1 = hmac.ComputeHash(salt1); + + for (var i = 1; i < count; i++) + { + var u2 = hmac.ComputeHash(u1); + NpgsqlConnector.Xor(hi, u2); + u1 = u2; + } + + return hi; + } +#endif + +#if !NET7_0_OR_GREATER + internal async Task AuthenticateGSS(bool async) + { + var targetName = $"{KerberosServiceName}/{Host}"; + + using var negotiateStream = new NegotiateStream(new GSSPasswordMessageStream(this), true); + try + { + if (async) + await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false); + else + negotiateStream.AuthenticateAsClient(CredentialCache.DefaultNetworkCredentials, targetName); + } + catch (AuthenticationCompleteException) + { + return; + } + catch (IOException e) when (e.InnerException is AuthenticationCompleteException) + { + return; + } + catch (IOException e) when (e.InnerException is PostgresException) + { + throw e.InnerException; + } + + throw new NpgsqlException("NegotiateStream.AuthenticateAsClient completed unexpectedly without signaling success"); + } + + /// + /// This Stream is placed between NegotiateStream and the socket's NetworkStream (or SSLStream). It intercepts + /// traffic and performs the following operations: + /// * Outgoing messages are framed in PostgreSQL's PasswordMessage, and incoming are stripped of it. + /// * NegotiateStream frames payloads with a 5-byte header, which PostgreSQL doesn't understand. This header is + /// stripped from outgoing messages and added to incoming ones. + /// + /// + /// See https://referencesource.microsoft.com/#System/net/System/Net/_StreamFramer.cs,16417e735f0e9530,references + /// + sealed class GSSPasswordMessageStream : Stream + { + readonly NpgsqlConnector _connector; + int _leftToWrite; + int _leftToRead, _readPos; + byte[]? _readBuf; + + internal GSSPasswordMessageStream(NpgsqlConnector connector) + => _connector = connector; + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + => Write(buffer, offset, count, true, cancellationToken); + + public override void Write(byte[] buffer, int offset, int count) + => Write(buffer, offset, count, false).GetAwaiter().GetResult(); + + async Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) + { + if (_leftToWrite == 0) + { + // We're writing the frame header, which contains the payload size. + _leftToWrite = (buffer[3] << 8) | buffer[4]; + + buffer[0] = 22; + if (buffer[1] != 1) + throw new NotSupportedException($"Received frame header major v {buffer[1]} (different from 1)"); + if (buffer[2] != 0) + throw new NotSupportedException($"Received frame header minor v {buffer[2]} (different from 0)"); + + // In case of payload data in the same buffer just after the frame header + if (count == 5) + return; + count -= 5; + offset += 5; + } + + if (count > _leftToWrite) + throw new NpgsqlException($"NegotiateStream trying to write {count} bytes but according to frame header we only have {_leftToWrite} left!"); + await _connector.WritePassword(buffer, offset, count, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + _leftToWrite -= count; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + => Read(buffer, offset, count, true, cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) + => Read(buffer, offset, count, false).GetAwaiter().GetResult(); + + async Task Read(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) + { + if (_leftToRead == 0) + { + var response = ExpectAny(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) + throw new AuthenticationCompleteException(); + var gssMsg = response as AuthenticationGSSContinueMessage; + if (gssMsg == null) + throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); + _readBuf = gssMsg.AuthenticationData; + _leftToRead = gssMsg.AuthenticationData.Length; + _readPos = 0; + buffer[0] = 22; + buffer[1] = 1; + buffer[2] = 0; + buffer[3] = (byte)((_leftToRead >> 8) & 0xFF); + buffer[4] = (byte)(_leftToRead & 0xFF); + return 5; + } + + if (count > _leftToRead) + throw new NpgsqlException($"NegotiateStream trying to read {count} bytes but according to frame header we only have {_leftToRead} left!"); + count = Math.Min(count, _leftToRead); + Array.Copy(_readBuf!, _readPos, buffer, offset, count); + _leftToRead -= count; + return count; + } + + public override void Flush() { } + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + } + + sealed class AuthenticationCompleteException : Exception { } +#endif +} diff --git a/src/Npgsql/Internal/NpgsqlConnector.cs b/src/Npgsql/Internal/NpgsqlConnector.cs new file mode 100644 index 0000000000..446ff9a383 --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlConnector.cs @@ -0,0 +1,2910 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.ExceptionServices; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Util; +using static Npgsql.Util.Statics; +using System.Transactions; +using Microsoft.Extensions.Logging; +using Npgsql.Properties; + +namespace Npgsql.Internal; + +/// +/// Represents a connection to a PostgreSQL backend. Unlike NpgsqlConnection objects, which are +/// exposed to users, connectors are internal to Npgsql and are recycled by the connection pool. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public sealed partial class NpgsqlConnector +{ + #region Fields and Properties + + /// + /// The physical connection socket to the backend. + /// + Socket _socket = default!; + + /// + /// The physical connection stream to the backend, without anything on top. + /// + NetworkStream _baseStream = default!; + + /// + /// The physical connection stream to the backend, layered with an SSL/TLS stream if in secure mode. + /// + Stream _stream = default!; + + /// + /// The parsed connection string. + /// + public NpgsqlConnectionStringBuilder Settings { get; } + + Action? ClientCertificatesCallback { get; } + RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; } +#pragma warning disable CS0618 // ProvidePasswordCallback is obsolete + ProvidePasswordCallback? ProvidePasswordCallback { get; } +#pragma warning restore CS0618 + + public Encoding TextEncoding { get; private set; } = default!; + + /// + /// Same as , except that it does not throw an exception if an invalid char is + /// encountered (exception fallback), but rather replaces it with a question mark character (replacement + /// fallback). + /// + internal Encoding RelaxedTextEncoding { get; private set; } = default!; + + /// + /// Buffer used for reading data. + /// + internal NpgsqlReadBuffer ReadBuffer { get; private set; } = default!; + + /// + /// If we read a data row that's bigger than , we allocate an oversize buffer. + /// The original (smaller) buffer is stored here, and restored when the connection is reset. + /// + NpgsqlReadBuffer? _origReadBuffer; + + /// + /// Buffer used for writing data. + /// + internal NpgsqlWriteBuffer WriteBuffer { get; private set; } = default!; + + /// + /// The secret key of the backend for this connector, used for query cancellation. + /// + int _backendSecretKey; + + /// + /// The process ID of the backend for this connector. + /// + internal int BackendProcessId { get; private set; } + + string? _inferredUserName; + + /// + /// The user name that has been inferred when the connector was opened + /// + internal string InferredUserName + { + get => _inferredUserName ?? throw new InvalidOperationException($"{nameof(InferredUserName)} cannot be accessed before the connector has been opened."); + private set => _inferredUserName = value; + } + + bool SupportsPostgresCancellation => BackendProcessId != 0; + + /// + /// A unique ID identifying this connector, used for logging. Currently mapped to BackendProcessId + /// + internal int Id => BackendProcessId; + + internal PgSerializerOptions SerializerOptions { get; set; } = default!; + + /// + /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). + /// + public NpgsqlDatabaseInfo DatabaseInfo { get; internal set; } = default!; + + /// + /// The current transaction status for this connector. + /// + internal TransactionStatus TransactionStatus { get; set; } + + /// + /// A transaction object for this connector. Since only one transaction can be in progress at any given time, + /// this instance is recycled. To check whether a transaction is currently in progress on this connector, + /// see . + /// + internal NpgsqlTransaction? Transaction { get; set; } + + internal NpgsqlTransaction? UnboundTransaction { get; set; } + + /// + /// The NpgsqlConnection that (currently) owns this connector. Null if the connector isn't + /// owned (i.e. idle in the pool) + /// + internal NpgsqlConnection? Connection { get; set; } + + /// + /// The number of messages that were prepended to the current message chain, but not yet sent. + /// Note that this only tracks messages which produce a ReadyForQuery message + /// + internal int PendingPrependedResponses { get; set; } + + /// + /// A ManualResetEventSlim used to make sure a cancellation request doesn't run + /// while we're reading responses for the prepended query + /// as we can't gracefully handle their cancellation. + /// + readonly ManualResetEventSlim ReadingPrependedMessagesMRE = new(initialState: true); + + internal NpgsqlDataReader? CurrentReader; + + internal PreparedStatementManager PreparedStatementManager { get; } + + internal SqlQueryParser SqlQueryParser { get; } = new(); + + /// + /// If the connector is currently in COPY mode, holds a reference to the importer/exporter object. + /// Otherwise null. + /// + internal ICancelable? CurrentCopyOperation; + + /// + /// Holds all run-time parameters received from the backend (via ParameterStatus messages) + /// + internal Dictionary PostgresParameters { get; } + + /// + /// Holds all run-time parameters in raw, binary format for efficient handling without allocations. + /// + readonly List<(byte[] Name, byte[] Value)> _rawParameters = new(); + + /// + /// If this connector was broken, this contains the exception that caused the break. + /// + volatile Exception? _breakReason; + + // Used by replication to change our cancellation behaviour on ColumnStreams. + internal bool LongRunningConnection { get; set; } + + /// + /// + /// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write + /// isn't started concurrently. Note that since we have only one write loop, this is only ever usedto + /// protect against an over-capacity writes into a connector that's currently *asynchronously* writing. + /// + /// + /// It is guaranteed that the currently-executing + /// Specifically, reading may occur - and the connector may even be returned to the pool - before this is + /// released. + /// + /// + internal volatile int MultiplexAsyncWritingLock; + + /// + internal void FlagAsNotWritableForMultiplexing() + { + Debug.Assert(Settings.Multiplexing); + Debug.Assert(CommandsInFlightCount > 0 || IsBroken || IsClosed, + $"About to mark multiplexing connector as non-writable, but {nameof(CommandsInFlightCount)} is {CommandsInFlightCount}"); + + Interlocked.Exchange(ref MultiplexAsyncWritingLock, 1); + } + + /// + internal void FlagAsWritableForMultiplexing() + { + Debug.Assert(Settings.Multiplexing); + if (Interlocked.CompareExchange(ref MultiplexAsyncWritingLock, 0, 1) != 1) + throw new Exception("Multiplexing lock was not taken when releasing. Please report a bug."); + } + + /// + /// A lock that's taken while a cancellation is being delivered; new queries are blocked until the + /// cancellation is delivered. This reduces the chance that a cancellation meant for a previous + /// command will accidentally cancel a later one, see #615. + /// + object CancelLock { get; } = new(); + + /// + /// A lock that's taken to make sure no other concurrent operation is running. + /// Break takes it to set the state of the connector. + /// Anyone else should immediately check the state and exit + /// if the connector is closed. + /// + object SyncObj { get; } = new(); + + /// + /// A lock that's used to wait for the Cleanup to complete while breaking the connection. + /// + object CleanupLock { get; } = new(); + + readonly bool _isKeepAliveEnabled; + readonly Timer? _keepAliveTimer; + + /// + /// The command currently being executed by the connector, null otherwise. + /// Used only for concurrent use error reporting purposes. + /// + NpgsqlCommand? _currentCommand; + + bool _sendResetOnClose; + + /// + /// The connector source (e.g. pool) from where this connector came, and to which it will be returned. + /// Note that in multi-host scenarios, this references the host-specific rather than the + /// . + /// + internal NpgsqlDataSource DataSource { get; } + + internal string UserFacingConnectionString => DataSource.ConnectionString; + + /// + /// Contains the UTC timestamp when this connector was opened, used to implement + /// . + /// + internal DateTime OpenTimestamp { get; private set; } + + internal int ClearCounter { get; set; } + + volatile bool _postgresCancellationPerformed; + internal bool PostgresCancellationPerformed + { + get => _postgresCancellationPerformed; + private set => _postgresCancellationPerformed = value; + } + + volatile bool _userCancellationRequested; + CancellationTokenRegistration _cancellationTokenRegistration; + internal bool UserCancellationRequested => _userCancellationRequested; + internal CancellationToken UserCancellationToken { get; set; } + internal bool AttemptPostgresCancellation { get; private set; } + static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.FromMilliseconds(-1); + + IDisposable? _certificate; + + internal NpgsqlLoggingConfiguration LoggingConfiguration { get; } + + internal ILogger ConnectionLogger { get; } + internal ILogger CommandLogger { get; } + internal ILogger TransactionLogger { get; } + internal ILogger CopyLogger { get; } + + internal readonly Stopwatch QueryLogStopWatch = new(); + + internal EndPoint? ConnectedEndPoint { get; private set; } + + #endregion + + #region Constants + + /// + /// The minimum timeout that can be set on internal commands such as COMMIT, ROLLBACK. + /// + /// Precision is seconds + internal const int MinimumInternalCommandTimeout = 3; + + #endregion + + #region Reusable Message Objects + + byte[]? _resetWithoutDeallocateMessage; + + int _resetWithoutDeallocateResponseCount; + + // Backend + readonly CommandCompleteMessage _commandCompleteMessage = new(); + readonly ReadyForQueryMessage _readyForQueryMessage = new(); + readonly ParameterDescriptionMessage _parameterDescriptionMessage = new(); + readonly DataRowMessage _dataRowMessage = new(); + readonly RowDescriptionMessage _rowDescriptionMessage = new(connectorOwned: true); + + // Since COPY is rarely used, allocate these lazily + CopyInResponseMessage? _copyInResponseMessage; + CopyOutResponseMessage? _copyOutResponseMessage; + CopyDataMessage? _copyDataMessage; + CopyBothResponseMessage? _copyBothResponseMessage; + + #endregion + + internal NpgsqlDataReader DataReader { get; set; } + + internal NpgsqlDataReader? UnboundDataReader { get; set; } + + #region Constructors + + internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn) + : this(dataSource) + { + if (conn.ProvideClientCertificatesCallback is not null) + ClientCertificatesCallback = certs => conn.ProvideClientCertificatesCallback(certs); + if (conn.UserCertificateValidationCallback is not null) + UserCertificateValidationCallback = conn.UserCertificateValidationCallback; + +#pragma warning disable CS0618 // Obsolete + ProvidePasswordCallback = conn.ProvidePasswordCallback; +#pragma warning restore CS0618 + } + + NpgsqlConnector(NpgsqlConnector connector) + : this(connector.DataSource) + { + ClientCertificatesCallback = connector.ClientCertificatesCallback; + UserCertificateValidationCallback = connector.UserCertificateValidationCallback; + ProvidePasswordCallback = connector.ProvidePasswordCallback; + } + + NpgsqlConnector(NpgsqlDataSource dataSource) + { + Debug.Assert(dataSource.OwnsConnectors); + + DataSource = dataSource; + + LoggingConfiguration = dataSource.LoggingConfiguration; + ConnectionLogger = LoggingConfiguration.ConnectionLogger; + CommandLogger = LoggingConfiguration.CommandLogger; + TransactionLogger = LoggingConfiguration.TransactionLogger; + CopyLogger = LoggingConfiguration.CopyLogger; + + ClientCertificatesCallback = dataSource.ClientCertificatesCallback; + UserCertificateValidationCallback = dataSource.UserCertificateValidationCallback; + + State = ConnectorState.Closed; + TransactionStatus = TransactionStatus.Idle; + Settings = dataSource.Settings; + PostgresParameters = new Dictionary(); + + _isKeepAliveEnabled = Settings.KeepAlive > 0; + if (_isKeepAliveEnabled) + _keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite); + + DataReader = new NpgsqlDataReader(this); + + // TODO: Not just for automatic preparation anymore... + PreparedStatementManager = new PreparedStatementManager(this); + + if (Settings.Multiplexing) + { + // Note: It's OK for this channel to be unbounded: each command enqueued to it is accompanied by sending + // it to PostgreSQL. If we overload it, a TCP zero window will make us block on the networking side + // anyway. + // Note: the in-flight channel can probably be single-writer, but that doesn't actually do anything + // at this point. And we currently rely on being able to complete the channel at any point (from + // Break). We may want to revisit this if an optimized, SingleWriter implementation is introduced. + var commandsInFlightChannel = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true }); + CommandsInFlightReader = commandsInFlightChannel.Reader; + CommandsInFlightWriter = commandsInFlightChannel.Writer; + + // TODO: Properly implement this + if (_isKeepAliveEnabled) + throw new NotImplementedException("Keepalive not yet implemented for multiplexing"); + } + } + + #endregion + + #region Configuration settings + + internal string Host => Settings.Host!; + internal int Port => Settings.Port; + internal string Database => Settings.Database!; + string KerberosServiceName => Settings.KerberosServiceName; + int ConnectionTimeout => Settings.Timeout; + + #endregion Configuration settings + + #region State management + + int _state; + + /// + /// Gets the current state of the connector + /// + internal ConnectorState State + { + get => (ConnectorState)_state; + set + { + var newState = (int)value; + if (newState == _state) + return; + + if (newState is < 0 or > (int)ConnectorState.Replication) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(value), "Unknown state: " + value); + + Interlocked.Exchange(ref _state, newState); + } + } + + /// + /// Returns whether the connector is open, regardless of any task it is currently performing + /// + bool IsConnected => State is not (ConnectorState.Closed or ConnectorState.Connecting or ConnectorState.Broken); + + internal bool IsReady => State == ConnectorState.Ready; + internal bool IsClosed => State == ConnectorState.Closed; + internal bool IsBroken => State == ConnectorState.Broken; + + #endregion + + #region Open + + /// + /// Opens the physical connection to the server. + /// + /// Usually called by the RequestConnector + /// Method of the connection pool manager. + internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + Debug.Assert(State == ConnectorState.Closed); + + State = ConnectorState.Connecting; + LogMessages.OpeningPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString); + var stopwatch = Stopwatch.StartNew(); + + try + { + await OpenCore(this, Settings.SslMode, timeout, async, cancellationToken).ConfigureAwait(false); + + await DataSource.Bootstrap(this, timeout, forceReload: false, async, cancellationToken).ConfigureAwait(false); + + Debug.Assert(DataSource.SerializerOptions is not null); + Debug.Assert(DataSource.DatabaseInfo is not null); + SerializerOptions = DataSource.SerializerOptions; + DatabaseInfo = DataSource.DatabaseInfo; + + if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard) + { + _sendResetOnClose = true; + GenerateResetMessage(); + } + + OpenTimestamp = DateTime.UtcNow; + + if (Settings.Multiplexing) + { + // Start an infinite async loop, which processes incoming multiplexing traffic. + // It is intentionally not awaited and will run as long as the connector is alive. + // The CommandsInFlightWriter channel is completed in Cleanup, which should cause this task + // to complete. + _ = Task.Run(MultiplexingReadLoop, CancellationToken.None) + .ContinueWith(t => + { + // Note that we *must* observe the exception if the task is faulted. + ConnectionLogger.LogError(t.Exception!, "Exception bubbled out of multiplexing read loop", Id); + }, TaskContinuationOptions.OnlyOnFaulted); + } + + if (_isKeepAliveEnabled) + { + // Start the keep alive mechanism to work by scheduling the timer. + // Otherwise, it doesn't work for cases when no query executed during + // the connection lifetime in case of a new connector. + lock (SyncObj) + { + var keepAlive = Settings.KeepAlive * 1000; + _keepAliveTimer!.Change(keepAlive, keepAlive); + } + } + + if (DataSource.ConnectionInitializerAsync is not null) + { + Debug.Assert(DataSource.ConnectionInitializer is not null); + + var tempConnection = new NpgsqlConnection(DataSource, this); + + try + { + if (async) + await DataSource.ConnectionInitializerAsync(tempConnection).ConfigureAwait(false); + else if (!async) + DataSource.ConnectionInitializer(tempConnection); + } + finally + { + // Note that we can't just close/dispose the NpgsqlConnection, since that puts the connector back in the pool. + // But we transition it to disposed immediately, in case the user decides to capture the NpgsqlConnection and use it + // later. + Connection?.MakeDisposed(); + Connection = null; + } + } + + LogMessages.OpenedPhysicalConnection( + ConnectionLogger, Host, Port, Database, UserFacingConnectionString, stopwatch.ElapsedMilliseconds, Id); + } + catch (Exception e) + { + Break(e); + throw; + } + + static async Task OpenCore( + NpgsqlConnector conn, + SslMode sslMode, + NpgsqlTimeout timeout, + bool async, + CancellationToken cancellationToken, + bool isFirstAttempt = true) + { + await conn.RawOpen(sslMode, timeout, async, cancellationToken, isFirstAttempt).ConfigureAwait(false); + + var username = await conn.GetUsernameAsync(async, cancellationToken).ConfigureAwait(false); + + timeout.CheckAndApply(conn); + conn.WriteStartupMessage(username); + await conn.Flush(async, cancellationToken).ConfigureAwait(false); + + using var cancellationRegistration = conn.StartCancellableOperation(cancellationToken, attemptPgCancellation: false); + try + { + await conn.Authenticate(username, timeout, async, cancellationToken).ConfigureAwait(false); + } + catch (PostgresException e) + when (e.SqlState == PostgresErrorCodes.InvalidAuthorizationSpecification && + (sslMode == SslMode.Prefer && conn.IsSecure || sslMode == SslMode.Allow && !conn.IsSecure)) + { + cancellationRegistration.Dispose(); + Debug.Assert(!conn.IsBroken); + + conn.Cleanup(); + + // If Prefer was specified and we failed (with SSL), retry without SSL. + // If Allow was specified and we failed (without SSL), retry with SSL + await OpenCore( + conn, + sslMode == SslMode.Prefer ? SslMode.Disable : SslMode.Require, + timeout, + async, + cancellationToken, + isFirstAttempt: false).ConfigureAwait(false); + + return; + } + + // We treat BackendKeyData as optional because some PostgreSQL-like database + // don't send it (CockroachDB, CrateDB) + var msg = await conn.ReadMessage(async).ConfigureAwait(false); + if (msg.Code == BackendMessageCode.BackendKeyData) + { + var keyDataMsg = (BackendKeyDataMessage)msg; + conn.BackendProcessId = keyDataMsg.BackendProcessId; + conn._backendSecretKey = keyDataMsg.BackendSecretKey; + msg = await conn.ReadMessage(async).ConfigureAwait(false); + } + + if (msg.Code != BackendMessageCode.ReadyForQuery) + throw new NpgsqlException($"Received backend message {msg.Code} while expecting ReadyForQuery. Please file a bug."); + + conn.State = ConnectorState.Ready; + } + } + + internal async ValueTask QueryDatabaseState( + NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken = default) + { + using var batch = CreateBatch(); + batch.BatchCommands.Add(new NpgsqlBatchCommand("select pg_is_in_recovery()")); + batch.BatchCommands.Add(new NpgsqlBatchCommand("SHOW default_transaction_read_only")); + batch.Timeout = (int)timeout.CheckAndGetTimeLeft().TotalSeconds; + + var reader = async ? await batch.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false) : batch.ExecuteReader(); + try + { + if (async) + { + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + _isHotStandBy = reader.GetBoolean(0); + await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + else + { + reader.Read(); + _isHotStandBy = reader.GetBoolean(0); + reader.NextResult(); + reader.Read(); + } + + _isTransactionReadOnly = reader.GetString(0) != "off"; + + var databaseState = UpdateDatabaseState(); + Debug.Assert(databaseState.HasValue); + return databaseState.Value; + } + finally + { + if (async) + await reader.DisposeAsync().ConfigureAwait(false); + else + reader.Dispose(); + } + } + + void WriteStartupMessage(string username) + { + var startupParams = new Dictionary + { + ["user"] = username, + ["client_encoding"] = Settings.ClientEncoding ?? + PostgresEnvironment.ClientEncoding ?? + "UTF8" + }; + + if (Settings.Database is not null) + startupParams["database"] = Settings.Database; + + if (Settings.ApplicationName?.Length > 0) + startupParams["application_name"] = Settings.ApplicationName; + + if (Settings.SearchPath?.Length > 0) + startupParams["search_path"] = Settings.SearchPath; + + var timezone = Settings.Timezone ?? PostgresEnvironment.TimeZone; + if (timezone != null) + startupParams["TimeZone"] = timezone; + + var options = Settings.Options ?? PostgresEnvironment.Options; + if (options?.Length > 0) + startupParams["options"] = options; + + switch (Settings.ReplicationMode) + { + case ReplicationMode.Logical: + startupParams["replication"] = "database"; + break; + case ReplicationMode.Physical: + startupParams["replication"] = "true"; + break; + } + + WriteStartup(startupParams); + } + + ValueTask GetUsernameAsync(bool async, CancellationToken cancellationToken) + { + var username = Settings.Username; + if (username?.Length > 0) + { + InferredUserName = username; + return new(username); + } + + username = PostgresEnvironment.User; + if (username?.Length > 0) + { + InferredUserName = username; + return new(username); + } + + return GetUsernameAsyncInternal(); + + async ValueTask GetUsernameAsyncInternal() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + username = await DataSource.IntegratedSecurityHandler.GetUsername(async, Settings.IncludeRealm, ConnectionLogger, + cancellationToken).ConfigureAwait(false); + + if (username?.Length > 0) + { + InferredUserName = username; + return username; + } + } + + username = Environment.UserName; + if (username?.Length > 0) + { + InferredUserName = username; + return username; + } + + throw new NpgsqlException("No username could be found, please specify one explicitly"); + } + } + + async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken, bool isFirstAttempt = true) + { + try + { + if (async) + await ConnectAsync(timeout, cancellationToken).ConfigureAwait(false); + else + Connect(timeout); + + _baseStream = new NetworkStream(_socket, true); + _stream = _baseStream; + + if (Settings.Encoding == "UTF8") + { + TextEncoding = NpgsqlWriteBuffer.UTF8Encoding; + RelaxedTextEncoding = NpgsqlWriteBuffer.RelaxedUTF8Encoding; + } + else + { + TextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback); + RelaxedTextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ReplacementFallback, DecoderFallback.ReplacementFallback); + } + + ReadBuffer = new NpgsqlReadBuffer(this, _stream, _socket, Settings.ReadBufferSize, TextEncoding, RelaxedTextEncoding); + WriteBuffer = new NpgsqlWriteBuffer(this, _stream, _socket, Settings.WriteBufferSize, TextEncoding); + + timeout.CheckAndApply(this); + + IsSecure = false; + + if ((sslMode is SslMode.Prefer && DataSource.TransportSecurityHandler.SupportEncryption) || + sslMode is SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull) + { + WriteSslRequest(); + await Flush(async, cancellationToken).ConfigureAwait(false); + + await ReadBuffer.Ensure(1, async).ConfigureAwait(false); + var response = (char)ReadBuffer.ReadByte(); + timeout.CheckAndApply(this); + + switch (response) + { + default: + throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)"); + case 'N': + if (sslMode != SslMode.Prefer) + throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured."); + break; + case 'S': + await DataSource.TransportSecurityHandler.NegotiateEncryption(async, this, sslMode, timeout, isFirstAttempt).ConfigureAwait(false); + break; + } + + if (ReadBuffer.ReadBytesLeft > 0) + throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack."); + } + + ConnectionLogger.LogTrace("Socket connected to {Host}:{Port}", Host, Port); + } + catch + { + _stream?.Dispose(); + _stream = null!; + + _baseStream?.Dispose(); + _baseStream = null!; + + _socket?.Dispose(); + _socket = null!; + + throw; + } + } + + internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, bool async, bool isFirstAttempt) + { + var clientCertificates = new X509Certificate2Collection(); + var certPath = Settings.SslCertificate ?? PostgresEnvironment.SslCert ?? PostgresEnvironment.SslCertDefault; + + if (certPath != null) + { + var password = Settings.SslPassword; + + X509Certificate2? cert = null; + if (Path.GetExtension(certPath).ToUpperInvariant() != ".PFX") + { +#if NET5_0_OR_GREATER + // It's PEM time + var keyPath = Settings.SslKey ?? PostgresEnvironment.SslKey ?? PostgresEnvironment.SslKeyDefault; + cert = string.IsNullOrEmpty(password) + ? X509Certificate2.CreateFromPemFile(certPath, keyPath) + : X509Certificate2.CreateFromEncryptedPemFile(certPath, password, keyPath); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // Windows crypto API has a bug with pem certs + // See #3650 + using var previousCert = cert; + cert = new X509Certificate2(cert.Export(X509ContentType.Pkcs12)); + } + +#else + // Technically PEM certificates are supported as of .NET 5 but we don't build for the net5.0 + // TFM anymore since .NET 5 is out of support + // This is a breaking change for .NET 5 as of Npgsql 8! + throw new NotSupportedException("PEM certificates are only supported with .NET 6 and higher"); +#endif + } + + cert ??= new X509Certificate2(certPath, password); + clientCertificates.Add(cert); + + _certificate = cert; + } + + try + { + ClientCertificatesCallback?.Invoke(clientCertificates); + + var checkCertificateRevocation = Settings.CheckCertificateRevocation; + + RemoteCertificateValidationCallback? certificateValidationCallback; + X509Certificate2? caCert; + string? certRootPath = null; + + if (UserCertificateValidationCallback is not null) + { + if (sslMode is SslMode.VerifyCA or SslMode.VerifyFull) + throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode)); + + if (Settings.RootCertificate is not null) + throw new ArgumentException(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback); + + if (DataSource.TransportSecurityHandler.RootCertificateCallback is not null) + throw new ArgumentException(NpgsqlStrings.CannotUseValidationRootCertificateCallbackWithUserCallback); + + certificateValidationCallback = UserCertificateValidationCallback; + } + else if (sslMode is SslMode.Prefer or SslMode.Require) + { + certificateValidationCallback = SslTrustServerValidation; + checkCertificateRevocation = false; + } + else if ((caCert = DataSource.TransportSecurityHandler.RootCertificateCallback?.Invoke()) is not null || + (certRootPath = Settings.RootCertificate ?? + PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is not null) + { + certificateValidationCallback = SslRootValidation(sslMode == SslMode.VerifyFull, certRootPath, caCert); + } + else if (sslMode == SslMode.VerifyCA) + { + certificateValidationCallback = SslVerifyCAValidation; + } + else + { + Debug.Assert(sslMode == SslMode.VerifyFull); + certificateValidationCallback = SslVerifyFullValidation; + } + + var host = Host; + +#if !NET8_0_OR_GREATER + // If the host is a valid IP address - replace it with an empty string + // We do that because .NET uses targetHost argument to send SNI to the server + // RFC explicitly prohibits sending an IP address so some servers might fail + // This was already fixed for .NET 8 + // See #5543 for discussion + if (IPAddress.TryParse(host, out _)) + host = string.Empty; +#endif + + timeout.CheckAndApply(this); + + try + { + var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false, certificateValidationCallback); + + var sslProtocols = SslProtocols.None; +#if NETSTANDARD2_0 + // On .NET Framework SslProtocols.None can be disabled, see #3718 + sslProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12; +#endif + + if (async) + await sslStream.AuthenticateAsClientAsync(host, clientCertificates, sslProtocols, checkCertificateRevocation).ConfigureAwait(false); + else + sslStream.AuthenticateAsClient(host, clientCertificates, sslProtocols, checkCertificateRevocation); + + _stream = sslStream; + } + catch (Exception e) + { + throw new NpgsqlException("Exception while performing SSL handshake", e); + } + + ReadBuffer.Underlying = _stream; + WriteBuffer.Underlying = _stream; + IsSecure = true; + ConnectionLogger.LogTrace("SSL negotiation successful"); + } + catch + { + _certificate?.Dispose(); + _certificate = null; + + throw; + } + } + + void Connect(NpgsqlTimeout timeout) + { + // Note that there aren't any timeout-able or cancellable DNS methods + var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath) + ? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) } + : IPAddressesToEndpoints(Dns.GetHostAddresses(Host), Port); + timeout.Check(); + + // Give each endpoint an equal share of the remaining time + var perEndpointTimeout = -1; // Default to infinity + if (timeout.IsSet) + perEndpointTimeout = (int)(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length / 10); + + for (var i = 0; i < endpoints.Length; i++) + { + var endpoint = endpoints[i]; + ConnectionLogger.LogTrace("Attempting to connect to {Endpoint}", endpoint); + var protocolType = + endpoint.AddressFamily == AddressFamily.InterNetwork || + endpoint.AddressFamily == AddressFamily.InterNetworkV6 + ? ProtocolType.Tcp + : ProtocolType.IP; + var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType) + { + Blocking = false + }; + + try + { + try + { + socket.Connect(endpoint); + } + catch (SocketException e) + { + if (e.SocketErrorCode != SocketError.WouldBlock) + throw; + } + var write = new List { socket }; + var error = new List { socket }; + Socket.Select(null, write, error, perEndpointTimeout); + var errorCode = (int) socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.Error)!; + if (errorCode != 0) + throw new SocketException(errorCode); + if (write.Count is 0) + throw new TimeoutException("Timeout during connection attempt"); + socket.Blocking = true; + SetSocketOptions(socket); + _socket = socket; + ConnectedEndPoint = endpoint; + return; + } + catch (Exception e) + { + try { socket.Dispose(); } + catch + { + // ignored + } + + ConnectionLogger.LogTrace(e, "Failed to connect to {Endpoint}", endpoint); + + if (i == endpoints.Length - 1) + throw new NpgsqlException($"Failed to connect to {endpoint}", e); + } + } + } + + async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken) + { + Task GetHostAddressesAsync(CancellationToken ct) => +#if NET6_0_OR_GREATER + Dns.GetHostAddressesAsync(Host, ct); +#else + Dns.GetHostAddressesAsync(Host); +#endif + + // Whether the framework and/or the OS platform support Dns.GetHostAddressesAsync cancellation API or they do not, + // we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting + // and raises the exception, while the actual task may be left running. + var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath) + ? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) } + : IPAddressesToEndpoints(await TaskTimeoutAndCancellation.ExecuteAsync(GetHostAddressesAsync, timeout, cancellationToken).ConfigureAwait(false), + Port); + + // Give each IP an equal share of the remaining time + var perIpTimespan = default(TimeSpan); + var perIpTimeout = timeout; + if (timeout.IsSet) + { + perIpTimespan = new TimeSpan(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length); + perIpTimeout = new NpgsqlTimeout(perIpTimespan); + } + + for (var i = 0; i < endpoints.Length; i++) + { + var endpoint = endpoints[i]; + ConnectionLogger.LogTrace("Attempting to connect to {Endpoint}", endpoint); + var protocolType = + endpoint.AddressFamily == AddressFamily.InterNetwork || + endpoint.AddressFamily == AddressFamily.InterNetworkV6 + ? ProtocolType.Tcp + : ProtocolType.IP; + var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType); + try + { + await OpenSocketConnectionAsync(socket, endpoint, perIpTimeout, cancellationToken).ConfigureAwait(false); + SetSocketOptions(socket); + _socket = socket; + ConnectedEndPoint = endpoint; + return; + } + catch (Exception e) + { + try + { + socket.Dispose(); + } + catch + { + // ignored + } + + cancellationToken.ThrowIfCancellationRequested(); + + if (e is OperationCanceledException) + e = new TimeoutException("Timeout during connection attempt"); + + ConnectionLogger.LogTrace(e, "Failed to connect to {Endpoint}", endpoint); + + if (i == endpoints.Length - 1) + throw new NpgsqlException($"Failed to connect to {endpoint}", e); + } + } + + static Task OpenSocketConnectionAsync(Socket socket, EndPoint endpoint, NpgsqlTimeout perIpTimeout, CancellationToken cancellationToken) + { + // Whether the framework and/or the OS platform support Socket.ConnectAsync cancellation API or they do not, + // we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting + // and raises the exception, while the actual task may be left running. + Task ConnectAsync(CancellationToken ct) => +#if NET5_0_OR_GREATER + socket.ConnectAsync(endpoint, ct).AsTask(); +#else + socket.ConnectAsync(endpoint); +#endif + return TaskTimeoutAndCancellation.ExecuteAsync(ConnectAsync, perIpTimeout, cancellationToken); + } + } + + IPEndPoint[] IPAddressesToEndpoints(IPAddress[] ipAddresses, int port) + { + var result = new IPEndPoint[ipAddresses.Length]; + for (var i = 0; i < ipAddresses.Length; i++) + result[i] = new IPEndPoint(ipAddresses[i], port); + return result; + } + + void SetSocketOptions(Socket socket) + { + if (socket.AddressFamily == AddressFamily.InterNetwork || socket.AddressFamily == AddressFamily.InterNetworkV6) + socket.NoDelay = true; + if (Settings.SocketReceiveBufferSize > 0) + socket.ReceiveBufferSize = Settings.SocketReceiveBufferSize; + if (Settings.SocketSendBufferSize > 0) + socket.SendBufferSize = Settings.SocketSendBufferSize; + + if (Settings.TcpKeepAlive) + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true); + if (Settings.TcpKeepAliveInterval > 0 && Settings.TcpKeepAliveTime == 0) + throw new ArgumentException("If TcpKeepAliveInterval is defined, TcpKeepAliveTime must be defined as well"); + if (Settings.TcpKeepAliveTime > 0) + { + var timeSeconds = Settings.TcpKeepAliveTime; + var intervalSeconds = Settings.TcpKeepAliveInterval > 0 + ? Settings.TcpKeepAliveInterval + : Settings.TcpKeepAliveTime; + +#if NETSTANDARD2_0 || NETSTANDARD2_1 + var timeMilliseconds = timeSeconds * 1000; + var intervalMilliseconds = intervalSeconds * 1000; + + // For the following see https://msdn.microsoft.com/en-us/library/dd877220.aspx + var uintSize = Marshal.SizeOf(typeof(uint)); + var inOptionValues = new byte[uintSize * 3]; + BitConverter.GetBytes((uint)1).CopyTo(inOptionValues, 0); + BitConverter.GetBytes((uint)timeMilliseconds).CopyTo(inOptionValues, uintSize); + BitConverter.GetBytes((uint)intervalMilliseconds).CopyTo(inOptionValues, uintSize * 2); + var result = 0; + try + { + result = socket.IOControl(IOControlCode.KeepAliveValues, inOptionValues, null); + } + catch (PlatformNotSupportedException) + { + throw new PlatformNotSupportedException("Setting TCP Keepalive Time and TCP Keepalive Interval is supported only on Windows, Mono and .NET Core 3.1+. " + + "TCP keepalives can still be used on other systems but are enabled via the TcpKeepAlive option or configured globally for the machine, see the relevant docs."); + } + + if (result != 0) + throw new NpgsqlException($"Got non-zero value when trying to set TCP keepalive: {result}"); +#else + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true); + socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, timeSeconds); + socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, intervalSeconds); +#endif + } + } + + #endregion + + #region I/O + + readonly ChannelReader? CommandsInFlightReader; + internal readonly ChannelWriter? CommandsInFlightWriter; + + internal volatile int CommandsInFlightCount; + + internal ManualResetValueTaskSource ReaderCompleted { get; } = + new() { RunContinuationsAsynchronously = true }; + + async Task MultiplexingReadLoop() + { + Debug.Assert(Settings.Multiplexing); + Debug.Assert(CommandsInFlightReader != null); + + NpgsqlCommand? command = null; + var commandsRead = 0; + + try + { + while (await CommandsInFlightReader.WaitToReadAsync().ConfigureAwait(false)) + { + commandsRead = 0; + Debug.Assert(!InTransaction); + + while (CommandsInFlightReader.TryRead(out command)) + { + commandsRead++; + + await ReadBuffer.Ensure(5, true).ConfigureAwait(false); + + // We have a resultset for the command - hand back control to the command (which will + // return it to the user) + command.TraceReceivedFirstResponse(); + ReaderCompleted.Reset(); + command.ExecutionCompletion.SetResult(this); + + // Now wait until that command's reader is disposed. Note that RunContinuationsAsynchronously is + // true, so that the user code calling NpgsqlDataReader.Dispose will not continue executing + // synchronously here. The prevents issues if the code after the next command's execution + // completion blocks. + await new ValueTask(ReaderCompleted, ReaderCompleted.Version).ConfigureAwait(false); + Debug.Assert(!InTransaction); + } + + // Atomically update the commands in-flight counter, and check if it reached 0. If so, the + // connector is idle and can be returned. + // Note that this is racing with over-capacity writing, which can select any connector at any + // time (see MultiplexingWriteLoop), and we must make absolutely sure that if a connector is + // returned to the pool, it is *never* written to unless properly dequeued from the Idle channel. + if (Interlocked.Add(ref CommandsInFlightCount, -commandsRead) == 0) + { + // There's a race condition where the continuation of an asynchronous multiplexing write may not + // have executed yet, and the flush may still be in progress. We know all I/O has already + // been sent - because the reader has already consumed the entire resultset. So we wait until + // the connector's write lock has been released (long waiting will never occur here). + SpinWait.SpinUntil(() => MultiplexAsyncWritingLock == 0 || IsBroken); + + ResetReadBuffer(); + DataSource.Return(this); + } + } + + ConnectionLogger.LogTrace("Exiting multiplexing read loop", Id); + } + catch (Exception e) + { + Debug.Assert(IsBroken); + + // Decrement the commands already dequeued from the in-flight counter + Interlocked.Add(ref CommandsInFlightCount, -commandsRead); + + // When a connector is broken, the causing exception is stored on it. We fail commands with + // that exception - rather than the one thrown here - since the break may have happened during + // writing, and we want to bubble that one up. + + // Drain any pending in-flight commands and fail them. Note that some have only been written + // to the buffer, and not sent to the server. + command?.ExecutionCompletion.SetException(_breakReason!); + try + { + while (true) + { + var pendingCommand = await CommandsInFlightReader.ReadAsync().ConfigureAwait(false); + + // TODO: the exception we have here is sometimes just the result of the write loop breaking + // the connector, so it doesn't represent the actual root cause. + pendingCommand.ExecutionCompletion.SetException(new NpgsqlException("A previous command on this connection caused an error requiring all pending commands on this connection to be aborted", _breakReason!)); + } + } + catch (ChannelClosedException) + { + // All good, drained to the channel and failed all commands + } + + // "Return" the connector to the pool to for cleanup (e.g. update total connector count) + DataSource.Return(this); + + ConnectionLogger.LogError(e, "Exception in multiplexing read loop", Id); + } + + Debug.Assert(CommandsInFlightCount == 0); + } + + #endregion + + #region Frontend message processing + + /// + /// Prepends a message to be sent at the beginning of the next message chain. + /// + internal void PrependInternalMessage(byte[] rawMessage, int responseMessageCount) + { + PendingPrependedResponses += responseMessageCount; + + var t = WritePregenerated(rawMessage); + Debug.Assert(t.IsCompleted, "Could not fully write pregenerated message into the buffer"); + } + + #endregion + + #region Backend message processing + + internal ValueTask ReadMessageWithNotifications(bool async) + => ReadMessageLong(async, DataRowLoadingMode.NonSequential, readingNotifications: true); + + internal ValueTask ReadMessage( + bool async, + DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential) + { + if (PendingPrependedResponses > 0 || + dataRowLoadingMode == DataRowLoadingMode.Skip || + ReadBuffer.ReadBytesLeft < 5) + { + return ReadMessageLong(async, dataRowLoadingMode, readingNotifications: false)!; + } + + var messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); + switch (messageCode) + { + case BackendMessageCode.NoticeResponse: + case BackendMessageCode.NotificationResponse: + case BackendMessageCode.ParameterStatus: + case BackendMessageCode.ErrorResponse: + ReadBuffer.ReadPosition--; + return ReadMessageLong(async, dataRowLoadingMode, readingNotifications: false)!; + } + + ValidateBackendMessageCode(messageCode); + var len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself + if (len > ReadBuffer.ReadBytesLeft) + { + ReadBuffer.ReadPosition -= 5; + return ReadMessageLong(async, dataRowLoadingMode, readingNotifications: false)!; + } + + return new ValueTask(ParseServerMessage(ReadBuffer, messageCode, len, false))!; + } + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask ReadMessageLong( + bool async, + DataRowLoadingMode dataRowLoadingMode, + bool readingNotifications, + bool isReadingPrependedMessage = false) + { + // First read the responses of any prepended messages. + if (PendingPrependedResponses > 0 && !isReadingPrependedMessage) + { + try + { + // TODO: There could be room for optimization here, rather than the async call(s) + for (; PendingPrependedResponses > 0; PendingPrependedResponses--) + await ReadMessageLong(async, DataRowLoadingMode.Skip, readingNotifications: false, isReadingPrependedMessage: true).ConfigureAwait(false); + // We've read all the prepended response. + // Allow cancellation to proceed. + ReadingPrependedMessagesMRE.Set(); + } + catch (Exception e) + { + // Prepended queries should never fail. + // If they do, we're not even going to attempt to salvage the connector. + Break(e); + throw; + } + } + + PostgresException? error = null; + + try + { + while (true) + { + await ReadBuffer.Ensure(5, async, readingNotifications).ConfigureAwait(false); + var messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); + ValidateBackendMessageCode(messageCode); + var len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself + + if ((messageCode == BackendMessageCode.DataRow && + dataRowLoadingMode != DataRowLoadingMode.NonSequential) || + messageCode == BackendMessageCode.CopyData) + { + if (dataRowLoadingMode == DataRowLoadingMode.Skip) + { + await ReadBuffer.Skip(len, async).ConfigureAwait(false); + continue; + } + } + else if (len > ReadBuffer.ReadBytesLeft) + { + if (len > ReadBuffer.Size) + { + var oversizeBuffer = ReadBuffer.AllocateOversize(len); + + if (_origReadBuffer == null) + _origReadBuffer = ReadBuffer; + else + ReadBuffer.Dispose(); + + ReadBuffer = oversizeBuffer; + } + + await ReadBuffer.Ensure(len, async).ConfigureAwait(false); + } + + var msg = ParseServerMessage(ReadBuffer, messageCode, len, isReadingPrependedMessage); + + switch (messageCode) + { + case BackendMessageCode.ErrorResponse: + Debug.Assert(msg == null); + + // An ErrorResponse is (almost) always followed by a ReadyForQuery. Save the error + // and throw it as an exception when the ReadyForQuery is received (next). + error = PostgresException.Load( + ReadBuffer, + Settings.IncludeErrorDetail, + LoggingConfiguration.ExceptionLogger); + + if (State == ConnectorState.Connecting) + { + // During the startup/authentication phase, an ErrorResponse isn't followed by + // an RFQ. Instead, the server closes the connection immediately + throw error; + } + + if (PostgresErrorCodes.IsCriticalFailure(error, clusterError: false)) + { + // Consider the connection dead + throw Break(error); + } + + continue; + + case BackendMessageCode.ReadyForQuery: + if (error != null) + { + NpgsqlEventSource.Log.CommandFailed(); + DataSource.MetricsReporter.ReportCommandFailed(); + throw error; + } + + break; + + // Asynchronous messages which can come anytime, they have already been handled + // in ParseServerMessage. Read the next message. + case BackendMessageCode.NoticeResponse: + case BackendMessageCode.NotificationResponse: + case BackendMessageCode.ParameterStatus: + Debug.Assert(msg == null); + if (!readingNotifications) + continue; + return null; + } + + Debug.Assert(msg != null, "Message is null for code: " + messageCode); + + // Reset flushed bytes after any RFQ or in between potentially long running operations. + // Just in case we'll hit that 15 exbibyte limit of a signed long... + if (messageCode is BackendMessageCode.ReadyForQuery or BackendMessageCode.CopyData or BackendMessageCode.NotificationResponse) + ReadBuffer.ResetFlushedBytes(); + + return msg; + } + } + catch (PostgresException e) + { + if (e.SqlState == PostgresErrorCodes.QueryCanceled && PostgresCancellationPerformed) + { + // The query could be canceled because of a user cancellation or a timeout - raise the proper exception. + // If _postgresCancellationPerformed is false, this is an unsolicited cancellation - + // just bubble up thePostgresException. + throw UserCancellationRequested + ? new OperationCanceledException("Query was cancelled", e, UserCancellationToken) + : new NpgsqlException("Exception while reading from stream", + new TimeoutException("Timeout during reading attempt")); + } + + throw; + } + catch (NpgsqlException) + { + // An ErrorResponse isn't followed by ReadyForQuery + if (error != null) + ExceptionDispatchInfo.Capture(error).Throw(); + throw; + } + } + + internal IBackendMessage? ParseResultSetMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool handleCallbacks = false) + => code switch + { + BackendMessageCode.DataRow => _dataRowMessage.Load(len), + BackendMessageCode.CommandComplete => _commandCompleteMessage.Load(buf, len), + _ => ParseServerMessage(buf, code, len, false, handleCallbacks) + }; + + internal IBackendMessage? ParseServerMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool isPrependedMessage, bool handleCallbacks = true) + { + switch (code) + { + case BackendMessageCode.RowDescription: + return _rowDescriptionMessage.Load(buf, SerializerOptions); + case BackendMessageCode.DataRow: + return _dataRowMessage.Load(len); + case BackendMessageCode.CommandComplete: + return _commandCompleteMessage.Load(buf, len); + case BackendMessageCode.ReadyForQuery: + var rfq = _readyForQueryMessage.Load(buf); + if (!isPrependedMessage) { + // Transaction status on prepended messages shouldn't be processed, because there may be prepended messages + // before the begin transaction message. In this case, they will contain transaction status Idle, which will + // clear our Pending transaction status. Only process transaction status on RFQ's from user-provided, non + // prepended messages. + ProcessNewTransactionStatus(rfq.TransactionStatusIndicator); + } + return rfq; + case BackendMessageCode.EmptyQueryResponse: + return EmptyQueryMessage.Instance; + case BackendMessageCode.ParseComplete: + return ParseCompleteMessage.Instance; + case BackendMessageCode.ParameterDescription: + return _parameterDescriptionMessage.Load(buf); + case BackendMessageCode.BindComplete: + return BindCompleteMessage.Instance; + case BackendMessageCode.NoData: + return NoDataMessage.Instance; + case BackendMessageCode.CloseComplete: + return CloseCompletedMessage.Instance; + case BackendMessageCode.ParameterStatus: + ReadParameterStatus(buf.GetNullTerminatedBytes(), buf.GetNullTerminatedBytes()); + return null; + case BackendMessageCode.NoticeResponse: + if (handleCallbacks) + { + var notice = PostgresNotice.Load(buf, Settings.IncludeErrorDetail, LoggingConfiguration.ExceptionLogger); + LogMessages.ReceivedNotice(ConnectionLogger, notice.MessageText, Id); + Connection?.OnNotice(notice); + } + return null; + case BackendMessageCode.NotificationResponse: + if (handleCallbacks) + { + Connection?.OnNotification(new NpgsqlNotificationEventArgs(buf)); + } + return null; + + case BackendMessageCode.AuthenticationRequest: + var authType = (AuthenticationRequestType)buf.ReadInt32(); + return authType switch + { + AuthenticationRequestType.AuthenticationOk => AuthenticationOkMessage.Instance, + AuthenticationRequestType.AuthenticationCleartextPassword => AuthenticationCleartextPasswordMessage.Instance, + AuthenticationRequestType.AuthenticationMD5Password => AuthenticationMD5PasswordMessage.Load(buf), + AuthenticationRequestType.AuthenticationGSS => AuthenticationGSSMessage.Instance, + AuthenticationRequestType.AuthenticationSSPI => AuthenticationSSPIMessage.Instance, + AuthenticationRequestType.AuthenticationGSSContinue => AuthenticationGSSContinueMessage.Load(buf, len), + AuthenticationRequestType.AuthenticationSASL => new AuthenticationSASLMessage(buf), + AuthenticationRequestType.AuthenticationSASLContinue => new AuthenticationSASLContinueMessage(buf, len - 4), + AuthenticationRequestType.AuthenticationSASLFinal => new AuthenticationSASLFinalMessage(buf, len - 4), + _ => throw new NotSupportedException($"Authentication method not supported (Received: {authType})") + }; + + case BackendMessageCode.BackendKeyData: + return new BackendKeyDataMessage(buf); + + case BackendMessageCode.CopyInResponse: + return (_copyInResponseMessage ??= new CopyInResponseMessage()).Load(ReadBuffer); + case BackendMessageCode.CopyOutResponse: + return (_copyOutResponseMessage ??= new CopyOutResponseMessage()).Load(ReadBuffer); + case BackendMessageCode.CopyData: + return (_copyDataMessage ??= new CopyDataMessage()).Load(len); + case BackendMessageCode.CopyBothResponse: + return (_copyBothResponseMessage ??= new CopyBothResponseMessage()).Load(ReadBuffer); + + case BackendMessageCode.CopyDone: + return CopyDoneMessage.Instance; + + case BackendMessageCode.ErrorResponse: + return null; + + case BackendMessageCode.PortalSuspended: + case BackendMessageCode.FunctionCallResponse: + // We don't use the obsolete function call protocol + default: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {code} of enum {nameof(BackendMessageCode)}. Please file a bug."); + return null; + } + } + + /// + /// Reads backend messages and discards them, stopping only after a message of the given type has + /// been seen. Only a sync I/O version of this method exists - in async flows we inline the loop + /// rather than calling an additional async method, in order to avoid the overhead. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal IBackendMessage SkipUntil(BackendMessageCode stopAt) + { + Debug.Assert(stopAt != BackendMessageCode.DataRow, "Shouldn't be used for rows, doesn't know about sequential"); + + while (true) + { + var msg = ReadMessage(async: false, DataRowLoadingMode.Skip).GetAwaiter().GetResult()!; + Debug.Assert(!(msg is DataRowMessage)); + if (msg.Code == stopAt) + return msg; + } + } + + #endregion Backend message processing + + #region Transactions + + internal Task Rollback(bool async, CancellationToken cancellationToken = default) + { + ConnectionLogger.LogDebug("Rolling back transaction", Id); + return ExecuteInternalCommand(PregeneratedMessages.RollbackTransaction, async, cancellationToken); + } + + internal bool InTransaction + { + get + { + switch (TransactionStatus) + { + case TransactionStatus.Idle: + return false; + case TransactionStatus.Pending: + case TransactionStatus.InTransactionBlock: + case TransactionStatus.InFailedTransactionBlock: + return true; + default: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {{0}} of enum {nameof(TransactionStatus)}. Please file a bug.", TransactionStatus); + return false; + } + } + } + + /// + /// Handles a new transaction indicator received on a ReadyForQuery message + /// + void ProcessNewTransactionStatus(TransactionStatus newStatus) + { + if (newStatus == TransactionStatus) + return; + + TransactionStatus = newStatus; + + switch (newStatus) + { + case TransactionStatus.Idle: + return; + case TransactionStatus.InTransactionBlock: + case TransactionStatus.InFailedTransactionBlock: + // In multiplexing mode, we can't support transaction in SQL: the connector must be removed from the + // writable connectors list, otherwise other commands may get written to it. So the user must tell us + // about the transaction via BeginTransaction. + if (Connection is null) + { + Debug.Assert(Settings.Multiplexing); + ThrowHelper.ThrowNotSupportedException("In multiplexing mode, transactions must be started with BeginTransaction"); + } + return; + case TransactionStatus.Pending: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: invalid TransactionStatus {nameof(TransactionStatus.Pending)} received, should be frontend-only"); + return; + default: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {{0}} of enum {nameof(TransactionStatus)}. Please file a bug.", newStatus); + return; + } + } + + internal void ClearTransaction(Exception? disposeReason = null) + { + Transaction?.DisposeImmediately(disposeReason); + TransactionStatus = TransactionStatus.Idle; + } + + #endregion + + #region SSL + + /// + /// Returns whether SSL is being used for the connection + /// + internal bool IsSecure { get; private set; } + + /// + /// Returns whether SCRAM-SHA256 is being user for the connection + /// + internal bool IsScram { get; private set; } + + /// + /// Returns whether SCRAM-SHA256-PLUS is being user for the connection + /// + internal bool IsScramPlus { get; private set; } + + static readonly RemoteCertificateValidationCallback SslVerifyFullValidation = + (sender, certificate, chain, sslPolicyErrors) + => sslPolicyErrors == SslPolicyErrors.None; + + static readonly RemoteCertificateValidationCallback SslVerifyCAValidation = + (sender, certificate, chain, sslPolicyErrors) + => sslPolicyErrors == SslPolicyErrors.None || sslPolicyErrors == SslPolicyErrors.RemoteCertificateNameMismatch; + + static readonly RemoteCertificateValidationCallback SslTrustServerValidation = + (sender, certificate, chain, sslPolicyErrors) + => true; + + static RemoteCertificateValidationCallback SslRootValidation(bool verifyFull, string? certRootPath, X509Certificate2? caCertificate) + => (_, certificate, chain, sslPolicyErrors) => + { + if (certificate is null || chain is null) + return false; + + // No errors here - no reason to check further + if (sslPolicyErrors == SslPolicyErrors.None) + return true; + + // That's VerifyCA check and the only error is name mismatch - no reason to check further + if (!verifyFull && sslPolicyErrors == SslPolicyErrors.RemoteCertificateNameMismatch) + return true; + + // That's VerifyFull check and we have name mismatch - no reason to check further + if (verifyFull && sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) + return false; + + var certs = new X509Certificate2Collection(); + + if (certRootPath is null) + { + Debug.Assert(caCertificate is not null); + certs.Add(caCertificate); + } + else + { + Debug.Assert(caCertificate is null); +#if NET5_0_OR_GREATER + if (Path.GetExtension(certRootPath).ToUpperInvariant() != ".PFX") + certs.ImportFromPemFile(certRootPath); +#endif + + if (certs.Count == 0) + certs.Add(new X509Certificate2(certRootPath)); + } + +#if NET5_0_OR_GREATER + chain.ChainPolicy.CustomTrustStore.AddRange(certs); + chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; +#endif + + chain.ChainPolicy.ExtraStore.AddRange(certs); + + return chain.Build(certificate as X509Certificate2 ?? new X509Certificate2(certificate)); + }; + + #endregion SSL + + #region Cancel + + internal void ResetCancellation() + { + // If a cancellation is in progress, wait for it to "complete" before proceeding (#615) + lock (CancelLock) + { + if (PendingPrependedResponses > 0) + ReadingPrependedMessagesMRE.Reset(); + Debug.Assert(ReadingPrependedMessagesMRE.IsSet || PendingPrependedResponses > 0); + } + } + + internal void PerformUserCancellation() + { + var connection = Connection; + if (connection is null || connection.ConnectorBindingScope == ConnectorBindingScope.Reader || UserCancellationRequested) + return; + + // Take the lock first to make sure there is no concurrent Break. + // We should be safe to take it as Break only take it to set the state. + lock (SyncObj) + { + // The connector is dead, exit gracefully. + if (!IsConnected) + return; + // The connector is still alive, take the CancelLock before exiting SingleUseLock. + // If a break will happen after, it's going to wait for the cancellation to complete. + Monitor.Enter(CancelLock); + } + + try + { + // Wait before we've read all responses for the prepended queries + // as we can't gracefully handle their cancellation. + // Break makes sure that it's going to be set even if we fail while reading them. + + // We don't wait indefinitely to avoid deadlocks from synchronous CancellationToken.Register + // See #5032 + if (!ReadingPrependedMessagesMRE.Wait(0)) + return; + + _userCancellationRequested = true; + + if (AttemptPostgresCancellation && SupportsPostgresCancellation) + { + var cancellationTimeout = Settings.CancellationTimeout; + if (PerformPostgresCancellation() && cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + { + ReadBuffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + ReadBuffer.Cts.CancelAfter(cancellationTimeout); + } + + return; + } + } + + ReadBuffer.Timeout = _cancelImmediatelyTimeout; + ReadBuffer.Cts.Cancel(); + } + finally + { + Monitor.Exit(CancelLock); + } + } + + /// + /// Creates another connector and sends a cancel request through it for this connector. This method never throws, but returns + /// whether the cancellation attempt failed. + /// + /// + /// + /// if the cancellation request was successfully delivered, or if it was skipped because a previous + /// request was already sent. if the cancellation request could not be delivered because of an exception + /// (the method logs internally). + /// + /// + /// This does not indicate whether the cancellation attempt was successful on the PostgreSQL side - only if the request was + /// delivered. + /// + /// + internal bool PerformPostgresCancellation() + { + Debug.Assert(BackendProcessId != 0, "PostgreSQL cancellation requested by the backend doesn't support it"); + + lock (CancelLock) + { + if (PostgresCancellationPerformed) + return true; + + LogMessages.CancellingCommand(ConnectionLogger, Id); + PostgresCancellationPerformed = true; + + try + { + var cancelConnector = new NpgsqlConnector(this); + cancelConnector.DoCancelRequest(BackendProcessId, _backendSecretKey); + } + catch (Exception e) + { + var socketException = e.InnerException as SocketException; + if (socketException == null || socketException.SocketErrorCode != SocketError.ConnectionReset) + { + ConnectionLogger.LogDebug(e, "Exception caught while attempting to cancel command", Id); + return false; + } + } + + return true; + } + } + + void DoCancelRequest(int backendProcessId, int backendSecretKey) + { + Debug.Assert(State == ConnectorState.Closed); + + try + { + RawOpen(Settings.SslMode, new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)), false, CancellationToken.None) + .GetAwaiter().GetResult(); + WriteCancelRequest(backendProcessId, backendSecretKey); + Flush(); + + Debug.Assert(ReadBuffer.ReadBytesLeft == 0); + + // Now wait for the server to close the connection, better chance of the cancellation + // actually being delivered before we continue with the user's logic. + var count = _stream.Read(ReadBuffer.Buffer, 0, 1); + if (count > 0) + ConnectionLogger.LogError("Received response after sending cancel request, shouldn't happen! First byte: " + ReadBuffer.Buffer[0]); + } + finally + { + FullCleanup(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal CancellationTokenRegistration StartCancellableOperation( + CancellationToken cancellationToken = default, + bool attemptPgCancellation = true) + { + _userCancellationRequested = PostgresCancellationPerformed = false; + UserCancellationToken = cancellationToken; + ReadBuffer.Cts.ResetCts(); + + AttemptPostgresCancellation = attemptPgCancellation; + return _cancellationTokenRegistration = + cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + } + + /// + /// Starts a new cancellable operation within an ongoing user action. This should only be used if a single user + /// action spans several different actions which each has its own cancellation tokens. For example, a command + /// execution is a single user action, but spans ExecuteReaderQuery, NextResult, Read and so forth. + /// + /// + /// Only one level of nested operations is supported. It is an error to call this method if it has previously + /// been called, and the returned was not disposed. + /// + /// + /// The cancellation token provided by the user. Callbacks will be registered on this token for executing the + /// cancellation, and the token will be included in any thrown . + /// + /// + /// If , PostgreSQL cancellation will be attempted when the user requests cancellation or + /// a timeout occurs, followed by a client-side socket cancellation once + /// has elapsed. If , + /// PostgreSQL cancellation will be skipped and client-socket cancellation will occur immediately. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal NestedCancellableScope StartNestedCancellableOperation( + CancellationToken cancellationToken = default, + bool attemptPgCancellation = true) + { + var currentUserCancellationToken = UserCancellationToken; + UserCancellationToken = cancellationToken; + var currentAttemptPostgresCancellation = AttemptPostgresCancellation; + AttemptPostgresCancellation = attemptPgCancellation; + + var registration = cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + + return new(this, registration, currentUserCancellationToken, currentAttemptPostgresCancellation); + } + + internal readonly struct NestedCancellableScope : IDisposable + { + readonly NpgsqlConnector _connector; + readonly CancellationTokenRegistration _registration; + readonly CancellationToken _previousCancellationToken; + readonly bool _previousAttemptPostgresCancellation; + + public NestedCancellableScope(NpgsqlConnector connector, CancellationTokenRegistration registration, CancellationToken previousCancellationToken, bool previousAttemptPostgresCancellation) + { + _connector = connector; + _registration = registration; + _previousCancellationToken = previousCancellationToken; + _previousAttemptPostgresCancellation = previousAttemptPostgresCancellation; + } + + public void Dispose() + { + if (_connector is null) + return; + + _connector.UserCancellationToken = _previousCancellationToken; + _connector.AttemptPostgresCancellation = _previousAttemptPostgresCancellation; + _registration.Dispose(); + } + } + + #endregion Cancel + + #region Close / Reset + + /// + /// Closes ongoing operations, i.e. an open reader exists or a COPY operation still in progress, as + /// part of a connection close. + /// + internal async Task CloseOngoingOperations(bool async) + { + var reader = CurrentReader; + var copyOperation = CurrentCopyOperation; + + if (reader != null) + await reader.Close(async, connectionClosing: true, isDisposing: false).ConfigureAwait(false); + else if (copyOperation != null) + { + // TODO: There's probably a race condition as the COPY operation may finish on its own during the next few lines + + // Note: we only want to cancel import operations, since in these cases cancel is safe. + // Export cancellations go through the PostgreSQL "asynchronous" cancel mechanism and are + // therefore vulnerable to the race condition in #615. + if (copyOperation is NpgsqlBinaryImporter || + copyOperation is NpgsqlCopyTextWriter || + copyOperation is NpgsqlRawCopyStream rawCopyStream && rawCopyStream.CanWrite) + { + try + { + if (async) + await copyOperation.CancelAsync().ConfigureAwait(false); + else + copyOperation.Cancel(); + } + catch (Exception e) + { + CopyLogger.LogWarning(e, "Error while cancelling COPY on connector close", Id); + } + } + + try + { + if (async) + await copyOperation.DisposeAsync().ConfigureAwait(false); + else + copyOperation.Dispose(); + } + catch (Exception e) + { + CopyLogger.LogWarning(e, "Error while disposing cancelled COPY on connector close", Id); + } + } + } + + // TODO in theory this should be async-optional, but the only I/O done here is the Terminate Flush, which is + // very unlikely to block (plus locking would need to be worked out) + internal void Close() + { + lock (SyncObj) + { + if (IsReady) + { + LogMessages.ClosingPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString, Id); + try + { + // At this point, there could be some prepended commands (like DISCARD ALL) + // which make no sense to send on connection close + // see https://github.com/npgsql/npgsql/issues/3592 + WriteBuffer.Clear(); + WriteTerminate(); + Flush(); + } + catch (Exception e) + { + ConnectionLogger.LogError(e, "Exception while closing connector", Id); + Debug.Assert(IsBroken); + } + } + + switch (State) + { + case ConnectorState.Broken: + case ConnectorState.Closed: + return; + } + + State = ConnectorState.Closed; + } + + FullCleanup(); + LogMessages.ClosedPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString, Id); + } + + internal bool TryRemovePendingEnlistedConnector(Transaction transaction) + => DataSource.TryRemovePendingEnlistedConnector(this, transaction); + + internal void Return() => DataSource.Return(this); + + /// + /// Called when an unexpected message has been received during an action. Breaks the + /// connector and returns the appropriate message. + /// + internal Exception UnexpectedMessageReceived(BackendMessageCode received) + => throw Break(new Exception($"Received unexpected backend message {received}. Please file a bug.")); + + /// + /// Called when a connector becomes completely unusable, e.g. when an unexpected I/O exception is raised or when + /// we lose protocol sync. + /// Note that fatal errors during the Open phase do *not* pass through here. + /// + /// The exception that caused the break. + /// The exception given in for chaining calls. + internal Exception Break(Exception reason) + { + Debug.Assert(!IsClosed); + + Monitor.Enter(SyncObj); + + if (State == ConnectorState.Broken) + { + // We're already broken. + // Exit SingleUseLock to unblock other threads (like cancellation). + Monitor.Exit(SyncObj); + // Wait for the break to complete before going forward. + lock (CleanupLock) { } + return reason; + } + + try + { + // If we're broken while reading prepended messages + // the cancellation request might still be waiting on the MRE. + // Unblock it. + ReadingPrependedMessagesMRE.Set(); + + LogMessages.BreakingConnection(ConnectionLogger, Id, reason); + + // Note that we may be reading and writing from the same connector concurrently, so safely set + // the original reason for the break before actually closing the socket etc. + Interlocked.CompareExchange(ref _breakReason, reason, null); + State = ConnectorState.Broken; + // Take the CleanupLock while in SingleUseLock to make sure concurrent Break doesn't take it first. + Monitor.Enter(CleanupLock); + } + finally + { + // Unblock other threads (like cancellation) to proceed and exit gracefully. + Monitor.Exit(SyncObj); + } + + try + { + // Make sure there is no concurrent cancellation in process + lock (CancelLock) + { + // Note we only set the cluster to offline and clear the pool if the connection is being broken (we're in this method), + // *and* the exception indicates that the PG cluster really is down; the latter includes any IO/timeout issue, + // but does not include e.g. authentication failure or timeouts with disabled cancellation. + if (reason is NpgsqlException { IsTransient: true } ne && + (ne.InnerException is not TimeoutException || Settings.CancellationTimeout != -1) || + reason is PostgresException pe && PostgresErrorCodes.IsCriticalFailure(pe)) + { + DataSource.UpdateDatabaseState(DatabaseState.Offline, DateTime.UtcNow, Settings.HostRecheckSecondsTranslated); + DataSource.Clear(); + } + + var connection = Connection; + + FullCleanup(); + + if (connection is not null) + { + var closeLockTaken = connection.TakeCloseLock(); + Debug.Assert(closeLockTaken); + if (Settings.ReplicationMode == ReplicationMode.Off) + { + // When a connector is broken, we immediately "return" it to the pool (i.e. update the pool state so reflect the + // connector no longer being open). Upper layers such as EF may check DbConnection.ConnectionState, and only close if + // it's closed; so we can't set the state to Closed and expect the user to still close (in order to return to the pool). + // On the other hand leaving the state Open could indicate to the user that the connection is functional. + // (see https://github.com/npgsql/npgsql/issues/3705#issuecomment-839908772) + Connection = null; + if (connection.ConnectorBindingScope != ConnectorBindingScope.None) + Return(); + connection.EnlistedTransaction = null; + connection.Connector = null; + connection.ConnectorBindingScope = ConnectorBindingScope.None; + } + + connection.FullState = ConnectionState.Broken; + connection.ReleaseCloseLock(); + } + + return reason; + } + } + finally + { + Monitor.Exit(CleanupLock); + } + } + + void FullCleanup() + { + lock (CleanupLock) + { + if (Settings.Multiplexing) + { + FlagAsNotWritableForMultiplexing(); + + // Note that in multiplexing, this could be called from the read loop, while the write loop is + // writing into the channel. To make sure this race condition isn't a problem, the channel currently + // isn't set up with SingleWriter (since at this point it doesn't do anything). + CommandsInFlightWriter!.Complete(); + + // The connector's read loop has a continuation to observe and log any exception coming out + // (see Open) + } + + ConnectionLogger.LogTrace("Cleaning up connector", Id); + Cleanup(); + + if (_isKeepAliveEnabled) + { + _keepAliveTimer!.Change(Timeout.Infinite, Timeout.Infinite); + _keepAliveTimer.Dispose(); + } + + ReadingPrependedMessagesMRE.Dispose(); + } + } + + /// + /// Closes the socket and cleans up client-side resources associated with this connector. + /// + /// + /// This method doesn't actually perform any meaningful I/O, and therefore is sync-only. + /// + void Cleanup() + { + try + { + _stream?.Dispose(); + } + catch + { + // ignored + } + + if (CurrentReader != null) + { + CurrentReader.Command.State = CommandState.Idle; + try + { + // Note that this never actually blocks on I/O, since the stream is also closed + // (which is why we don't need to call CloseAsync) + CurrentReader.Close(); + } + catch + { + // ignored + } + CurrentReader = null; + } + + if (CurrentCopyOperation != null) + { + try + { + // Note that this never actually blocks on I/O, since the stream is also closed + // (which is why we don't need to call DisposeAsync) + CurrentCopyOperation.Dispose(); + } + catch + { + // ignored + } + CurrentCopyOperation = null; + } + + ClearTransaction(_breakReason); + + _stream = null!; + _baseStream = null!; + _origReadBuffer?.Dispose(); + _origReadBuffer = null; + ReadBuffer?.Dispose(); + ReadBuffer = null!; + WriteBuffer?.Dispose(); + WriteBuffer = null!; + Connection = null; + PostgresParameters.Clear(); + _currentCommand = null; + + if (_certificate is not null) + { + _certificate.Dispose(); + _certificate = null; + } + } + + void GenerateResetMessage() + { + var sb = new StringBuilder("SET SESSION AUTHORIZATION DEFAULT;RESET ALL;"); + _resetWithoutDeallocateResponseCount = 2; + if (DatabaseInfo.SupportsCloseAll) + { + sb.Append("CLOSE ALL;"); + _resetWithoutDeallocateResponseCount++; + } + if (DatabaseInfo.SupportsUnlisten) + { + sb.Append("UNLISTEN *;"); + _resetWithoutDeallocateResponseCount++; + } + if (DatabaseInfo.SupportsAdvisoryLocks) + { + sb.Append("SELECT pg_advisory_unlock_all();"); + _resetWithoutDeallocateResponseCount += 2; + } + if (DatabaseInfo.SupportsDiscardSequences) + { + sb.Append("DISCARD SEQUENCES;"); + _resetWithoutDeallocateResponseCount++; + } + if (DatabaseInfo.SupportsDiscardTemp) + { + sb.Append("DISCARD TEMP"); + _resetWithoutDeallocateResponseCount++; + } + + _resetWithoutDeallocateResponseCount++; // One ReadyForQuery at the end + + _resetWithoutDeallocateMessage = PregeneratedMessages.Generate(WriteBuffer, sb.ToString()); + } + + /// + /// Called when a pooled connection is closed, and its connector is returned to the pool. + /// Resets the connector back to its initial state, releasing server-side sources + /// (e.g. prepared statements), resetting parameters to their defaults, and resetting client-side + /// state + /// + internal async Task Reset(bool async) + { + bool endBindingScope; + + // We start user action in case a keeplive happens concurrently, or a concurrent user command (bug) + using (StartUserAction(attemptPgCancellation: false)) + { + // Our buffer may contain unsent prepended messages, so clear it out. + // In practice, this is (currently) only done when beginning a transaction or a transaction savepoint. + WriteBuffer.Clear(); + PendingPrependedResponses = 0; + + ResetReadBuffer(); + + Transaction?.UnbindIfNecessary(); + + // Must rollback transaction before sending DISCARD ALL + switch (TransactionStatus) + { + case TransactionStatus.Idle: + // There is an undisposed transaction on multiplexing connection + endBindingScope = Connection?.ConnectorBindingScope == ConnectorBindingScope.Transaction; + break; + case TransactionStatus.Pending: + // BeginTransaction() was called, but was left in the write buffer and not yet sent to server. + // Just clear the transaction state. + ProcessNewTransactionStatus(TransactionStatus.Idle); + ClearTransaction(); + endBindingScope = true; + break; + case TransactionStatus.InTransactionBlock: + case TransactionStatus.InFailedTransactionBlock: + await Rollback(async).ConfigureAwait(false); + ClearTransaction(); + endBindingScope = true; + break; + default: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {TransactionStatus} of enum {nameof(TransactionStatus)}. Please file a bug."); + return; + } + + if (_sendResetOnClose) + { + if (PreparedStatementManager.NumPrepared > 0) + { + // We have prepared statements, so we can't reset the connection state with DISCARD ALL + // Note: the send buffer has been cleared above, and we assume all this will fit in it. + PrependInternalMessage(_resetWithoutDeallocateMessage!, _resetWithoutDeallocateResponseCount); + } + else + { + // There are no prepared statements. + // We simply send DISCARD ALL which is more efficient than sending the above messages separately + PrependInternalMessage(PregeneratedMessages.DiscardAll, 2); + } + } + + DataReader.UnbindIfNecessary(); + } + + if (endBindingScope) + { + // Connection is null if a connection enlisted in a TransactionScope was closed before the + // TransactionScope completed - the connector is still enlisted, but has no connection. + Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + } + } + + /// + /// The connector may have allocated an oversize read buffer, to hold big rows in non-sequential reading. + /// This switches us back to the original one and returns the buffer to . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void ResetReadBuffer() + { + LongRunningConnection = false; + if (_origReadBuffer != null) + { + Debug.Assert(_origReadBuffer.ReadBytesLeft == 0); + Debug.Assert(_origReadBuffer.ReadPosition == 0); + if (ReadBuffer.ReadBytesLeft > 0) + { + // There is still something in the buffer which we haven't read yet + // In most cases it's ParameterStatus which can be sent asynchronously + // If in some extreme case we have too much data left in the buffer to store in the original buffer + // we just leave the oversize buffer as is and will try again on next reset + if (ReadBuffer.ReadBytesLeft > _origReadBuffer.Size) + return; + + ReadBuffer.CopyTo(_origReadBuffer); + } + + ReadBuffer.Dispose(); + ReadBuffer = _origReadBuffer; + _origReadBuffer = null; + } + } + + internal void UnprepareAll() + { + ExecuteInternalCommand("DEALLOCATE ALL"); + PreparedStatementManager.ClearAll(); + } + + #endregion Close / Reset + + #region Locking + + internal UserAction StartUserAction(CancellationToken cancellationToken = default, bool attemptPgCancellation = true) + => StartUserAction(ConnectorState.Executing, command: null, cancellationToken, attemptPgCancellation); + + internal UserAction StartUserAction( + ConnectorState newState, + CancellationToken cancellationToken = default, + bool attemptPgCancellation = true) + => StartUserAction(newState, command: null, cancellationToken, attemptPgCancellation); + + /// + /// Starts a user action. This makes sure that another action isn't already in progress, handles synchronization with keepalive, + /// and sets up cancellation. + /// + /// The new state to be set when entering this user action. + /// + /// The that is starting execution - if an is + /// thrown, it will reference this. + /// + /// + /// The cancellation token provided by the user. Callbacks will be registered on this token for executing the cancellation, + /// and the token will be included in any thrown . + /// + /// + /// If , PostgreSQL cancellation will be attempted when the user requests cancellation or a timeout + /// occurs, followed by a client-side socket cancellation once has + /// elapsed. If , PostgreSQL cancellation will be skipped and client-socket cancellation will occur + /// immediately. + /// + internal UserAction StartUserAction( + ConnectorState newState, + NpgsqlCommand? command, + CancellationToken cancellationToken = default, + bool attemptPgCancellation = true) + { + // If keepalive is enabled, we must protect state transitions with a lock. + // This will make the keepalive abort safely if a user query is in progress, and make + // the user query wait if a keepalive is in progress. + // If keepalive isn't enabled, we don't use the lock and rely only on the connector's + // state (updated via Interlocked.Exchange) to detect concurrent use, on a best-effort basis. + return _isKeepAliveEnabled + ? DoStartUserActionWithKeepAlive(newState, command, cancellationToken, attemptPgCancellation) + : DoStartUserAction(newState, command, cancellationToken, attemptPgCancellation); + + UserAction DoStartUserAction(ConnectorState newState, NpgsqlCommand? command, + CancellationToken cancellationToken, bool attemptPgCancellation) + { + switch (State) + { + case ConnectorState.Ready: + break; + case ConnectorState.Closed: + case ConnectorState.Broken: + ThrowHelper.ThrowInvalidOperationException("Connection is not open"); + break; + case ConnectorState.Executing: + case ConnectorState.Fetching: + case ConnectorState.Waiting: + case ConnectorState.Replication: + case ConnectorState.Connecting: + case ConnectorState.Copy: + var currentCommand = _currentCommand; + if (currentCommand is null) + ThrowHelper.ThrowNpgsqlOperationInProgressException(State); + else + ThrowHelper.ThrowNpgsqlOperationInProgressException(currentCommand); + break; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(State), "Invalid connector state: {0}", State); + break; + } + + Debug.Assert(IsReady); + + cancellationToken.ThrowIfCancellationRequested(); + + LogMessages.StartUserAction(ConnectionLogger, Id); + State = newState; + _currentCommand = command; + + StartCancellableOperation(cancellationToken, attemptPgCancellation); + + // We reset the ReadBuffer.Timeout for every user action, so it wouldn't leak from the previous query or action + // For example, we might have successfully cancelled the previous query (so the connection is not broken) + // But the next time, we call the Prepare, which doesn't set it's own timeout + ReadBuffer.Timeout = TimeSpan.FromSeconds(command?.CommandTimeout ?? Settings.CommandTimeout); + + return new UserAction(this); + } + + UserAction DoStartUserActionWithKeepAlive(ConnectorState newState, NpgsqlCommand? command, + CancellationToken cancellationToken, bool attemptPgCancellation) + { + lock (SyncObj) + { + if (!IsConnected) + { + if (IsBroken) + ThrowHelper.ThrowNpgsqlException("The connection was previously broken because of the following exception", _breakReason); + else + ThrowHelper.ThrowNpgsqlException("The connection is closed"); + } + + // Disable keepalive, it will be restarted at the end of the user action + _keepAliveTimer!.Change(Timeout.Infinite, Timeout.Infinite); + + try + { + // Check that the connector is ready. + return DoStartUserAction(newState, command, cancellationToken, attemptPgCancellation); + } + catch (Exception ex) when (ex is not NpgsqlOperationInProgressException) + { + // We failed, but there is no current operation. + // As such, we re-enable the keepalive. + var keepAlive = Settings.KeepAlive * 1000; + _keepAliveTimer!.Change(keepAlive, keepAlive); + throw; + } + } + } + } + + internal void EndUserAction() + { + Debug.Assert(CurrentReader == null); + + _cancellationTokenRegistration.Dispose(); + + if (_isKeepAliveEnabled) + { + lock (SyncObj) + { + if (IsReady || !IsConnected) + return; + + var keepAlive = Settings.KeepAlive * 1000; + _keepAliveTimer!.Change(keepAlive, keepAlive); + + LogMessages.EndUserAction(ConnectionLogger, Id); + _currentCommand = null; + State = ConnectorState.Ready; + } + } + else + { + if (IsReady || !IsConnected) + return; + + LogMessages.EndUserAction(ConnectionLogger, Id); + _currentCommand = null; + State = ConnectorState.Ready; + } + } + + /// + /// An IDisposable wrapper around . + /// + internal readonly struct UserAction : IDisposable + { + readonly NpgsqlConnector _connector; + internal UserAction(NpgsqlConnector connector) => _connector = connector; + public void Dispose() => _connector.EndUserAction(); + } + + #endregion + + #region Keepalive + +#pragma warning disable CA1801 // Review unused parameters + void PerformKeepAlive(object? state) + { + Debug.Assert(_isKeepAliveEnabled); + if (!Monitor.TryEnter(SyncObj)) + return; + + try + { + // There may already be a user action, or the connector may be closed etc. + if (!IsReady) + return; + + LogMessages.SendingKeepalive(ConnectionLogger, Id); + AttemptPostgresCancellation = false; + var timeout = Math.Max(Settings.CommandTimeout, MinimumInternalCommandTimeout); + ReadBuffer.Timeout = WriteBuffer.Timeout = TimeSpan.FromSeconds(timeout); + WriteSync(async: false).GetAwaiter().GetResult(); + Flush(); + SkipUntil(BackendMessageCode.ReadyForQuery); + LogMessages.CompletedKeepalive(ConnectionLogger, Id); + } + catch (Exception e) + { + LogMessages.KeepaliveFailed(ConnectionLogger, Id, e); + try + { + Break(new NpgsqlException("Exception while sending a keepalive", e)); + } + catch (Exception e2) + { + ConnectionLogger.LogError(e2, "Further exception while breaking connector on keepalive failure", Id); + } + } + finally + { + Monitor.Exit(SyncObj); + } + } +#pragma warning restore CA1801 // Review unused parameters + + #endregion + + #region Wait + + internal async Task Wait(bool async, int timeout, CancellationToken cancellationToken = default) + { + using var _ = StartUserAction(ConnectorState.Waiting, cancellationToken: cancellationToken, attemptPgCancellation: false); + + // We may have prepended messages in the connection's write buffer - these need to be flushed now. + await Flush(async, cancellationToken).ConfigureAwait(false); + + var keepaliveMs = Settings.KeepAlive * 1000; + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + var timeoutForKeepalive = _isKeepAliveEnabled && (timeout <= 0 || keepaliveMs < timeout); + ReadBuffer.Timeout = TimeSpan.FromMilliseconds(timeoutForKeepalive ? keepaliveMs : timeout); + try + { + var msg = await ReadMessageWithNotifications(async).ConfigureAwait(false); + if (msg != null) + { + throw Break( + new NpgsqlException($"Received unexpected message of type {msg.Code} while waiting")); + } + return true; + } + catch (NpgsqlException e) when (e.InnerException is TimeoutException) + { + if (!timeoutForKeepalive) // We really timed out + return false; + } + + LogMessages.SendingKeepalive(ConnectionLogger, Id); + + var keepaliveTime = Stopwatch.StartNew(); + await WriteSync(async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + + var receivedNotification = false; + var expectedMessageCode = BackendMessageCode.RowDescription; + + while (true) + { + IBackendMessage? msg; + + try + { + msg = await ReadMessageWithNotifications(async).ConfigureAwait(false); + } + catch (Exception e) when (e is OperationCanceledException || e is NpgsqlException npgEx && npgEx.InnerException is TimeoutException) + { + // We're somewhere in the middle of a reading keepalive messages + // Breaking the connection, as we've lost protocol sync + Break(e); + throw; + } + + if (msg == null) + { + receivedNotification = true; + continue; + } + + if (msg.Code != BackendMessageCode.ReadyForQuery) + throw new NpgsqlException($"Received unexpected message of type {msg.Code} while expecting {expectedMessageCode} as part of keepalive"); + + LogMessages.CompletedKeepalive(ConnectionLogger, Id); + + if (receivedNotification) + return true; // Notification was received during the keepalive + cancellationToken.ThrowIfCancellationRequested(); + break; + } + + if (timeout > 0) + timeout -= (keepaliveMs + (int)keepaliveTime.ElapsedMilliseconds); + } + } + + #endregion + + #region Supported features and PostgreSQL settings + + internal bool UseConformingStrings { get; private set; } + + /// + /// The connection's timezone as reported by PostgreSQL, in the IANA/Olson database format. + /// + internal string Timezone { get; private set; } = default!; + + bool? _isTransactionReadOnly; + + bool? _isHotStandBy; + + #endregion Supported features and PostgreSQL settings + + #region Execute internal command + + internal void ExecuteInternalCommand(string query) + => ExecuteInternalCommand(query, false).GetAwaiter().GetResult(); + + internal async Task ExecuteInternalCommand(string query, bool async, CancellationToken cancellationToken = default) + { + LogMessages.ExecutingInternalCommand(CommandLogger, query, Id); + + await WriteQuery(query, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + Expect(await ReadMessage(async).ConfigureAwait(false), this); + Expect(await ReadMessage(async).ConfigureAwait(false), this); + } + + internal async Task ExecuteInternalCommand(byte[] data, bool async, CancellationToken cancellationToken = default) + { + Debug.Assert(State != ConnectorState.Ready, "Forgot to start a user action..."); + + await WritePregenerated(data, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); + Expect(await ReadMessage(async).ConfigureAwait(false), this); + Expect(await ReadMessage(async).ConfigureAwait(false), this); + } + + #endregion + + #region Misc + + /// + /// Creates and returns a object associated with the . + /// + /// The text of the query. + /// A object. + public NpgsqlCommand CreateCommand(string? cmdText = null) => new(cmdText, this); + + /// + /// Creates and returns a object associated with the . + /// + /// A object. + public NpgsqlBatch CreateBatch() => new NpgsqlBatch(this); + + void ReadParameterStatus(ReadOnlySpan incomingName, ReadOnlySpan incomingValue) + { + byte[] rawName; + byte[] rawValue; + + for (var i = 0; i < _rawParameters.Count; i++) + { + (var currentName, var currentValue) = _rawParameters[i]; + if (incomingName.SequenceEqual(currentName)) + { + if (incomingValue.SequenceEqual(currentValue)) + return; + + rawName = currentName; + rawValue = incomingValue.ToArray(); + _rawParameters[i] = (rawName, rawValue); + + goto ProcessParameter; + } + } + + rawName = incomingName.ToArray(); + rawValue = incomingValue.ToArray(); + _rawParameters.Add((rawName, rawValue)); + + ProcessParameter: + var name = TextEncoding.GetString(rawName); + var value = TextEncoding.GetString(rawValue); + + PostgresParameters[name] = value; + + switch (name) + { + case "standard_conforming_strings": + if (value != "on" && Settings.Multiplexing) + throw Break(new NotSupportedException("standard_conforming_strings must be on with multiplexing")); + UseConformingStrings = value == "on"; + return; + + case "TimeZone": + Timezone = value; + return; + + case "default_transaction_read_only": + _isTransactionReadOnly = value == "on"; + UpdateDatabaseState(); + return; + + case "in_hot_standby": + _isHotStandBy = value == "on"; + UpdateDatabaseState(); + return; + } + } + + DatabaseState? UpdateDatabaseState() + { + if (_isTransactionReadOnly.HasValue && _isHotStandBy.HasValue) + { + var state = _isHotStandBy.Value + ? DatabaseState.Standby + : _isTransactionReadOnly.Value + ? DatabaseState.PrimaryReadOnly + : DatabaseState.PrimaryReadWrite; + return DataSource.UpdateDatabaseState(state, DateTime.UtcNow, Settings.HostRecheckSecondsTranslated); + } + + return null; + } + + #endregion Misc +} + +#region Enums + +/// +/// Expresses the exact state of a connector. +/// +enum ConnectorState +{ + /// + /// The connector has either not yet been opened or has been closed. + /// + Closed, + + /// + /// The connector is currently connecting to a PostgreSQL server. + /// + Connecting, + + /// + /// The connector is connected and may be used to send a new query. + /// + Ready, + + /// + /// The connector is waiting for a response to a query which has been sent to the server. + /// + Executing, + + /// + /// The connector is currently fetching and processing query results. + /// + Fetching, + + /// + /// The connector is currently waiting for asynchronous notifications to arrive. + /// + Waiting, + + /// + /// The connection was broken because an unexpected error occurred which left it in an unknown state. + /// This state isn't implemented yet. + /// + Broken, + + /// + /// The connector is engaged in a COPY operation. + /// + Copy, + + /// + /// The connector is engaged in streaming replication. + /// + Replication, +} + +#pragma warning disable CA1717 +enum TransactionStatus : byte +#pragma warning restore CA1717 +{ + /// + /// Currently not in a transaction block + /// + Idle = (byte)'I', + + /// + /// Currently in a transaction block + /// + InTransactionBlock = (byte)'T', + + /// + /// Currently in a failed transaction block (queries will be rejected until block is ended) + /// + InFailedTransactionBlock = (byte)'E', + + /// + /// A new transaction has been requested but not yet transmitted to the backend. It will be transmitted + /// prepended to the next query. + /// This is a client-side state option only, and is never transmitted from the backend. + /// + Pending = byte.MaxValue, +} + +/// +/// Specifies how to load/parse DataRow messages as they're received from the backend. +/// +enum DataRowLoadingMode +{ + /// + /// Load DataRows in non-sequential mode + /// + NonSequential, + + /// + /// Load DataRows in sequential mode + /// + Sequential, + + /// + /// Skip DataRow messages altogether + /// + Skip +} + +#endregion diff --git a/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs new file mode 100644 index 0000000000..7fd3fe95e9 --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs @@ -0,0 +1,369 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using Npgsql.Util; + +namespace Npgsql.Internal; + +/// +/// Base class for implementations which provide information about PostgreSQL and PostgreSQL-like databases +/// (e.g. type definitions, capabilities...). +/// +[Experimental(NpgsqlDiagnostics.DatabaseInfoExperimental)] +public abstract class NpgsqlDatabaseInfo +{ + #region Fields + + static volatile INpgsqlDatabaseInfoFactory[] Factories = { + new PostgresMinimalDatabaseInfoFactory(), + new PostgresDatabaseInfoFactory() + }; + + #endregion Fields + + #region General database info + + /// + /// The hostname of IP address of the database. + /// + public string Host { get; } + + /// + /// The TCP port of the database. + /// + public int Port { get; } + + /// + /// The database name. + /// + public string Name { get; } + + /// + /// The version of the PostgreSQL database we're connected to, as reported in the "server_version" parameter. + /// Exposed via . + /// + public Version Version { get; } + + /// + /// The PostgreSQL version string as returned by the server_version option. Populated during loading. + /// + public string ServerVersion { get; } + + #endregion General database info + + #region Supported capabilities and features + + /// + /// Whether the backend supports range types. + /// + public virtual bool SupportsRangeTypes => Version.IsGreaterOrEqual(9, 2); + + /// + /// Whether the backend supports multirange types. + /// + public virtual bool SupportsMultirangeTypes => Version.IsGreaterOrEqual(14); + + /// + /// Whether the backend supports enum types. + /// + public virtual bool SupportsEnumTypes => Version.IsGreaterOrEqual(8, 3); + + /// + /// Whether the backend supports the CLOSE ALL statement. + /// + public virtual bool SupportsCloseAll => Version.IsGreaterOrEqual(8, 3); + + /// + /// Whether the backend supports advisory locks. + /// + public virtual bool SupportsAdvisoryLocks => Version.IsGreaterOrEqual(8, 2); + + /// + /// Whether the backend supports the DISCARD SEQUENCES statement. + /// + public virtual bool SupportsDiscardSequences => Version.IsGreaterOrEqual(9, 4); + + /// + /// Whether the backend supports the UNLISTEN statement. + /// + public virtual bool SupportsUnlisten => Version.IsGreaterOrEqual(6, 4); // overridden by PostgresDatabase + + /// + /// Whether the backend supports the DISCARD TEMP statement. + /// + public virtual bool SupportsDiscardTemp => Version.IsGreaterOrEqual(8, 3); + + /// + /// Whether the backend supports the DISCARD statement. + /// + public virtual bool SupportsDiscard => Version.IsGreaterOrEqual(8, 3); + + /// + /// Reports whether the backend uses the newer integer timestamp representation. + /// + public virtual bool HasIntegerDateTimes { get; protected set; } = true; + + /// + /// Whether the database supports transactions. + /// + public virtual bool SupportsTransactions { get; protected set; } = true; + + #endregion Supported capabilities and features + + #region Types + + readonly List _baseTypesMutable = new(); + readonly List _arrayTypesMutable = new(); + readonly List _rangeTypesMutable = new(); + readonly List _multirangeTypesMutable = new(); + readonly List _enumTypesMutable = new(); + readonly List _compositeTypesMutable = new(); + readonly List _domainTypesMutable = new(); + + internal IReadOnlyList BaseTypes => _baseTypesMutable; + internal IReadOnlyList ArrayTypes => _arrayTypesMutable; + internal IReadOnlyList RangeTypes => _rangeTypesMutable; + internal IReadOnlyList MultirangeTypes => _multirangeTypesMutable; + internal IReadOnlyList EnumTypes => _enumTypesMutable; + internal IReadOnlyList CompositeTypes => _compositeTypesMutable; + internal IReadOnlyList DomainTypes => _domainTypesMutable; + + /// + /// Indexes backend types by their type OID. + /// + internal Dictionary ByOID { get; } = new(); + + /// + /// Indexes backend types by their PostgreSQL internal name, including namespace (e.g. pg_catalog.int4). + /// Only used for enums and composites. + /// + internal Dictionary ByFullName { get; } = new(); + + /// + /// Indexes backend types by their PostgreSQL name, not including namespace. + /// If more than one type exists with the same name (i.e. in different namespaces) this + /// table will contain an entry with a null value. + /// Only used for enums and composites. + /// + internal Dictionary ByName { get; } = new(); + + /// + /// Initializes the instance of . + /// + protected NpgsqlDatabaseInfo(string host, int port, string databaseName, Version version) + : this(host, port, databaseName, version, version.ToString()) + { } + + /// + /// Initializes the instance of . + /// + protected NpgsqlDatabaseInfo(string host, int port, string databaseName, Version version, string serverVersion) + { + Host = host; + Port = port; + Name = databaseName; + Version = version; + ServerVersion = serverVersion; + } + + private protected NpgsqlDatabaseInfo(string host, int port, string databaseName, string serverVersion) + { + Host = host; + Port = port; + Name = databaseName; + ServerVersion = serverVersion; + Version = ParseServerVersion(serverVersion); + } + + internal PostgresType GetPostgresType(Oid oid) => GetPostgresType(oid.Value); + + public PostgresType GetPostgresType(uint oid) + => ByOID.TryGetValue(oid, out var pgType) + ? pgType + : throw new ArgumentException($"A PostgreSQL type with the oid '{oid}' was not found in the current database info"); + + internal PostgresType GetPostgresType(DataTypeName dataTypeName) + => ByFullName.TryGetValue(dataTypeName.Value, out var value) + ? value + : throw new ArgumentException($"A PostgreSQL type with the name '{dataTypeName}' was not found in the current database info"); + + public PostgresType GetPostgresType(string pgName) + => TryGetPostgresTypeByName(pgName, out var pgType) + ? pgType + : throw new ArgumentException($"A PostgreSQL type with the name '{pgName}' was not found in the current database info"); + + public bool TryGetPostgresTypeByName(string pgName, [NotNullWhen(true)] out PostgresType? pgType) + { + // Full type name with namespace + if (pgName.IndexOf('.') > -1) + { + if (ByFullName.TryGetValue(pgName, out pgType)) + return true; + } + // No dot, partial type name + else if (ByName.TryGetValue(pgName, out pgType)) + { + if (pgType is not null) + return true; + + // If the name was found but the value is null, that means that there are + // two db types with the same name (different schemas). + // Try to fall back to pg_catalog, otherwise fail. + if (ByFullName.TryGetValue($"pg_catalog.{pgName}", out pgType)) + return true; + + var ambiguousTypes = new List(); + foreach (var key in ByFullName.Keys) + if (key.EndsWith($".{pgName}", StringComparison.Ordinal)) + ambiguousTypes.Add(key); + + throw new ArgumentException($"More than one PostgreSQL type was found with the name {pgName}, " + + $"please specify a full name including schema: {string.Join(", ", ambiguousTypes)}"); + } + + return false; + } + + internal void ProcessTypes() + { + var unspecified = new PostgresBaseType(DataTypeName.Unspecified, Oid.Unspecified); + ByOID[Oid.Unspecified.Value] = unspecified; + ByFullName[unspecified.DataTypeName.Value] = unspecified; + ByName[unspecified.InternalName] = unspecified; + + foreach (var type in GetTypes()) + { + ByOID[type.OID] = type; + ByFullName[type.DataTypeName.Value] = type; + // If more than one type exists with the same partial name, we place a null value. + // This allows us to detect this case later and force the user to use full names only. + ByName[type.InternalName] = ByName.ContainsKey(type.InternalName) + ? null + : type; + + switch (type) + { + case PostgresBaseType baseType: + _baseTypesMutable.Add(baseType); + continue; + case PostgresArrayType arrayType: + _arrayTypesMutable.Add(arrayType); + continue; + case PostgresRangeType rangeType: + _rangeTypesMutable.Add(rangeType); + continue; + case PostgresMultirangeType multirangeType: + _multirangeTypesMutable.Add(multirangeType); + continue; + case PostgresEnumType enumType: + _enumTypesMutable.Add(enumType); + continue; + case PostgresCompositeType compositeType: + _compositeTypesMutable.Add(compositeType); + continue; + case PostgresDomainType domainType: + _domainTypesMutable.Add(domainType); + continue; + default: + throw new ArgumentOutOfRangeException(); + } + } + } + + /// + /// Provides all PostgreSQL types detected in this database. + /// + /// + protected abstract IEnumerable GetTypes(); + + #endregion Types + + #region Misc + + /// + /// Parses a PostgreSQL server version (e.g. 10.1, 9.6.3) and returns a CLR Version. + /// + protected static Version ParseServerVersion(string value) + { + var versionString = value.TrimStart(); + for (var idx = 0; idx != versionString.Length; ++idx) + { + var c = versionString[idx]; + if (!char.IsDigit(c) && c != '.') + { + versionString = versionString.Substring(0, idx); + break; + } + } + if (!versionString.Contains(".")) + versionString += ".0"; + return new Version(versionString); + } + + #endregion Misc + + #region Factory management + + /// + /// Registers a new database info factory, which is used to load information about databases. + /// + public static void RegisterFactory(INpgsqlDatabaseInfoFactory factory) + { + if (factory == null) + throw new ArgumentNullException(nameof(factory)); + + var factories = new INpgsqlDatabaseInfoFactory[Factories.Length + 1]; + factories[0] = factory; + Array.Copy(Factories, 0, factories, 1, Factories.Length); + Factories = factories; + } + + internal static async Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) + { + foreach (var factory in Factories) + { + var dbInfo = await factory.Load(conn, timeout, async).ConfigureAwait(false); + if (dbInfo != null) + { + dbInfo.ProcessTypes(); + return dbInfo; + } + } + + // Should never be here + throw new NpgsqlException("No DatabaseInfoFactory could be found for this connection"); + } + + // For tests + internal static void ResetFactories() + => Factories = new INpgsqlDatabaseInfoFactory[] + { + new PostgresMinimalDatabaseInfoFactory(), + new PostgresDatabaseInfoFactory() + }; + + #endregion Factory management + + internal Oid GetOid(PgTypeId pgTypeId, bool validate = false) + => pgTypeId.IsOid + ? validate ? GetPostgresType(pgTypeId.Oid).OID : pgTypeId.Oid + : GetPostgresType(pgTypeId.DataTypeName).OID; + + internal DataTypeName GetDataTypeName(PgTypeId pgTypeId, bool validate = false) + => pgTypeId.IsDataTypeName + ? validate ? GetPostgresType(pgTypeId.DataTypeName).DataTypeName : pgTypeId.DataTypeName + : GetPostgresType(pgTypeId.Oid).DataTypeName; + + internal PostgresType GetPostgresType(PgTypeId pgTypeId) + => pgTypeId.IsOid + ? GetPostgresType(pgTypeId.Oid.Value) + : GetPostgresType(pgTypeId.DataTypeName.Value); + + internal PostgresType? FindPostgresType(PgTypeId pgTypeId) + => pgTypeId.IsOid + ? ByOID.TryGetValue(pgTypeId.Oid.Value, out var pgType) ? pgType : null + : TryGetPostgresTypeByName(pgTypeId.DataTypeName.Value, out pgType) ? pgType : null; +} diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs new file mode 100644 index 0000000000..e99b77fa1b --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs @@ -0,0 +1,246 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +sealed partial class NpgsqlReadBuffer +{ + internal sealed class ColumnStream : Stream +#if NETSTANDARD2_0 + , IAsyncDisposable +#endif + { + readonly NpgsqlConnector _connector; + readonly NpgsqlReadBuffer _buf; + long _startPos; + int _start; + int _read; + bool _canSeek; + bool _commandScoped; + /// Does not throw ODE. + internal int CurrentLength { get; private set; } + internal bool IsDisposed { get; private set; } + + internal ColumnStream(NpgsqlConnector connector) + { + _connector = connector; + _buf = connector.ReadBuffer; + IsDisposed = true; + } + + internal void Init(int len, bool canSeek, bool commandScoped) + { + Debug.Assert(!canSeek || _buf.ReadBytesLeft >= len, + "Seekable stream constructed but not all data is in buffer (sequential)"); + _startPos = _buf.CumulativeReadPosition; + + _canSeek = canSeek; + _start = canSeek ? _buf.ReadPosition : 0; + + CurrentLength = len; + _read = 0; + + _commandScoped = commandScoped; + IsDisposed = false; + } + + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override bool CanSeek => _canSeek; + + public override long Length + { + get + { + CheckDisposed(); + return CurrentLength; + } + } + + public override void SetLength(long value) + => throw new NotSupportedException(); + + public override long Position + { + get + { + CheckDisposed(); + return _read; + } + set + { + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), "Non - negative number required."); + Seek(value, SeekOrigin.Begin); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + CheckDisposed(); + + if (!_canSeek) + throw new NotSupportedException(); + if (offset > int.MaxValue) + throw new ArgumentOutOfRangeException(nameof(offset), "Stream length must be non-negative and less than 2^31 - 1 - origin."); + + const string seekBeforeBegin = "An attempt was made to move the position before the beginning of the stream."; + + switch (origin) + { + case SeekOrigin.Begin: + { + var tempPosition = unchecked(_start + (int)offset); + if (offset < 0 || tempPosition < _start) + throw new IOException(seekBeforeBegin); + _buf.ReadPosition = tempPosition; + _read = (int)offset; + return _read; + } + case SeekOrigin.Current: + { + var tempPosition = unchecked(_buf.ReadPosition + (int)offset); + if (unchecked(_buf.ReadPosition + offset) < _start || tempPosition < _start) + throw new IOException(seekBeforeBegin); + _buf.ReadPosition = tempPosition; + _read += (int)offset; + return _read; + } + case SeekOrigin.End: + { + var tempPosition = unchecked(_start + CurrentLength + (int)offset); + if (unchecked(_start + CurrentLength + offset) < _start || tempPosition < _start) + throw new IOException(seekBeforeBegin); + _buf.ReadPosition = tempPosition; + _read = CurrentLength + (int)offset; + return _read; + } + default: + throw new ArgumentOutOfRangeException(nameof(origin), "Invalid seek origin."); + } + } + + public override void Flush() + => CheckDisposed(); + + public override Task FlushAsync(CancellationToken cancellationToken) + { + CheckDisposed(); + return cancellationToken.IsCancellationRequested + ? Task.FromCanceled(cancellationToken) : Task.CompletedTask; + } + + public override int ReadByte() + { + Span byteSpan = stackalloc byte[1]; + var read = Read(byteSpan); + return read > 0 ? byteSpan[0] : -1; + } + + public override int Read(byte[] buffer, int offset, int count) + { + ValidateArguments(buffer, offset, count); + return Read(new Span(buffer, offset, count)); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateArguments(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + +#if NETSTANDARD2_0 + public int Read(Span span) +#else + public override int Read(Span span) +#endif + { + CheckDisposed(); + + var count = Math.Min(span.Length, CurrentLength - _read); + + if (count == 0) + return 0; + + var read = _buf.Read(_commandScoped, span.Slice(0, count)); + _read += read; + + return read; + } + +#if NETSTANDARD2_0 + public ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#else + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#endif + { + CheckDisposed(); + + var count = Math.Min(buffer.Length, CurrentLength - _read); + return count == 0 ? new ValueTask(0) : ReadLong(this, buffer.Slice(0, count), cancellationToken); + + static async ValueTask ReadLong(ColumnStream stream, Memory buffer, CancellationToken cancellationToken = default) + { + using var registration = cancellationToken.CanBeCanceled + ? stream._connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false) + : default; + + var read = await stream._buf.ReadAsync(stream._commandScoped, buffer, cancellationToken).ConfigureAwait(false); + stream._read += read; + return read; + } + } + + public override void Write(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + void CheckDisposed() + { + if (IsDisposed) + ThrowHelper.ThrowObjectDisposedException(nameof(ColumnStream)); + } + + protected override void Dispose(bool disposing) + => DisposeAsync(disposing, async: false).GetAwaiter().GetResult(); + +#if NETSTANDARD2_0 + public ValueTask DisposeAsync() +#else + public override ValueTask DisposeAsync() +#endif + => DisposeAsync(disposing: true, async: true); + + async ValueTask DisposeAsync(bool disposing, bool async) + { + if (IsDisposed || !disposing) + return; + + if (!_connector.IsBroken) + { + var pos = _buf.CumulativeReadPosition - _startPos; + var remaining = checked((int)(CurrentLength - pos)); + if (remaining > 0) + await _buf.Skip(remaining, async).ConfigureAwait(false); + } + + IsDisposed = true; + } + } + + static void ValidateArguments(byte[] buffer, int offset, int count) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0) + throw new ArgumentOutOfRangeException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } +} diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.cs new file mode 100644 index 0000000000..03e2499a91 --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.cs @@ -0,0 +1,845 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Util; +using static System.Threading.Timeout; + +namespace Npgsql.Internal; + +/// +/// A buffer used by Npgsql to read data from the socket efficiently. +/// Provides methods which decode different values types and tracks the current position. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +sealed partial class NpgsqlReadBuffer : IDisposable +{ + #region Fields and Properties + +#if DEBUG + internal static readonly bool BufferBoundsChecks = true; +#else + internal static readonly bool BufferBoundsChecks = Statics.EnableAssertions; +#endif + + public NpgsqlConnection Connection => Connector.Connection!; + internal readonly NpgsqlConnector Connector; + internal Stream Underlying { private get; set; } + readonly Socket? _underlyingSocket; + internal ResettableCancellationTokenSource Cts { get; } + readonly MetricsReporter? _metricsReporter; + + TimeSpan _preTranslatedTimeout = TimeSpan.Zero; + + /// + /// Timeout for sync and async reads + /// + internal TimeSpan Timeout + { + get => _preTranslatedTimeout; + set + { + if (_preTranslatedTimeout != value) + { + _preTranslatedTimeout = value; + + if (value == TimeSpan.Zero) + value = InfiniteTimeSpan; + else if (value < TimeSpan.Zero) + value = TimeSpan.Zero; + + Debug.Assert(_underlyingSocket != null); + + _underlyingSocket.ReceiveTimeout = (int)value.TotalMilliseconds; + Cts.Timeout = value; + } + } + } + + /// + /// The total byte length of the buffer. + /// + internal int Size { get; } + + internal Encoding TextEncoding { get; } + + /// + /// Same as , except that it does not throw an exception if an invalid char is + /// encountered (exception fallback), but rather replaces it with a question mark character (replacement + /// fallback). + /// + internal Encoding RelaxedTextEncoding { get; } + + internal int ReadPosition { get; set; } + internal int ReadBytesLeft => FilledBytes - ReadPosition; + internal PgReader PgReader { get; } + + long _flushedBytes; // this will always fit at least one message. + internal long CumulativeReadPosition + // Cast to uint to remove the sign extension (ReadPosition is never negative) + => _flushedBytes + (uint)ReadPosition; + + internal readonly byte[] Buffer; + internal int FilledBytes; + + internal ReadOnlySpan Span => Buffer.AsSpan(ReadPosition, ReadBytesLeft); + + readonly bool _usePool; + bool _disposed; + + /// + /// The minimum buffer size possible. + /// + internal const int MinimumSize = 4096; + internal const int DefaultSize = 8192; + + #endregion + + #region Constructors + + internal NpgsqlReadBuffer( + NpgsqlConnector? connector, + Stream stream, + Socket? socket, + int size, + Encoding textEncoding, + Encoding relaxedTextEncoding, + bool usePool = false) + { + if (size < MinimumSize) + { + throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); + } + + Connector = connector!; // TODO: Clean this up + Underlying = stream; + _underlyingSocket = socket; + _metricsReporter = connector?.DataSource.MetricsReporter; + Cts = new ResettableCancellationTokenSource(); + Buffer = usePool ? ArrayPool.Shared.Rent(size) : new byte[size]; + Size = Buffer.Length; + _usePool = usePool; + + TextEncoding = textEncoding; + RelaxedTextEncoding = relaxedTextEncoding; + PgReader = new PgReader(this); + } + + #endregion + + #region I/O + + public void Ensure(int count) + => Ensure(count, async: false, readingNotifications: false).GetAwaiter().GetResult(); + + public ValueTask Ensure(int count, bool async) + => Ensure(count, async, readingNotifications: false); + + public ValueTask EnsureAsync(int count) + => Ensure(count, async: true, readingNotifications: false); + + // Can't share due to Span vs Memory difference (can't make a memory out of a span). + int ReadWithTimeout(Span buffer) + { + while (true) + { + try + { + var read = Underlying.Read(buffer); + _flushedBytes = unchecked(_flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception ex) + { + var connector = Connector; + switch (ex) + { + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException e when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + { + var isStreamBroken = false; +#if NETSTANDARD2_0 + // SslStream on .NET Framework treats any IOException (including timeouts) as fatal and may + // return garbage if reused. To prevent this, we flow down and break the connection immediately. + // See #4305. + isStreamBroken = connector.IsSecure && ex is IOException; +#endif + + // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. + // TODO: As an optimization, we can still attempt to send a cancellation request, but after + // that immediately break the connection + if (connector.AttemptPostgresCancellation && + !connector.PostgresCancellationPerformed && + connector.PerformPostgresCancellation() && + !isStreamBroken) + { + // Note that if the cancellation timeout is negative, we flow down and break the + // connection immediately. + var cancellationTimeout = connector.Settings.CancellationTimeout; + if (cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + + continue; + } + } + + // If we're here, the PostgreSQL cancellation either failed or skipped entirely. + // Break the connection, bubbling up the correct exception type (cancellation or timeout) + throw connector.Break(CreateCancelException(connector)); + } + default: + throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); + } + } + } + } + + async ValueTask ReadWithTimeoutAsync(Memory buffer, CancellationToken cancellationToken) + { + var finalCt = Timeout != TimeSpan.Zero + ? Cts.Start(cancellationToken) + : Cts.Reset(); + + while (true) + { + try + { + var read = await Underlying.ReadAsync(buffer, finalCt).ConfigureAwait(false); + _flushedBytes = unchecked(_flushedBytes + read); + Cts.Stop(); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception ex) + { + var connector = Connector; + Cts.Stop(); + switch (ex) + { + // Read timeout + case OperationCanceledException: + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException e when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + { + Debug.Assert(ex is OperationCanceledException); + var isStreamBroken = false; +#if NETSTANDARD2_0 + // SslStream on .NET Framework treats any IOException (including timeouts) as fatal and may + // return garbage if reused. To prevent this, we flow down and break the connection immediately. + // See #4305. + isStreamBroken = connector.IsSecure && ex is IOException; +#endif + // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. + // TODO: As an optimization, we can still attempt to send a cancellation request, but after + // that immediately break the connection + if (connector.AttemptPostgresCancellation && + !connector.PostgresCancellationPerformed && + connector.PerformPostgresCancellation() && + !isStreamBroken) + { + // Note that if the cancellation timeout is negative, we flow down and break the + // connection immediately. + var cancellationTimeout = connector.Settings.CancellationTimeout; + if (cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + + finalCt = Cts.Start(cancellationToken); + continue; + } + } + + // If we're here, the PostgreSQL cancellation either failed or skipped entirely. + // Break the connection, bubbling up the correct exception type (cancellation or timeout) + throw connector.Break(CreateCancelException(connector)); + } + default: + throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); + } + } + } + } + + static Exception CreateCancelException(NpgsqlConnector connector) + => !connector.UserCancellationRequested + ? NpgsqlTimeoutException() + : connector.PostgresCancellationPerformed + ? new OperationCanceledException("Query was cancelled", TimeoutException(), connector.UserCancellationToken) + : new OperationCanceledException("Query was cancelled", connector.UserCancellationToken); + + static Exception NpgsqlTimeoutException() => new NpgsqlException("Exception while reading from stream", TimeoutException()); + + static Exception TimeoutException() => new TimeoutException("Timeout during reading attempt"); + + /// + /// Ensures that bytes are available in the buffer, and if + /// not, reads from the socket until enough is available. + /// + internal ValueTask Ensure(int count, bool async, bool readingNotifications) + { + return count <= ReadBytesLeft ? new() : EnsureLong(this, count, async, readingNotifications); + +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] +#endif + static async ValueTask EnsureLong( + NpgsqlReadBuffer buffer, + int count, + bool async, + bool readingNotifications) + { + Debug.Assert(count <= buffer.Size); + Debug.Assert(count > buffer.ReadBytesLeft); + count -= buffer.ReadBytesLeft; + + if (buffer.ReadPosition == buffer.FilledBytes) + { + buffer.ResetPosition(); + } + else if (count > buffer.Size - buffer.FilledBytes) + { + Array.Copy(buffer.Buffer, buffer.ReadPosition, buffer.Buffer, 0, buffer.ReadBytesLeft); + buffer.FilledBytes = buffer.ReadBytesLeft; + buffer._flushedBytes = unchecked(buffer._flushedBytes + buffer.ReadPosition); + buffer.ReadPosition = 0; + } + + var finalCt = async && buffer.Timeout != TimeSpan.Zero + ? buffer.Cts.Start() + : buffer.Cts.Reset(); + + var totalRead = 0; + while (count > 0) + { + try + { + var toRead = buffer.Size - buffer.FilledBytes; + var read = async + ? await buffer.Underlying.ReadAsync(buffer.Buffer.AsMemory(buffer.FilledBytes, toRead), finalCt).ConfigureAwait(false) + : buffer.Underlying.Read(buffer.Buffer, buffer.FilledBytes, toRead); + + if (read == 0) + throw new EndOfStreamException(); + count -= read; + buffer.FilledBytes += read; + totalRead += read; + + // Most of the time, it should be fine to reset cancellation token source, so we can use it again + // It's still possible for cancellation token to cancel between reading and resetting (although highly improbable) + // In this case, we consider it as timed out and fail with OperationCancelledException on next ReadAsync + // Or we consider it not timed out if we have already read everything (count == 0) + // In which case we reinitialize it on the next call to EnsureLong() + if (async && count > 0) + buffer.Cts.RestartTimeoutWithoutReset(); + } + catch (Exception e) + { + var connector = buffer.Connector; + + // Stopping twice (in case the previous Stop() call succeeded) doesn't hurt. + // Not stopping will cause an assertion failure in debug mode when we call Start() the next time. + // We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained + // _timeoutCts + buffer.Cts.Stop(); + + switch (e) + { + // Read timeout + case OperationCanceledException: + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + { + Debug.Assert(e is OperationCanceledException ? async : !async); + + var isStreamBroken = false; +#if NETSTANDARD2_0 + // SslStream on .NET Framework treats any IOException (including timeouts) as fatal and may + // return garbage if reused. To prevent this, we flow down and break the connection immediately. + // See #4305. + isStreamBroken = connector.IsSecure && e is IOException; +#endif + // When reading notifications (Wait), just throw TimeoutException or + // OperationCanceledException immediately. + // Nothing to cancel, and no breaking of the connection. + if (readingNotifications && !isStreamBroken) + throw CreateException(connector); + + // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. + // TODO: As an optimization, we can still attempt to send a cancellation request, but after + // that immediately break the connection + if (connector.AttemptPostgresCancellation && + !connector.PostgresCancellationPerformed && + connector.PerformPostgresCancellation() && + !isStreamBroken) + { + // Note that if the cancellation timeout is negative, we flow down and break the + // connection immediately. + var cancellationTimeout = connector.Settings.CancellationTimeout; + if (cancellationTimeout >= 0) + { + if (cancellationTimeout > 0) + buffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + + if (async) + finalCt = buffer.Cts.Start(); + + continue; + } + } + + // If we're here, the PostgreSQL cancellation either failed or skipped entirely. + // Break the connection, bubbling up the correct exception type (cancellation or timeout) + throw connector.Break(CreateException(connector)); + + static Exception CreateException(NpgsqlConnector connector) + => !connector.UserCancellationRequested + ? NpgsqlTimeoutException() + : connector.PostgresCancellationPerformed + ? new OperationCanceledException("Query was cancelled", TimeoutException(), connector.UserCancellationToken) + : new OperationCanceledException("Query was cancelled", connector.UserCancellationToken); + } + + default: + throw connector.Break(new NpgsqlException("Exception while reading from stream", e)); + } + } + } + + buffer.Cts.Stop(); + NpgsqlEventSource.Log.BytesRead(totalRead); + buffer._metricsReporter?.ReportBytesRead(totalRead); + + static Exception NpgsqlTimeoutException() => new NpgsqlException("Exception while reading from stream", TimeoutException()); + + static Exception TimeoutException() => new TimeoutException("Timeout during reading attempt"); + } + } + + internal ValueTask ReadMore(bool async) => Ensure(ReadBytesLeft + 1, async); + + internal NpgsqlReadBuffer AllocateOversize(int count) + { + Debug.Assert(count > Size); + var tempBuf = new NpgsqlReadBuffer(Connector, Underlying, _underlyingSocket, count, TextEncoding, RelaxedTextEncoding, usePool: true); + if (_underlyingSocket != null) + tempBuf.Timeout = Timeout; + CopyTo(tempBuf); + ResetPosition(); + return tempBuf; + } + + /// + /// Does not perform any I/O - assuming that the bytes to be skipped are in the memory buffer. + /// + internal void Skip(int len) + { + Debug.Assert(ReadBytesLeft >= len); + ReadPosition += len; + } + + /// + /// Skip a given number of bytes. + /// + public async Task Skip(int len, bool async) + { + Debug.Assert(len >= 0); + + if (len > ReadBytesLeft) + { + len -= ReadBytesLeft; + while (len > Size) + { + ResetPosition(); + await Ensure(Size, async).ConfigureAwait(false); + len -= Size; + } + ResetPosition(); + await Ensure(len, async).ConfigureAwait(false); + } + + ReadPosition += len; + } + + #endregion + + #region Read Simple + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public byte ReadByte() + { + CheckBounds(sizeof(byte)); + var result = Buffer[ReadPosition]; + ReadPosition += sizeof(byte); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public short ReadInt16() + { + CheckBounds(sizeof(short)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(short); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ushort ReadUInt16() + { + CheckBounds(sizeof(ushort)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(ushort); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int ReadInt32() + { + CheckBounds(sizeof(int)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(int); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public uint ReadUInt32() + { + CheckBounds(sizeof(uint)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(uint); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public long ReadInt64() + { + CheckBounds(sizeof(long)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(long); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ulong ReadUInt64() + { + CheckBounds(sizeof(ulong)); + var result = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(ulong); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float ReadSingle() + { + CheckBounds(sizeof(float)); +#if NETSTANDARD2_0 + float result; + if (BitConverter.IsLittleEndian) + { + var value = BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])); + result = Unsafe.As(ref value); + } + else + result = Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); +#else + var result = BitConverter.IsLittleEndian + ? BitConverter.Int32BitsToSingle(BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition]))) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); +#endif + ReadPosition += sizeof(float); + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public double ReadDouble() + { + CheckBounds(sizeof(double)); + var result = BitConverter.IsLittleEndian + ? BitConverter.Int64BitsToDouble(BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition]))) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + ReadPosition += sizeof(double); + return result; + } + + void CheckBounds(int count) + { + if (BufferBoundsChecks) + Core(count); + + [MethodImpl(MethodImplOptions.NoInlining)] + void Core(int count) + { + if (count > ReadBytesLeft) + ThrowHelper.ThrowInvalidOperationException("There is not enough data left in the buffer."); + } + } + + public string ReadString(int byteLen) + { + Debug.Assert(byteLen <= ReadBytesLeft); + var result = TextEncoding.GetString(Buffer, ReadPosition, byteLen); + ReadPosition += byteLen; + return result; + } + + public void ReadBytes(Span output) + { + Debug.Assert(output.Length <= ReadBytesLeft); + new Span(Buffer, ReadPosition, output.Length).CopyTo(output); + ReadPosition += output.Length; + } + + public void ReadBytes(byte[] output, int outputOffset, int len) + => ReadBytes(new Span(output, outputOffset, len)); + + public ReadOnlyMemory ReadMemory(int len) + { + Debug.Assert(len <= ReadBytesLeft); + var memory = new ReadOnlyMemory(Buffer, ReadPosition, len); + ReadPosition += len; + return memory; + } + + #endregion + + #region Read Complex + + public int Read(bool commandScoped, Span output) + { + var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); + if (readFromBuffer > 0) + { + Buffer.AsSpan(ReadPosition, readFromBuffer).CopyTo(output); + ReadPosition += readFromBuffer; + return readFromBuffer; + } + + // Only reset if we'll be able to read data, this is to support zero-byte reads. + if (output.Length > 0) + { + Debug.Assert(ReadBytesLeft == 0); + ResetPosition(); + } + + if (commandScoped) + return ReadWithTimeout(output); + + try + { + var read = Underlying.Read(output); + _flushedBytes = unchecked(_flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception e) + { + throw Connector.Break(new NpgsqlException("Exception while reading from stream", e)); + } + } + + public ValueTask ReadAsync(bool commandScoped, Memory output, CancellationToken cancellationToken = default) + { + var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); + if (readFromBuffer > 0) + { + Buffer.AsSpan(ReadPosition, readFromBuffer).CopyTo(output.Span); + ReadPosition += readFromBuffer; + return new ValueTask(readFromBuffer); + } + + return ReadAsyncLong(this, commandScoped, output, cancellationToken); + + static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, bool commandScoped, Memory output, CancellationToken cancellationToken) + { + // Only reset if we'll be able to read data, this is to support zero-byte reads. + if (output.Length > 0) + { + Debug.Assert(buffer.ReadBytesLeft == 0); + buffer.ResetPosition(); + } + + if (commandScoped) + return await buffer.ReadWithTimeoutAsync(output, cancellationToken).ConfigureAwait(false); + + try + { + var read = await buffer.Underlying.ReadAsync(output, cancellationToken).ConfigureAwait(false); + buffer._flushedBytes = unchecked(buffer._flushedBytes + read); + NpgsqlEventSource.Log.BytesRead(read); + return read; + } + catch (Exception e) + { + throw buffer.Connector.Break(new NpgsqlException("Exception while reading from stream", e)); + } + } + } + + ColumnStream? _lastStream; + public ColumnStream CreateStream(int len, bool canSeek) + { + if (_lastStream is not { IsDisposed: true }) + _lastStream = new ColumnStream(Connector); + _lastStream.Init(len, canSeek, !Connector.LongRunningConnection); + return _lastStream; + } + + /// + /// Seeks the first null terminator (\0) and returns the string up to it. The buffer must already + /// contain the entire string and its terminator. + /// + public string ReadNullTerminatedString() + => ReadNullTerminatedString(TextEncoding, async: false).GetAwaiter().GetResult(); + + /// + /// Seeks the first null terminator (\0) and returns the string up to it. The buffer must already + /// contain the entire string and its terminator. If any character could not be decoded, a question + /// mark character is returned instead of throwing an exception. + /// + public string ReadNullTerminatedStringRelaxed() + => ReadNullTerminatedString(RelaxedTextEncoding, async: false).GetAwaiter().GetResult(); + + public ValueTask ReadNullTerminatedString(bool async, CancellationToken cancellationToken = default) + => ReadNullTerminatedString(TextEncoding, async, cancellationToken); + + /// + /// Seeks the first null terminator (\0) and returns the string up to it. Reads additional data from the network if a null + /// terminator isn't found in the buffered data. + /// + public ValueTask ReadNullTerminatedString(Encoding encoding, bool async, CancellationToken cancellationToken = default) + { + var index = Span.IndexOf((byte)0); + if (index >= 0) + { + var result = new ValueTask(encoding.GetString(Buffer, ReadPosition, index)); + ReadPosition += index + 1; + return result; + } + + return ReadLong(encoding, async); + + async ValueTask ReadLong(Encoding encoding, bool async) + { + var chunkSize = FilledBytes - ReadPosition; + var tempBuf = ArrayPool.Shared.Rent(chunkSize + 1024); + + try + { + bool foundTerminator; + var byteLen = chunkSize; + Array.Copy(Buffer, ReadPosition, tempBuf, 0, chunkSize); + ReadPosition += chunkSize; + + do + { + await ReadMore(async).ConfigureAwait(false); + Debug.Assert(ReadPosition == 0); + + foundTerminator = false; + int i; + for (i = 0; i < FilledBytes; i++) + { + if (Buffer[i] == 0) + { + foundTerminator = true; + break; + } + } + + if (byteLen + i > tempBuf.Length) + { + var newTempBuf = ArrayPool.Shared.Rent( + foundTerminator ? byteLen + i : byteLen + i + 1024); + + Array.Copy(tempBuf, 0, newTempBuf, 0, byteLen); + ArrayPool.Shared.Return(tempBuf); + tempBuf = newTempBuf; + } + + Array.Copy(Buffer, 0, tempBuf, byteLen, i); + byteLen += i; + ReadPosition = i; + } while (!foundTerminator); + + ReadPosition++; + return encoding.GetString(tempBuf, 0, byteLen); + } + finally + { + ArrayPool.Shared.Return(tempBuf); + } + } + } + + public ReadOnlySpan GetNullTerminatedBytes() + { + var i = Span.IndexOf((byte)0); + Debug.Assert(i >= 0); + var result = new ReadOnlySpan(Buffer, ReadPosition, i); + ReadPosition += i + 1; + return result; + } + + #endregion + + #region Dispose + + public void Dispose() + { + if (_disposed) + return; + + if (_usePool) + ArrayPool.Shared.Return(Buffer); + + Cts.Dispose(); + _disposed = true; + } + + #endregion + + #region Misc + + void ResetPosition() + { + _flushedBytes = unchecked(_flushedBytes + FilledBytes); + ReadPosition = 0; + FilledBytes = 0; + } + + internal void ResetFlushedBytes() => _flushedBytes = 0; + + internal void CopyTo(NpgsqlReadBuffer other) + { + Debug.Assert(other.Size - other.FilledBytes >= ReadBytesLeft); + Array.Copy(Buffer, ReadPosition, other.Buffer, other.FilledBytes, ReadBytesLeft); + other.FilledBytes += ReadBytesLeft; + } + + #endregion +} diff --git a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs new file mode 100644 index 0000000000..99146339d0 --- /dev/null +++ b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs @@ -0,0 +1,630 @@ +using System; +using System.Buffers.Binary; +using System.Diagnostics; +using System.IO; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Util; +using static System.Threading.Timeout; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +namespace Npgsql.Internal; + +/// +/// A buffer used by Npgsql to write data to the socket efficiently. +/// Provides methods which encode different values types and tracks the current position. +/// +sealed class NpgsqlWriteBuffer : IDisposable +{ + #region Fields and Properties + + internal static readonly UTF8Encoding UTF8Encoding = new(false, true); + internal static readonly UTF8Encoding RelaxedUTF8Encoding = new(false, false); + + internal readonly NpgsqlConnector Connector; + + internal Stream Underlying { private get; set; } + + readonly Socket? _underlyingSocket; + internal bool MessageLengthValidation { get; set; } = true; + + readonly ResettableCancellationTokenSource _timeoutCts; + readonly MetricsReporter? _metricsReporter; + + /// + /// Timeout for sync and async writes + /// + internal TimeSpan Timeout + { + get => _timeoutCts.Timeout; + set + { + if (_timeoutCts.Timeout != value) + { + Debug.Assert(_underlyingSocket != null); + + if (value > TimeSpan.Zero) + { + _underlyingSocket.SendTimeout = (int)value.TotalMilliseconds; + _timeoutCts.Timeout = value; + } + else + { + _underlyingSocket.SendTimeout = -1; + _timeoutCts.Timeout = InfiniteTimeSpan; + } + } + } + } + + /// + /// The total byte length of the buffer. + /// + internal int Size { get; private set; } + + bool _copyMode; + internal Encoding TextEncoding { get; } + + public int WriteSpaceLeft => Size - WritePosition; + + // (Re)init to make sure we'll refetch from the write buffer. + internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode = FlushMode.None) + => _pgWriter.Init(typeCatalog, flushMode); + + internal readonly byte[] Buffer; + readonly Encoder _textEncoder; + + internal int WritePosition; + + int _messageBytesFlushed; + int? _messageLength; + + bool _disposed; + readonly PgWriter _pgWriter; + + /// + /// The minimum buffer size possible. + /// + internal const int MinimumSize = 4096; + internal const int DefaultSize = 8192; + + #endregion + + #region Constructors + + internal NpgsqlWriteBuffer( + NpgsqlConnector? connector, + Stream stream, + Socket? socket, + int size, + Encoding textEncoding) + { + if (size < MinimumSize) + throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); + + Connector = connector!; // TODO: Clean this up; only null when used from PregeneratedMessages, where we don't care. + Underlying = stream; + _underlyingSocket = socket; + _metricsReporter = connector?.DataSource.MetricsReporter!; + _timeoutCts = new ResettableCancellationTokenSource(); + Buffer = new byte[size]; + Size = size; + + TextEncoding = textEncoding; + _textEncoder = TextEncoding.GetEncoder(); + _pgWriter = new PgWriter(new NpgsqlBufferWriter(this)); + } + + #endregion + + #region I/O + + public async Task Flush(bool async, CancellationToken cancellationToken = default) + { + if (_copyMode) + { + // In copy mode, we write CopyData messages. The message code has already been + // written to the beginning of the buffer, but we need to go back and write the + // length. + if (WritePosition == 1) + return; + var pos = WritePosition; + WritePosition = 1; + WriteInt32(pos - 1); + WritePosition = pos; + } else if (WritePosition == 0) + return; + else + AdvanceMessageBytesFlushed(WritePosition); + + var finalCt = async && Timeout > TimeSpan.Zero + ? _timeoutCts.Start(cancellationToken) + : cancellationToken; + + try + { + if (async) + { + await Underlying.WriteAsync(Buffer, 0, WritePosition, finalCt).ConfigureAwait(false); + await Underlying.FlushAsync(finalCt).ConfigureAwait(false); + if (Timeout > TimeSpan.Zero) + _timeoutCts.Stop(); + } + else + { + Underlying.Write(Buffer, 0, WritePosition); + Underlying.Flush(); + } + } + catch (Exception e) + { + // Stopping twice (in case the previous Stop() call succeeded) doesn't hurt. + // Not stopping will cause an assertion failure in debug mode when we call Start() the next time. + // We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained + // _timeoutCts + _timeoutCts.Stop(); + switch (e) + { + // User requested the cancellation + case OperationCanceledException _ when (cancellationToken.IsCancellationRequested): + throw Connector.Break(e); + // Read timeout + case OperationCanceledException _: + // Note that mono throws SocketException with the wrong error (see #1330) + case IOException _ when (e.InnerException as SocketException)?.SocketErrorCode == + (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + Debug.Assert(e is OperationCanceledException ? async : !async); + throw Connector.Break(new NpgsqlException("Exception while writing to stream", new TimeoutException("Timeout during writing attempt"))); + } + + throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); + } + NpgsqlEventSource.Log.BytesWritten(WritePosition); + _metricsReporter?.ReportBytesWritten(WritePosition); + + WritePosition = 0; + if (_copyMode) + WriteCopyDataHeader(); + } + + internal void Flush() => Flush(false).GetAwaiter().GetResult(); + + #endregion + + #region Direct write + + internal void DirectWrite(ReadOnlySpan buffer) + { + Flush(); + + if (_copyMode) + { + // Flush has already written the CopyData header for us, but write the CopyData + // header to the socket with the write length before we can start writing the data directly. + Debug.Assert(WritePosition == 5); + + WritePosition = 1; + WriteInt32(checked(buffer.Length + 4)); + WritePosition = 5; + _copyMode = false; + StartMessage(5); + Flush(); + _copyMode = true; + WriteCopyDataHeader(); // And ready the buffer after the direct write completes + } + else + { + Debug.Assert(WritePosition == 0); + AdvanceMessageBytesFlushed(buffer.Length); + } + + try + { + Underlying.Write(buffer); + } + catch (Exception e) + { + throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); + } + } + + internal async Task DirectWrite(ReadOnlyMemory memory, bool async, CancellationToken cancellationToken = default) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + + if (_copyMode) + { + // Flush has already written the CopyData header for us, but write the CopyData + // header to the socket with the write length before we can start writing the data directly. + Debug.Assert(WritePosition == 5); + + WritePosition = 1; + WriteInt32(checked(memory.Length + 4)); + WritePosition = 5; + _copyMode = false; + StartMessage(5); + await Flush(async, cancellationToken).ConfigureAwait(false); + _copyMode = true; + WriteCopyDataHeader(); // And ready the buffer after the direct write completes + } + else + { + Debug.Assert(WritePosition == 0); + AdvanceMessageBytesFlushed(memory.Length); + } + + try + { + if (async) + await Underlying.WriteAsync(memory, cancellationToken).ConfigureAwait(false); + else + Underlying.Write(memory.Span); + } + catch (Exception e) + { + throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); + } + } + + #endregion Direct write + + #region Write Simple + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteByte(byte value) + { + CheckBounds(); + Buffer[WritePosition] = value; + WritePosition += sizeof(byte); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteInt16(short value) + { + CheckBounds(); + Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); + WritePosition += sizeof(short); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteUInt16(ushort value) + { + CheckBounds(); + Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); + WritePosition += sizeof(ushort); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteInt32(int value) + { + CheckBounds(); + Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); + WritePosition += sizeof(int); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteUInt32(uint value) + { + CheckBounds(); + Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); + WritePosition += sizeof(uint); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteInt64(long value) + { + CheckBounds(); + Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); + WritePosition += sizeof(long); + } + + [Conditional("DEBUG")] + unsafe void CheckBounds() where T : unmanaged + { + if (sizeof(T) > WriteSpaceLeft) + ThrowNotSpaceLeft(); + } + + static void ThrowNotSpaceLeft() + => ThrowHelper.ThrowInvalidOperationException("There is not enough space left in the buffer."); + + public Task WriteString(string s, int byteLen, bool async, CancellationToken cancellationToken = default) + => WriteString(s, s.Length, byteLen, async, cancellationToken); + + public Task WriteString(string s, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default) + { + if (byteLen <= WriteSpaceLeft) + { + WriteString(s, charLen); + return Task.CompletedTask; + } + return WriteStringLong(this, async, s, charLen, byteLen, cancellationToken); + + static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int charLen, int byteLen, CancellationToken cancellationToken) + { + Debug.Assert(byteLen > buffer.WriteSpaceLeft); + if (byteLen <= buffer.Size) + { + // String can fit entirely in an empty buffer. Flush and retry rather than + // going into the partial writing flow below (which requires ToCharArray()) + await buffer.Flush(async, cancellationToken).ConfigureAwait(false); + buffer.WriteString(s, charLen); + } + else + { + var charPos = 0; + while (true) + { + buffer.WriteStringChunked(s, charPos, charLen - charPos, true, out var charsUsed, out var completed); + if (completed) + break; + await buffer.Flush(async, cancellationToken).ConfigureAwait(false); + charPos += charsUsed; + } + } + } + } + + public void WriteString(string s, int len = 0) + { + Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft); + WritePosition += TextEncoding.GetBytes(s, 0, len == 0 ? s.Length : len, Buffer, WritePosition); + } + + public void WriteBytes(ReadOnlySpan buf) + { + Debug.Assert(buf.Length <= WriteSpaceLeft); + buf.CopyTo(new Span(Buffer, WritePosition, Buffer.Length - WritePosition)); + WritePosition += buf.Length; + } + + public void WriteBytes(ReadOnlyMemory buf) + => WriteBytes(buf.Span); + + public void WriteBytes(byte[] buf) => WriteBytes(buf.AsSpan()); + + public void WriteBytes(byte[] buf, int offset, int count) + => WriteBytes(new ReadOnlySpan(buf, offset, count)); + + public Task WriteBytesRaw(ReadOnlyMemory bytes, bool async, CancellationToken cancellationToken = default) + { + if (bytes.Length <= WriteSpaceLeft) + { + WriteBytes(bytes); + return Task.CompletedTask; + } + return WriteBytesLong(this, async, bytes, cancellationToken); + + static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, ReadOnlyMemory bytes, CancellationToken cancellationToken) + { + if (bytes.Length <= buffer.Size) + { + // value can fit entirely in an empty buffer. Flush and retry rather than + // going into the partial writing flow below + await buffer.Flush(async, cancellationToken).ConfigureAwait(false); + buffer.WriteBytes(bytes); + } + else + { + var remaining = bytes.Length; + do + { + if (buffer.WriteSpaceLeft == 0) + await buffer.Flush(async, cancellationToken).ConfigureAwait(false); + var writeLen = Math.Min(remaining, buffer.WriteSpaceLeft); + var offset = bytes.Length - remaining; + buffer.WriteBytes(bytes.Slice(offset, writeLen)); + remaining -= writeLen; + } + while (remaining > 0); + } + } + } + + public async Task WriteStreamRaw(Stream stream, int count, bool async, CancellationToken cancellationToken = default) + { + while (count > 0) + { + if (WriteSpaceLeft == 0) + await Flush(async, cancellationToken).ConfigureAwait(false); + try + { + var read = async + ? await stream.ReadAsync(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count), cancellationToken).ConfigureAwait(false) + : stream.Read(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count)); + if (read == 0) + throw new EndOfStreamException(); + WritePosition += read; + count -= read; + } + catch (Exception e) + { + throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); + } + } + Debug.Assert(count == 0); + } + + public void WriteNullTerminatedString(string s) + { + AssertASCIIOnly(s); + Debug.Assert(WriteSpaceLeft >= s.Length + 1); + WritePosition += Encoding.ASCII.GetBytes(s, 0, s.Length, Buffer, WritePosition); + WriteByte(0); + } + + public void WriteNullTerminatedString(byte[] s) + { + AssertASCIIOnly(s); + Debug.Assert(WriteSpaceLeft >= s.Length + 1); + WriteBytes(s); + WriteByte(0); + } + + #endregion + + #region Write Complex + + internal void WriteStringChunked(char[] chars, int charIndex, int charCount, + bool flush, out int charsUsed, out bool completed) + { + if (WriteSpaceLeft < _textEncoder.GetByteCount(chars, charIndex, char.IsHighSurrogate(chars[charIndex]) ? 2 : 1, flush: false)) + { + charsUsed = 0; + completed = false; + return; + } + + _textEncoder.Convert(chars, charIndex, charCount, Buffer, WritePosition, WriteSpaceLeft, + flush, out charsUsed, out var bytesUsed, out completed); + WritePosition += bytesUsed; + } + + internal unsafe void WriteStringChunked(string s, int charIndex, int charCount, + bool flush, out int charsUsed, out bool completed) + { + int bytesUsed; + + fixed (char* sPtr = s) + fixed (byte* bufPtr = Buffer) + { + if (WriteSpaceLeft < _textEncoder.GetByteCount(sPtr + charIndex, char.IsHighSurrogate(*(sPtr + charIndex)) ? 2 : 1, flush: false)) + { + charsUsed = 0; + completed = false; + return; + } + + _textEncoder.Convert(sPtr + charIndex, charCount, bufPtr + WritePosition, WriteSpaceLeft, + flush, out charsUsed, out bytesUsed, out completed); + } + + WritePosition += bytesUsed; + } + + #endregion + + #region Copy + + internal void StartCopyMode() + { + _copyMode = true; + Size -= 5; + WriteCopyDataHeader(); + } + + internal void EndCopyMode() + { + // EndCopyMode is usually called after a Flush which ended the last CopyData message. + // That Flush also wrote the header for another CopyData which we clear here. + _copyMode = false; + Size += 5; + Clear(); + } + + void WriteCopyDataHeader() + { + Debug.Assert(_copyMode); + Debug.Assert(WritePosition == 0); + WriteByte(FrontendMessageCode.CopyData); + // Leave space for the message length + WriteInt32(0); + } + + #endregion + + #region Dispose + + public void Dispose() + { + if (_disposed) + return; + + _timeoutCts.Dispose(); + _disposed = true; + } + + #endregion + + #region Misc + + internal void StartMessage(int messageLength) + { + if (!MessageLengthValidation) + return; + + if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength) + Throw(); + + // Add negative WritePosition to compensate for previous message(s) written without flushing. + _messageBytesFlushed = -WritePosition; + _messageLength = messageLength; + + void Throw() + { + throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified")); + } + } + + void AdvanceMessageBytesFlushed(int count) + { + if (!MessageLengthValidation) + return; + + if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength) + Throw(); + + _messageBytesFlushed += count; + + void Throw() + { + if (count < 0) + throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count"); + + if (_messageLength is null) + throw Connector.Break(new InvalidOperationException("No message was started")); + + if ((long)_messageBytesFlushed + count > _messageLength) + throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified")); + } + } + + internal void Clear() + { + WritePosition = 0; + _messageLength = null; + } + + /// + /// Returns all contents currently written to the buffer (but not flushed). + /// Useful for pre-generating messages. + /// + internal byte[] GetContents() + { + var buf = new byte[WritePosition]; + Array.Copy(Buffer, buf, WritePosition); + return buf; + } + + [Conditional("DEBUG")] + internal static void AssertASCIIOnly(string s) + { + foreach (var c in s) + if (c >= 128) + Debug.Fail("Method only supports ASCII strings"); + } + + [Conditional("DEBUG")] + internal static void AssertASCIIOnly(byte[] s) + { + foreach (var c in s) + if (c >= 128) + Debug.Fail("Method only supports ASCII strings"); + } + + #endregion +} diff --git a/src/Npgsql/Internal/PgBufferedConverter.cs b/src/Npgsql/Internal/PgBufferedConverter.cs new file mode 100644 index 0000000000..d7b673fb7c --- /dev/null +++ b/src/Npgsql/Internal/PgBufferedConverter.cs @@ -0,0 +1,53 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgBufferedConverter : PgConverter +{ + protected PgBufferedConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } + + protected abstract T ReadCore(PgReader reader); + protected abstract void WriteCore(PgWriter writer, T value); + + public override Size GetSize(SizeContext context, T value, ref object? writeState) + => throw new NotSupportedException(); + + public sealed override T Read(PgReader reader) + { + // We check IsAtStart first to speed up primitive reads. + if (!reader.IsAtStart && reader.ShouldBufferCurrent()) + ThrowIORequired(reader.CurrentBufferRequirement); + + return ReadCore(reader); + } + + public sealed override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => new(Read(reader)); + + internal sealed override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) + => new(Read(reader)!); + + public sealed override void Write(PgWriter writer, T value) + { + if (!writer.BufferingWrite && writer.ShouldFlush(writer.CurrentBufferRequirement)) + ThrowIORequired(writer.CurrentBufferRequirement); + + WriteCore(writer, value); + } + + public sealed override ValueTask WriteAsync(PgWriter writer, [DisallowNull] T value, CancellationToken cancellationToken = default) + { + Write(writer, value); + return new(); + } + + internal sealed override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + Write(writer, (T)value); + return new(); + } +} diff --git a/src/Npgsql/Internal/PgComposingConverterResolver.cs b/src/Npgsql/Internal/PgComposingConverterResolver.cs new file mode 100644 index 0000000000..543ef8bdbd --- /dev/null +++ b/src/Npgsql/Internal/PgComposingConverterResolver.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +abstract class PgComposingConverterResolver : PgConverterResolver +{ + readonly PgTypeId? _pgTypeId; + public PgResolverTypeInfo EffectiveTypeInfo { get; } + readonly ConcurrentDictionary _converters = new(ReferenceEqualityComparer.Instance); + + protected PgComposingConverterResolver(PgTypeId? pgTypeId, PgResolverTypeInfo effectiveTypeInfo) + { + if (pgTypeId is null && effectiveTypeInfo.PgTypeId is not null) + throw new ArgumentNullException(nameof(pgTypeId), $"Cannot be null if {nameof(effectiveTypeInfo)}.{nameof(PgTypeInfo.PgTypeId)} is not null."); + + _pgTypeId = pgTypeId; + EffectiveTypeInfo = effectiveTypeInfo; + } + + protected abstract PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId); + protected abstract PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId); + protected abstract PgConverter CreateConverter(PgConverterResolution effectiveResolution); + protected abstract PgConverterResolution? GetEffectiveResolution(T? value, PgTypeId? expectedEffectivePgTypeId); + + public override PgConverterResolution GetDefault(PgTypeId? pgTypeId) + { + PgTypeId? effectivePgTypeId = pgTypeId is not null ? GetEffectiveTypeId(pgTypeId.GetValueOrDefault()) : null; + var effectiveResolution = EffectiveTypeInfo.GetDefaultResolution(effectivePgTypeId); + return new(GetOrAdd(effectiveResolution), pgTypeId ?? _pgTypeId ?? GetPgTypeId(effectiveResolution.PgTypeId)); + } + + public override PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId) + { + PgTypeId? expectedEffectiveId = expectedPgTypeId is not null ? GetEffectiveTypeId(expectedPgTypeId.GetValueOrDefault()) : null; + if (GetEffectiveResolution(value, expectedEffectiveId) is { } resolution) + return new PgConverterResolution(GetOrAdd(resolution), expectedPgTypeId ?? _pgTypeId ?? GetPgTypeId(resolution.PgTypeId)); + + return null; + } + + public override PgConverterResolution Get(Field field) + { + var effectiveResolution = EffectiveTypeInfo.GetResolution(field with { PgTypeId = GetEffectiveTypeId(field.PgTypeId) }); + return new PgConverterResolution(GetOrAdd(effectiveResolution), field.PgTypeId); + } + + PgTypeId GetEffectiveTypeId(PgTypeId pgTypeId) + { + if (_pgTypeId == pgTypeId) + return EffectiveTypeInfo.PgTypeId.GetValueOrDefault(); + + // We have an undecided type info which is asked to resolve for a specific type id + // we'll unfortunately have to look up the effective id, this is rare though. + return GetEffectivePgTypeId(pgTypeId); + } + + PgConverter GetOrAdd(PgConverterResolution effectiveResolution) + { + (PgComposingConverterResolver Instance, PgConverterResolution EffectiveResolution) state = (this, effectiveResolution); + return (PgConverter)_converters.GetOrAdd( + effectiveResolution.Converter, + static (_, state) => state.Instance.CreateConverter(state.EffectiveResolution), + state); + } +} diff --git a/src/Npgsql/Internal/PgConverter.cs b/src/Npgsql/Internal/PgConverter.cs new file mode 100644 index 0000000000..323c572e0a --- /dev/null +++ b/src/Npgsql/Internal/PgConverter.cs @@ -0,0 +1,225 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgConverter +{ + internal DbNullPredicate DbNullPredicateKind { get; } + public bool IsDbNullable => DbNullPredicateKind is not DbNullPredicate.None; + + private protected PgConverter(Type type, bool isNullDefaultValue, bool customDbNullPredicate = false) + => DbNullPredicateKind = customDbNullPredicate ? DbNullPredicate.Custom : InferDbNullPredicate(type, isNullDefaultValue); + + /// + /// Whether this converter can handle the given format and with which buffer requirements. + /// + /// The data format. + /// Returns the buffer requirements. + /// Returns true if the given data format is supported. + /// The buffer requirements should not cover database NULL reads or writes, these are handled by the caller. + public abstract bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements); + + internal abstract Type TypeToConvert { get; } + + internal bool IsDbNullAsObject([NotNullWhen(false)] object? value, ref object? writeState) + => DbNullPredicateKind switch + { + DbNullPredicate.Null => value is null, + DbNullPredicate.None => false, + DbNullPredicate.PolymorphicNull => value is null or DBNull, + // We do the null check to keep the NotNullWhen(false) invariant. + DbNullPredicate.Custom => IsDbNullValueAsObject(value, ref writeState) || (value is null && ThrowInvalidNullValue()), + _ => ThrowDbNullPredicateOutOfRange() + }; + + private protected abstract bool IsDbNullValueAsObject(object? value, ref object? writeState); + + internal abstract Size GetSizeAsObject(SizeContext context, object value, ref object? writeState); + + internal object ReadAsObject(PgReader reader) + => ReadAsObject(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + internal ValueTask ReadAsObjectAsync(PgReader reader, CancellationToken cancellationToken = default) + => ReadAsObject(async: true, reader, cancellationToken); + + // Shared sync/async abstract to reduce virtual method table size overhead and code size for each NpgsqlConverter instantiation. + internal abstract ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken); + + internal void WriteAsObject(PgWriter writer, object value) + => WriteAsObject(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + internal ValueTask WriteAsObjectAsync(PgWriter writer, object value, CancellationToken cancellationToken = default) + => WriteAsObject(async: true, writer, value, cancellationToken); + + // Shared sync/async abstract to reduce virtual method table size overhead and code size for each NpgsqlConverter instantiation. + internal abstract ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken); + + static DbNullPredicate InferDbNullPredicate(Type type, bool isNullDefaultValue) + => type == typeof(object) || type == typeof(DBNull) + ? DbNullPredicate.PolymorphicNull + : isNullDefaultValue + ? DbNullPredicate.Null + : DbNullPredicate.None; + + internal enum DbNullPredicate : byte + { + /// Never DbNull (struct types) + None, + /// DbNull when *user code* + Custom, + /// DbNull when value is null + Null, + /// DbNull when value is null or DBNull + PolymorphicNull + } + + [DoesNotReturn] + private protected void ThrowIORequired(Size bufferRequirement) + => throw new InvalidOperationException($"Buffer requirement '{bufferRequirement}' not respected for converter '{GetType().FullName}', expected no IO to be required."); + + private protected static bool ThrowInvalidNullValue() + => throw new ArgumentNullException("value", "Null value given for non-nullable type converter"); + + private protected bool ThrowDbNullPredicateOutOfRange() + => throw new UnreachableException($"Unknown case {DbNullPredicateKind.ToString()}"); + + protected bool CanConvertBufferedDefault(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Value; + return format is DataFormat.Binary; + } +} + +public abstract class PgConverter : PgConverter +{ + private protected PgConverter(bool customDbNullPredicate) + : base(typeof(T), default(T) is null, customDbNullPredicate) { } + + protected virtual bool IsDbNullValue(T? value, ref object? writeState) => throw new NotSupportedException(); + + // Object null semantics as follows, if T is a struct (so excluding nullable) report false for null values, don't throw on the cast. + // As a result this creates symmetry with IsDbNull when we're dealing with a struct T, as it cannot be passed null at all. + private protected override bool IsDbNullValueAsObject(object? value, ref object? writeState) + => (default(T) is null || value is not null) && IsDbNullValue((T?)value, ref writeState); + + public bool IsDbNull([NotNullWhen(false)] T? value, ref object? writeState) + => DbNullPredicateKind switch + { + DbNullPredicate.Null => value is null, + DbNullPredicate.None => false, + DbNullPredicate.PolymorphicNull => value is null or DBNull, + // We do the null check to keep the NotNullWhen(false) invariant. + DbNullPredicate.Custom => IsDbNullValue(value, ref writeState) || (value is null && ThrowInvalidNullValue()), + _ => ThrowDbNullPredicateOutOfRange() + }; + + public abstract T Read(PgReader reader); + public abstract ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default); + + public abstract Size GetSize(SizeContext context, [DisallowNull]T value, ref object? writeState); + public abstract void Write(PgWriter writer, [DisallowNull] T value); + public abstract ValueTask WriteAsync(PgWriter writer, [DisallowNull] T value, CancellationToken cancellationToken = default); + + internal sealed override Type TypeToConvert => typeof(T); + + internal sealed override Size GetSizeAsObject(SizeContext context, object value, ref object? writeState) + => GetSize(context, (T)value, ref writeState); +} + +static class PgConverterExtensions +{ + public static Size? GetSizeOrDbNull(this PgConverter converter, DataFormat format, Size writeRequirement, T? value, ref object? writeState) + { + if (converter.IsDbNull(value, ref writeState)) + return null; + + if (writeRequirement is { Kind: SizeKind.Exact, Value: var byteCount }) + return byteCount; + var size = converter.GetSize(new(format, writeRequirement), value, ref writeState); + + switch (size.Kind) + { + case SizeKind.UpperBound: + ThrowHelper.ThrowInvalidOperationException($"{nameof(SizeKind.UpperBound)} is not a valid return value for GetSize."); + break; + case SizeKind.Unknown: + // Not valid yet. + ThrowHelper.ThrowInvalidOperationException($"{nameof(SizeKind.Unknown)} is not a valid return value for GetSize."); + break; + } + + return size; + } + + public static Size? GetSizeOrDbNullAsObject(this PgConverter converter, DataFormat format, Size writeRequirement, object? value, ref object? writeState) + { + if (converter.IsDbNullAsObject(value, ref writeState)) + return null; + + if (writeRequirement is { Kind: SizeKind.Exact, Value: var byteCount }) + return byteCount; + var size = converter.GetSizeAsObject(new(format, writeRequirement), value, ref writeState); + + switch (size.Kind) + { + case SizeKind.UpperBound: + ThrowHelper.ThrowInvalidOperationException($"{nameof(SizeKind.UpperBound)} is not a valid return value for GetSize."); + break; + case SizeKind.Unknown: + // Not valid yet. + ThrowHelper.ThrowInvalidOperationException($"{nameof(SizeKind.Unknown)} is not a valid return value for GetSize."); + break; + } + + return size; + } + + internal static PgConverter UnsafeDowncast(this PgConverter converter) + { + // Justification: avoid perf cost of casting to a known base class type per read/write, see callers. + Debug.Assert(converter is PgConverter); + return Unsafe.As>(converter); + } +} + +public readonly struct SizeContext +{ + [SetsRequiredMembers] + public SizeContext(DataFormat format, Size bufferRequirement) + { + Format = format; + BufferRequirement = bufferRequirement; + } + + public required Size BufferRequirement { get; init; } + public DataFormat Format { get; } +} + +class MultiWriteState : IDisposable +{ + public required ArrayPool<(Size Size, object? WriteState)>? ArrayPool { get; init; } + public required ArraySegment<(Size Size, object? WriteState)> Data { get; init; } + public required bool AnyWriteState { get; init; } + + public void Dispose() + { + if (Data.Array is not { } array) + return; + + if (AnyWriteState) + { + for (var i = Data.Offset; i < array.Length; i++) + if (array[i].WriteState is IDisposable disposable) + disposable.Dispose(); + + Array.Clear(Data.Array, Data.Offset, Data.Count); + } + + ArrayPool?.Return(Data.Array); + } +} diff --git a/src/Npgsql/Internal/PgConverterResolver.cs b/src/Npgsql/Internal/PgConverterResolver.cs new file mode 100644 index 0000000000..5fbe699017 --- /dev/null +++ b/src/Npgsql/Internal/PgConverterResolver.cs @@ -0,0 +1,111 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgConverterResolver +{ + private protected PgConverterResolver() { } + + /// + /// Gets the appropriate converter solely based on PgTypeId. + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are cached once used. + /// Array or other collection converters depend on this to cache their own converter - which wraps the element converter - with the cache key being the element converter reference. + /// + public abstract PgConverterResolution GetDefault(PgTypeId? pgTypeId); + + /// + /// Gets the appropriate converter to read with based on the given field info. + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are cached once used. + /// Array or other collection converters depend on this to cache their own converter - which wraps the element converter - with the cache key being the element converter reference. + /// + public virtual PgConverterResolution Get(Field field) => GetDefault(field.PgTypeId); + + internal abstract Type TypeToConvert { get; } + + internal abstract PgConverterResolution? GetAsObjectInternal(PgTypeInfo typeInfo, object? value, PgTypeId? expectedPgTypeId); + + internal PgConverterResolution GetDefaultInternal(bool validate, bool expectPortableTypeIds, PgTypeId? pgTypeId) + { + var resolution = GetDefault(pgTypeId); + if (validate) + Validate(nameof(GetDefault), resolution, TypeToConvert, pgTypeId, expectPortableTypeIds); + return resolution; + } + + internal PgConverterResolution GetInternal(PgTypeInfo typeInfo, Field field) + { + var resolution = Get(field); + if (typeInfo.ValidateResolution) + Validate(nameof(Get), resolution, TypeToConvert, field.PgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } + + private protected static void Validate(string methodName, PgConverterResolution resolution, Type expectedTypeToConvert, PgTypeId? expectedPgTypeId, bool expectPortableTypeIds) + { + if (resolution.Converter is null) + throw new InvalidOperationException($"'{methodName}' returned a null {nameof(PgConverterResolution.Converter)} unexpectedly."); + + // We allow object resolvers to return any converter, this is to help: + // - Composing resolvers being able to use converter type identity (instead of everything being CastingConverter). + // - Reduce indirection by allowing disparate type converters to be returned directly. + // As a consequence any object typed resolver info is always a boxing one, to reduce the chances invalid casts to PgConverter are attempted. + if (expectedTypeToConvert != typeof(object) && resolution.Converter.TypeToConvert != expectedTypeToConvert) + throw new InvalidOperationException($"'{methodName}' returned a {nameof(PgConverterResolution.Converter)} of type {resolution.Converter.TypeToConvert} instead of {expectedTypeToConvert} unexpectedly."); + + if (expectPortableTypeIds && resolution.PgTypeId.IsOid || !expectPortableTypeIds && resolution.PgTypeId.IsDataTypeName) + throw new InvalidOperationException($"{methodName}' returned a resolution with a {nameof(PgConverterResolution.PgTypeId)} that was not in canonical form."); + + if (expectedPgTypeId is not null && resolution.PgTypeId != expectedPgTypeId) + throw new InvalidOperationException( + $"'{methodName}' returned a different {nameof(PgConverterResolution.PgTypeId)} than was passed in as expected." + + $" If such a mismatch occurs an exception should be thrown instead."); + } + + protected ArgumentOutOfRangeException CreateUnsupportedPgTypeIdException(PgTypeId pgTypeId) + => new(nameof(pgTypeId), pgTypeId, "Unsupported PgTypeId."); +} + +public abstract class PgConverterResolver : PgConverterResolver +{ + /// + /// Gets the appropriate converter to write with based on the given value. + /// + /// + /// + /// The converter resolution. + /// + /// Implementations should not return new instances of the possible converters that can be returned, instead its expected these are + /// cached once used. Array or other collection converters depend on this to cache their own converter - which wraps the element + /// converter - with the cache key being the element converter reference. + /// + public abstract PgConverterResolution? Get(T? value, PgTypeId? expectedPgTypeId); + + internal sealed override Type TypeToConvert => typeof(T); + + internal PgConverterResolution? GetInternal(PgTypeInfo typeInfo, T? value, PgTypeId? expectedPgTypeId) + { + var resolution = Get(value, expectedPgTypeId); + if (typeInfo.ValidateResolution && resolution is not null) + Validate(nameof(Get), resolution.GetValueOrDefault(), TypeToConvert, expectedPgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } + + internal sealed override PgConverterResolution? GetAsObjectInternal(PgTypeInfo typeInfo, object? value, PgTypeId? expectedPgTypeId) + { + var resolution = Get(value is null ? default : (T)value, expectedPgTypeId); + if (typeInfo.ValidateResolution && resolution is not null) + Validate(nameof(Get), resolution.GetValueOrDefault(), TypeToConvert, expectedPgTypeId, typeInfo.Options.PortableTypeIds); + return resolution; + } +} diff --git a/src/Npgsql/Internal/PgReader.cs b/src/Npgsql/Internal/PgReader.cs new file mode 100644 index 0000000000..90f5b53e14 --- /dev/null +++ b/src/Npgsql/Internal/PgReader.cs @@ -0,0 +1,844 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public class PgReader +{ + // We don't want to add a ton of memory pressure for large strings. + internal const int MaxPreparedTextReaderSize = 1024 * 64; + + readonly NpgsqlReadBuffer _buffer; + + bool _resumable; + + byte[]? _pooledArray; + NpgsqlReadBuffer.ColumnStream? _userActiveStream; + PreparedTextReader? _preparedTextReader; + + long _fieldStartPos; + Size _fieldBufferRequirement; + DataFormat _fieldFormat; + int _fieldSize; + + // This position is relative to _fieldStartPos, which is why it can be an int. + int _currentStartPos; + Size _currentBufferRequirement; + int _currentSize; + + // GetChars Internal state + TextReader? _charsReadReader; + int _charsRead; + + // GetChars User state + int? _charsReadOffset; + ArraySegment? _charsReadBuffer; + + bool _requiresCleanup; + // The field reading process of doing init/commit and startread/endread pairs is very perf sensitive. + // So this is used in Commit as a fast-path alternative to FieldRemaining to detect if the field was consumed succesfully. + bool _fieldConsumed; + + internal PgReader(NpgsqlReadBuffer buffer) + { + _buffer = buffer; + _fieldStartPos = -1; + _currentSize = -1; + } + + internal long FieldStartPos => _fieldStartPos; + internal int FieldSize => _fieldSize; + internal bool Initialized => _fieldStartPos is not -1; + internal int FieldOffset => (int)(_buffer.CumulativeReadPosition - _fieldStartPos); + internal int FieldRemaining => FieldSize - FieldOffset; + + bool HasCurrent => _currentSize is not -1; + int CurrentSize => HasCurrent ? _currentSize : _fieldSize; + + public ValueMetadata Current => new() { Size = CurrentSize, Format = _fieldFormat, BufferRequirement = CurrentBufferRequirement }; + public int CurrentRemaining => HasCurrent ? _currentSize - CurrentOffset : FieldRemaining; + + internal Size CurrentBufferRequirement => HasCurrent ? _currentBufferRequirement : _fieldBufferRequirement; + int CurrentOffset => FieldOffset - _currentStartPos; + + internal bool IsAtStart => FieldOffset is 0; + internal bool Resumable => _resumable; + public bool IsResumed => Resumable && CurrentSize != CurrentRemaining; + + ArrayPool ArrayPool => ArrayPool.Shared; + + // Here for testing purposes + internal void BreakConnection() => throw _buffer.Connector.Break(new Exception("Broken")); + + internal void Revert(int size, int startPos, Size bufferRequirement) + { + if (startPos > FieldOffset) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(startPos), "Can't revert forwardly"); + + _currentStartPos = startPos; + _currentBufferRequirement = bufferRequirement; + _currentSize = size; + } + + void CheckBounds(int count) + { + if (NpgsqlReadBuffer.BufferBoundsChecks) + Core(count); + + [MethodImpl(MethodImplOptions.NoInlining)] + void Core(int count) + { + if (count > FieldRemaining) + ThrowHelper.ThrowInvalidOperationException("Attempt to read past the end of the field."); + } + } + + public byte ReadByte() + { + CheckBounds(sizeof(byte)); + var result = _buffer.ReadByte(); + return result; + } + + public short ReadInt16() + { + CheckBounds(sizeof(short)); + var result = _buffer.ReadInt16(); + return result; + } + + public int ReadInt32() + { + CheckBounds(sizeof(int)); + var result = _buffer.ReadInt32(); + return result; + } + + public long ReadInt64() + { + CheckBounds(sizeof(long)); + var result = _buffer.ReadInt64(); + return result; + } + + public ushort ReadUInt16() + { + CheckBounds(sizeof(ushort)); + var result = _buffer.ReadUInt16(); + return result; + } + + public uint ReadUInt32() + { + CheckBounds(sizeof(uint)); + var result = _buffer.ReadUInt32(); + return result; + } + + public ulong ReadUInt64() + { + CheckBounds(sizeof(ulong)); + var result = _buffer.ReadUInt64(); + return result; + } + + public float ReadFloat() + { + CheckBounds(sizeof(float)); + var result = _buffer.ReadSingle(); + return result; + } + + public double ReadDouble() + { + CheckBounds(sizeof(double)); + var result = _buffer.ReadDouble(); + return result; + } + + public void Read(Span destination) + { + CheckBounds(destination.Length); + _buffer.ReadBytes(destination); + } + + public async ValueTask ReadNullTerminatedStringAsync(Encoding encoding, CancellationToken cancellationToken = default) + { + var result = await _buffer.ReadNullTerminatedString(encoding, async: true, cancellationToken).ConfigureAwait(false); + // Can only check after the fact. + CheckBounds(0); + return result; + } + + public string ReadNullTerminatedString(Encoding encoding) + { + var result = _buffer.ReadNullTerminatedString(encoding, async: false, CancellationToken.None).GetAwaiter().GetResult(); + CheckBounds(0); + return result; + } + public Stream GetStream(int? length = null) => GetColumnStream(false, length); + + internal Stream GetStream(bool canSeek, int? length = null) => GetColumnStream(canSeek, length); + + NpgsqlReadBuffer.ColumnStream GetColumnStream(bool canSeek = false, int? length = null) + { + if (length > CurrentRemaining) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(length), "Length is larger than the current remaining value size"); + + _requiresCleanup = true; + // This will cause any previously handed out StreamReaders etc to throw, as intended. + if (_userActiveStream is not null) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + + length ??= CurrentRemaining; + CheckBounds(length.GetValueOrDefault()); + return _userActiveStream = _buffer.CreateStream(length.GetValueOrDefault(), canSeek && length <= _buffer.ReadBytesLeft); + } + + public TextReader GetTextReader(Encoding encoding) + => GetTextReader(async: false, encoding, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask GetTextReaderAsync(Encoding encoding, CancellationToken cancellationToken) + => GetTextReader(async: true, encoding, cancellationToken); + + async ValueTask GetTextReader(bool async, Encoding encoding, CancellationToken cancellationToken) + { + _requiresCleanup = true; + if (CurrentRemaining > _buffer.ReadBytesLeft || CurrentRemaining > MaxPreparedTextReaderSize) + return new StreamReader(GetColumnStream(), encoding, detectEncodingFromByteOrderMarks: false); + + if (_preparedTextReader is { IsDisposed: false }) + { + _preparedTextReader.Dispose(); + _preparedTextReader = null; + } + + _preparedTextReader ??= new PreparedTextReader(); + _preparedTextReader.Init( + encoding.GetString(async + ? await ReadBytesAsync(CurrentRemaining, cancellationToken).ConfigureAwait(false) + : ReadBytes(CurrentRemaining)), GetColumnStream(canSeek: false, 0)); + return _preparedTextReader; + } + + public ValueTask ReadBytesAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var count = buffer.Length; + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + _buffer.Buffer.AsSpan(offset, count).CopyTo(buffer.Span); + _buffer.ReadPosition += count; + return new(); + } + + return Slow(count, buffer, cancellationToken); + + async ValueTask Slow(int count, Memory buffer, CancellationToken cancellationToken) + { + var stream = _buffer.CreateStream(count, canSeek: false); + await using var _ = stream.ConfigureAwait(false); + await stream.ReadExactlyAsync(buffer, cancellationToken).ConfigureAwait(false); + } + } + + public void ReadBytes(Span buffer) + { + var count = buffer.Length; + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + _buffer.Buffer.AsSpan(offset, count).CopyTo(buffer); + _buffer.ReadPosition += count; + return; + } + + Slow(count, buffer); + + void Slow(int count, Span buffer) + { + using var stream = _buffer.CreateStream(count, canSeek: false); + stream.ReadExactly(buffer); + } + } + + public bool TryReadBytes(int count, out ReadOnlySpan bytes) + { + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + bytes = new ReadOnlySpan(_buffer.Buffer, offset, count); + _buffer.ReadPosition += count; + return true; + } + bytes = default; + return false; + } + + public bool TryReadBytes(int count, out ReadOnlyMemory bytes) + { + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + bytes = new ReadOnlyMemory(_buffer.Buffer, offset, count); + _buffer.ReadPosition += count; + return true; + } + bytes = default; + return false; + } + + /// ReadBytes without memory management, the next read invalidates the underlying buffer(s), only use this for intermediate transformations. + public ReadOnlySequence ReadBytes(int count) + { + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + var result = new ReadOnlySequence(_buffer.Buffer, offset, count); + _buffer.ReadPosition += count; + return result; + } + + var array = RentArray(count); + ReadBytes(array.AsSpan(0, count)); + return new(array, 0, count); + } + + /// ReadBytesAsync without memory management, the next read invalidates the underlying buffer(s), only use this for intermediate transformations. + public async ValueTask> ReadBytesAsync(int count, CancellationToken cancellationToken = default) + { + CheckBounds(count); + var offset = _buffer.ReadPosition; + var remaining = _buffer.FilledBytes - offset; + if (remaining >= count) + { + var result = new ReadOnlySequence(_buffer.Buffer, offset, count); + _buffer.ReadPosition += count; + return result; + } + + var array = RentArray(count); + await ReadBytesAsync(array.AsMemory(0, count), cancellationToken).ConfigureAwait(false); + return new(array, 0, count); + } + + public void Rewind(int count) + { + // Shut down any streaming going on on the column + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + + if (_buffer.ReadPosition < count) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Cannot rewind further than the buffer start"); + + if (CurrentOffset < count) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Cannot rewind further than the current field offset"); + + _buffer.ReadPosition -= count; + } + + /// + /// + /// + /// + /// The stream length, if any + async ValueTask DisposeUserActiveStream(bool async) + { + if (StreamActive) + { + if (async) + await _userActiveStream.DisposeAsync().ConfigureAwait(false); + else + _userActiveStream.Dispose(); + } + + _userActiveStream = null; + } + + internal int CharsRead => _charsRead; + internal bool CharsReadActive => _charsReadOffset is not null; + + internal void GetCharsReadInfo(Encoding encoding, out int charsRead, out TextReader reader, out int charsOffset, out ArraySegment? buffer) + { + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); + + charsRead = _charsRead; + reader = _charsReadReader ??= GetTextReader(encoding); + charsOffset = _charsReadOffset ?? 0; + buffer = _charsReadBuffer; + } + + internal void RestartCharsRead() + { + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); + + switch (_charsReadReader) + { + case PreparedTextReader reader: + reader.Restart(); + break; + case StreamReader reader: + reader.BaseStream.Seek(0, SeekOrigin.Begin); + reader.DiscardBufferedData(); + break; + } + _charsRead = 0; + } + + internal void AdvanceCharsRead(int charsRead) => _charsRead += charsRead; + + internal void StartCharsRead(int dataOffset, ArraySegment? buffer) + { + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Wasn't initialized as resumed"); + + _charsReadOffset = dataOffset; + _charsReadBuffer = buffer; + } + + internal void EndCharsRead() + { + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Wasn't initialized as resumed"); + + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); + + _charsReadOffset = null; + _charsReadBuffer = null; + } + + internal PgReader Init(int fieldLength, DataFormat format, bool resumable = false) + { + if (Initialized) + { + if (resumable) + { + if (Resumable) + return this; + _resumable = true; + } + else + { + if (!IsAtStart) + ThrowHelper.ThrowInvalidOperationException("Cannot be initialized to be non-resumable until a commit is issued."); + _resumable = false; + } + } + + Debug.Assert(!_requiresCleanup, "Reader wasn't properly committed before next init"); + + _fieldStartPos = _buffer.CumulativeReadPosition; + _fieldFormat = format; + _fieldSize = fieldLength; + _resumable = resumable; + _fieldConsumed = false; + return this; + } + + internal void StartRead(Size bufferRequirement) + { + Debug.Assert(FieldSize >= 0); + _fieldBufferRequirement = bufferRequirement; + if (ShouldBuffer(bufferRequirement)) + Buffer(bufferRequirement); + } + + internal ValueTask StartReadAsync(Size bufferRequirement, CancellationToken cancellationToken) + { + Debug.Assert(FieldSize >= 0); + _fieldBufferRequirement = bufferRequirement; + return ShouldBuffer(bufferRequirement) ? BufferAsync(bufferRequirement, cancellationToken) : new(); + } + + internal void EndRead() + { + if (_resumable || StreamActive) + return; + + // If it was upper bound we should consume. + if (_fieldBufferRequirement is { Kind: SizeKind.UpperBound }) + { + Consume(FieldRemaining); + return; + } + + if (FieldOffset != FieldSize) + ThrowNotConsumedExactly(); + + _fieldConsumed = true; + } + + internal ValueTask EndReadAsync() + { + if (_resumable || StreamActive) + return new(); + + // If it was upper bound we should consume. + if (_fieldBufferRequirement is { Kind: SizeKind.UpperBound }) + return ConsumeAsync(FieldRemaining); + + if (FieldOffset != FieldSize) + ThrowNotConsumedExactly(); + + _fieldConsumed = true; + return new(); + } + + internal async ValueTask BeginNestedRead(bool async, int size, Size bufferRequirement, CancellationToken cancellationToken = default) + { + if (size > CurrentRemaining) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(size), "Cannot begin a read for a larger size than the current remaining size."); + + if (size < 0) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(size), "Cannot be negative"); + + var previousSize = CurrentSize; + var previousStartPos = _currentStartPos; + var previousBufferRequirement = CurrentBufferRequirement; + _currentSize = size; + _currentBufferRequirement = bufferRequirement; + _currentStartPos = FieldOffset; + + await Buffer(async, bufferRequirement, cancellationToken).ConfigureAwait(false); + return new NestedReadScope(async, this, previousSize, previousStartPos, previousBufferRequirement); + } + + public NestedReadScope BeginNestedRead(int size, Size bufferRequirement) + => BeginNestedRead(async: false, size, bufferRequirement, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask BeginNestedReadAsync(int size, Size bufferRequirement, CancellationToken cancellationToken = default) + => BeginNestedRead(async: true, size, bufferRequirement, cancellationToken); + + internal void Seek(int offset) + { + if (CurrentOffset > offset) + Rewind(CurrentOffset - offset); + else if (CurrentOffset < offset) + Consume(offset - CurrentOffset); + } + + internal async ValueTask Consume(bool async, int? count = null, CancellationToken cancellationToken = default) + { + if (count <= 0 || FieldSize < 0 || FieldRemaining == 0) + return; + + var remaining = count ?? CurrentRemaining; + CheckBounds(remaining); + + var origOffset = FieldOffset; + // A breaking exception unwind from a nested scope should not try to consume its remaining data. + if (!_buffer.Connector.IsBroken) + await _buffer.Skip(remaining, async).ConfigureAwait(false); + + Debug.Assert(FieldRemaining == FieldSize - origOffset - remaining); + } + + public void Consume(int? count = null) => Consume(async: false, count).GetAwaiter().GetResult(); + public ValueTask ConsumeAsync(int? count = null, CancellationToken cancellationToken = default) => Consume(async: true, count, cancellationToken); + + [MemberNotNullWhen(true, nameof(_userActiveStream))] + bool StreamActive => _userActiveStream is { IsDisposed: false }; + internal void ThrowIfStreamActive() + { + if (StreamActive) + ThrowHelper.ThrowInvalidOperationException("A stream is already open for this reader"); + } + + internal bool CommitHasIO(bool resuming) => Initialized && !resuming && FieldRemaining > 0; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Commit(bool resuming) + { + if (!Initialized) + return; + + if (resuming) + { + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Cannot resume a non-resumable read."); + return; + } + + // We don't rely on CurrentRemaining, just to make sure we consume fully in the event of a nested scope not being disposed. + // Also shut down any streaming, pooled arrays etc. + if (_requiresCleanup || (!_fieldConsumed && FieldRemaining > 0)) + { + CommitSlow(); + return; + } + + _fieldStartPos = -1; + Debug.Assert(!Initialized); + + // These will always be re-initialized by Init() + // _fieldSize = default; + // _fieldFormat = default; + // _resumable = default; + // _fieldCompleted = default; + + if (HasCurrent) + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + Debug.Assert(!HasCurrent); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + void CommitSlow() + { + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + { + if (StreamActive) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + + if (_pooledArray is not null) + { + ArrayPool.Return(_pooledArray); + _pooledArray = null; + } + + if (_charsReadReader is not null) + { + _charsReadReader.Dispose(); + _charsReadReader = null; + _charsRead = default; + } + _requiresCleanup = false; + } + + Consume(async: false, count: FieldRemaining).GetAwaiter().GetResult(); + + _fieldStartPos = -1; + Debug.Assert(!Initialized); + + // These will always be re-initialized by Init() + // _fieldSize = default; + // _fieldFormat = default; + // _resumable = default; + // _fieldCompleted = default; + + if (HasCurrent) + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + Debug.Assert(!HasCurrent); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ValueTask CommitAsync(bool resuming) + { + if (!Initialized) + return new(); + + if (resuming) + { + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Cannot resume a non-resumable read."); + return new(); + } + + // We don't rely on CurrentRemaining, just to make sure we consume fully in the event of a nested scope not being disposed. + // Also shut down any streaming, pooled arrays etc. + if (_requiresCleanup || (!_fieldConsumed && FieldRemaining > 0)) + return CommitSlow(); + + _fieldStartPos = -1; + Debug.Assert(!Initialized); + + // These will always be re-initialized by Init() + // _fieldSize = default; + // _fieldFormat = default; + // _resumable = default; + // _fieldCompleted = default; + + if (HasCurrent) + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + Debug.Assert(!HasCurrent); + } + + return new(); + + async ValueTask CommitSlow() + { + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + { + if (StreamActive) + await DisposeUserActiveStream(async: true).ConfigureAwait(false); + + if (_pooledArray is not null) + { + ArrayPool.Return(_pooledArray); + _pooledArray = null; + } + + if (_charsReadReader is not null) + { + _charsReadReader.Dispose(); + _charsReadReader = null; + _charsRead = default; + } + _requiresCleanup = false; + } + + await Consume(async: true, count: FieldRemaining).ConfigureAwait(false); + + _fieldStartPos = -1; + Debug.Assert(!Initialized); + + // These will always be re-initialized by Init() + // _fieldSize = default; + // _fieldFormat = default; + // _resumable = default; + // _fieldCompleted = default; + + if (HasCurrent) + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = -1; + Debug.Assert(!HasCurrent); + } + } + } + + byte[] RentArray(int count) + { + _requiresCleanup = true; + var pooledArray = _pooledArray; + if (pooledArray is not null) + { + if (pooledArray.Length >= count) + return pooledArray; + ArrayPool.Return(pooledArray); + } + var array = _pooledArray = ArrayPool.Rent(count); + return array; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + int GetBufferRequirementByteCount(Size bufferRequirement) + => bufferRequirement is { Kind: SizeKind.UpperBound } + ? Math.Min(CurrentRemaining, bufferRequirement.Value) + : bufferRequirement.GetValueOrDefault(); + + internal bool ShouldBufferCurrent() => ShouldBuffer(CurrentBufferRequirement); + + public bool ShouldBuffer(Size bufferRequirement) + => ShouldBuffer(GetBufferRequirementByteCount(bufferRequirement)); + public bool ShouldBuffer(int byteCount) + { + return _buffer.ReadBytesLeft < byteCount && ShouldBufferSlow(); + + [MethodImpl(MethodImplOptions.NoInlining)] + bool ShouldBufferSlow() + { + if (byteCount > _buffer.Size) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the buffer size, this can never succeed by buffering data but requires a larger buffer size instead."); + if (byteCount > CurrentRemaining) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the remaining length of the value, make sure the value is always at least this size or use an upper bound requirement instead."); + + return true; + } + } + + public void Buffer(Size bufferRequirement) + => Buffer(GetBufferRequirementByteCount(bufferRequirement)); + public void Buffer(int byteCount) => _buffer.Ensure(byteCount); + + public ValueTask BufferAsync(Size bufferRequirement, CancellationToken cancellationToken) + => BufferAsync(GetBufferRequirementByteCount(bufferRequirement), cancellationToken); + public ValueTask BufferAsync(int byteCount, CancellationToken cancellationToken) => _buffer.EnsureAsync(byteCount); + + internal ValueTask Buffer(bool async, Size bufferRequirement, CancellationToken cancellationToken) + => Buffer(async, GetBufferRequirementByteCount(bufferRequirement), cancellationToken); + internal ValueTask Buffer(bool async, int byteCount, CancellationToken cancellationToken) + { + if (async) + return BufferAsync(byteCount, cancellationToken); + + Buffer(byteCount); + return new(); + } + + void ThrowNotConsumedExactly() => + throw _buffer.Connector.Break( + new InvalidOperationException( + FieldOffset < FieldSize + ? $"The read on this field has not consumed all of its bytes (pos: {FieldOffset}, len: {FieldSize})" + : $"The read on this field has consumed all of its bytes and read into the subsequent bytes (pos: {FieldOffset}, len: {FieldSize})")); +} + +public readonly struct NestedReadScope : IDisposable, IAsyncDisposable +{ + readonly PgReader _reader; + readonly int _previousSize; + readonly int _previousStartPos; + readonly Size _previousBufferRequirement; + readonly bool _async; + + internal NestedReadScope(bool async, PgReader reader, int previousSize, int previousStartPos, Size previousBufferRequirement) + { + _async = async; + _reader = reader; + _previousSize = previousSize; + _previousStartPos = previousStartPos; + _previousBufferRequirement = previousBufferRequirement; + } + + public void Dispose() + { + if (_async) + ThrowHelper.ThrowInvalidOperationException("Cannot synchronously dispose async scopes, call DisposeAsync instead."); + DisposeAsync().GetAwaiter().GetResult(); + } + + public ValueTask DisposeAsync() + { + if (_reader.CurrentRemaining > 0) + { + if (_async) + return AsyncCore(_reader, _previousSize, _previousStartPos, _previousBufferRequirement); + + _reader.Consume(); + } + _reader.Revert(_previousSize, _previousStartPos, _previousBufferRequirement); + return new(); + + static async ValueTask AsyncCore(PgReader reader, int previousSize, int previousStartPos, Size previousBufferRequirement) + { + await reader.ConsumeAsync().ConfigureAwait(false); + reader.Revert(previousSize, previousStartPos, previousBufferRequirement); + } + } +} diff --git a/src/Npgsql/Internal/PgSerializerOptions.cs b/src/Npgsql/Internal/PgSerializerOptions.cs new file mode 100644 index 0000000000..b79b5757ec --- /dev/null +++ b/src/Npgsql/Internal/PgSerializerOptions.cs @@ -0,0 +1,165 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using Npgsql.Internal.Postgres; +using Npgsql.NameTranslation; +using Npgsql.PostgresTypes; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public sealed class PgSerializerOptions +{ + /// + /// Used by GetSchema to be able to attempt to resolve all type catalog types without exceptions. + /// + [field: ThreadStatic] + internal static bool IntrospectionCaller { get; set; } + + readonly PgTypeInfoResolverChain _resolverChain; + readonly Func? _timeZoneProvider; + IPgTypeInfoResolver? _typeInfoResolver; + object? _typeInfoCache; + + internal PgSerializerOptions(NpgsqlDatabaseInfo databaseInfo, PgTypeInfoResolverChain? resolverChain = null, Func? timeZoneProvider = null) + { + _resolverChain = resolverChain ?? new(); + _timeZoneProvider = timeZoneProvider; + DatabaseInfo = databaseInfo; + UnspecifiedDBNullTypeInfo = new(this, new Converters.Internal.VoidConverter(), DataTypeName.Unspecified, unboxedType: typeof(DBNull)); + } + + internal PgTypeInfo UnspecifiedDBNullTypeInfo { get; } + + PostgresType? _textPgType; + internal PostgresType TextPgType => _textPgType ??= DatabaseInfo.GetPostgresType(DataTypeNames.Text); + + // Used purely for type mapping, where we don't have a full set of types but resolvers might know enough. + readonly bool _introspectionInstance; + internal bool IntrospectionMode + { + get => _introspectionInstance || IntrospectionCaller; + init => _introspectionInstance = value; + } + + /// Whether options should return a portable identifier (data type name) to prevent any generated id (oid) confusion across backends, this comes with a perf penalty. + internal bool PortableTypeIds { get; init; } + internal NpgsqlDatabaseInfo DatabaseInfo { get; } + + public string TimeZone => _timeZoneProvider?.Invoke() ?? throw new NotSupportedException("TimeZone was not configured."); + public Encoding TextEncoding { get; init; } = Encoding.UTF8; + public IPgTypeInfoResolver TypeInfoResolver + { + get => _typeInfoResolver ??= new ChainTypeInfoResolver(_resolverChain); + internal init => _typeInfoResolver = value; + } + public bool EnableDateTimeInfinityConversions { get; init; } = true; + + public ArrayNullabilityMode ArrayNullabilityMode { get; init; } = ArrayNullabilityMode.Never; + public INpgsqlNameTranslator DefaultNameTranslator { get; init; } = NpgsqlSnakeCaseNameTranslator.Instance; + + public static bool IsWellKnownTextType(Type type) + { + type = type.IsValueType ? Nullable.GetUnderlyingType(type) ?? type : type; + return Array.IndexOf([ + typeof(string), typeof(char), + typeof(char[]), typeof(ReadOnlyMemory), typeof(ArraySegment), + typeof(byte[]), typeof(ReadOnlyMemory) + ], type) != -1 || typeof(Stream).IsAssignableFrom(type); + } + + internal bool RangesEnabled => _resolverChain.RangesEnabled; + internal bool MultirangesEnabled => _resolverChain.MultirangesEnabled; + internal bool ArraysEnabled => _resolverChain.ArraysEnabled; + + // We don't verify the kind of pgTypeId we get, it'll throw if it's incorrect. + // It's up to the caller to call GetCanonicalTypeId if they want to use an oid instead of a DataTypeName. + // This also makes it easier to realize it should be a cached value if infos for different CLR types are requested for the same + // pgTypeId. Effectively it should be 'impossible' to get the wrong kind via any PgConverterOptions api which is what this is mainly + // for. + PgTypeInfo? GetTypeInfoCore(Type? type, PgTypeId? pgTypeId, bool defaultTypeFallback) + => PortableTypeIds + ? ((TypeInfoCache)(_typeInfoCache ??= new TypeInfoCache(this))).GetOrAddInfo(type, pgTypeId?.DataTypeName, defaultTypeFallback) + : ((TypeInfoCache)(_typeInfoCache ??= new TypeInfoCache(this))).GetOrAddInfo(type, pgTypeId?.Oid, defaultTypeFallback); + + public PgTypeInfo? GetDefaultTypeInfo(PostgresType pgType) + => GetTypeInfoCore(null, ToCanonicalTypeId(pgType), false); + + public PgTypeInfo? GetDefaultTypeInfo(PgTypeId pgTypeId) + => GetTypeInfoCore(null, pgTypeId, false); + + public PgTypeInfo? GetTypeInfo(Type type, PostgresType pgType) + => GetTypeInfoCore(type, ToCanonicalTypeId(pgType), false); + + public PgTypeInfo? GetTypeInfo(Type type, PgTypeId? pgTypeId = null) + => GetTypeInfoCore(type, pgTypeId, false); + + public PgTypeInfo? GetObjectOrDefaultTypeInfo(PostgresType pgType) + => GetTypeInfoCore(typeof(object), ToCanonicalTypeId(pgType), true); + + public PgTypeInfo? GetObjectOrDefaultTypeInfo(PgTypeId pgTypeId) + => GetTypeInfoCore(typeof(object), pgTypeId, true); + + // If a given type id is in the opposite form than what was expected it will be mapped according to the requirement. + internal PgTypeId GetCanonicalTypeId(PgTypeId pgTypeId) + => PortableTypeIds ? DatabaseInfo.GetDataTypeName(pgTypeId) : DatabaseInfo.GetOid(pgTypeId); + + // If a given type id is in the opposite form than what was expected it will be mapped according to the requirement. + internal PgTypeId ToCanonicalTypeId(PostgresType pgType) + => PortableTypeIds ? pgType.DataTypeName : (Oid)pgType.OID; + + public PgTypeId GetArrayTypeId(PgTypeId elementTypeId) + { + // Static affordance to help the global type mapper. + if (PortableTypeIds && elementTypeId.IsDataTypeName) + return elementTypeId.DataTypeName.ToArrayName(); + + return ToCanonicalTypeId(DatabaseInfo.GetPostgresType(elementTypeId).Array + ?? throw new NotSupportedException("Cannot resolve array type id")); + } + + public PgTypeId GetArrayElementTypeId(PgTypeId arrayTypeId) + { + // Static affordance to help the global type mapper. + if (PortableTypeIds && arrayTypeId.IsDataTypeName && arrayTypeId.DataTypeName.UnqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal)) + return new DataTypeName(arrayTypeId.DataTypeName.Schema + arrayTypeId.DataTypeName.UnqualifiedNameSpan.Slice(1).ToString()); + + return ToCanonicalTypeId((DatabaseInfo.GetPostgresType(arrayTypeId) as PostgresArrayType)?.Element + ?? throw new NotSupportedException("Cannot resolve array element type id")); + } + + public PgTypeId GetRangeTypeId(PgTypeId subtypeTypeId) => + ToCanonicalTypeId(DatabaseInfo.GetPostgresType(subtypeTypeId).Range + ?? throw new NotSupportedException("Cannot resolve range type id")); + + public PgTypeId GetRangeSubtypeTypeId(PgTypeId rangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(rangeTypeId) as PostgresRangeType)?.Subtype + ?? throw new NotSupportedException("Cannot resolve range subtype type id")); + + public PgTypeId GetMultirangeTypeId(PgTypeId rangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(rangeTypeId) as PostgresRangeType)?.Multirange + ?? throw new NotSupportedException("Cannot resolve multirange type id")); + + public PgTypeId GetMultirangeElementTypeId(PgTypeId multirangeTypeId) => + ToCanonicalTypeId((DatabaseInfo.GetPostgresType(multirangeTypeId) as PostgresMultirangeType)?.Subrange + ?? throw new NotSupportedException("Cannot resolve multirange element type id")); + + public bool TryGetDataTypeName(PgTypeId pgTypeId, out DataTypeName dataTypeName) + { + if (DatabaseInfo.FindPostgresType(pgTypeId) is { } pgType) + { + dataTypeName = pgType.DataTypeName; + return true; + } + + dataTypeName = default; + return false; + } + + public DataTypeName GetDataTypeName(PgTypeId pgTypeId) + => !TryGetDataTypeName(pgTypeId, out var name) + ? throw new ArgumentException("Unknown type id", nameof(pgTypeId)) + : name; +} diff --git a/src/Npgsql/Internal/PgStreamingConverter.cs b/src/Npgsql/Internal/PgStreamingConverter.cs new file mode 100644 index 0000000000..971f1c6980 --- /dev/null +++ b/src/Npgsql/Internal/PgStreamingConverter.cs @@ -0,0 +1,103 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgStreamingConverter : PgConverter +{ + protected PgStreamingConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Binary; + } + + // Workaround for trimming https://github.com/dotnet/runtime/issues/92850#issuecomment-1744521361 + internal Task? ReadAsyncAsTask(PgReader reader, CancellationToken cancellationToken, out T result) + { + var task = ReadAsync(reader, cancellationToken); + if (task.IsCompletedSuccessfully) + { + result = task.Result; + return null; + } + result = default!; + return task.AsTask(); + } + + internal sealed override unsafe ValueTask ReadAsObject( + bool async, PgReader reader, CancellationToken cancellationToken) + { + if (!async) + return new(Read(reader)!); + + var task = ReadAsync(reader, cancellationToken); + return task.IsCompletedSuccessfully + ? new(task.Result!) + : PgStreamingConverterHelpers.AwaitTask(task.AsTask(), new(this, &BoxResult)); + + static object BoxResult(Task task) + { + // We're using ValueTask.Result here to avoid rooting any TaskAwaiter or ValueTaskAwaiter types. + // On ValueTask calling .Result is equivalent to GetAwaiter().GetResult() w.r.t. exception wrapping. + return new ValueTask(task: (Task)task).Result!; + } + } + + internal sealed override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) + { + if (async) + return WriteAsync(writer, (T)value, cancellationToken); + + Write(writer, (T)value); + return new(); + } +} + +// Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is +// passed along. As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're +// done. +// The alternatives are: +// 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). +// 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse +// as well. +static class PgStreamingConverterHelpers +{ + // Split out from the generic class to amortize the huge size penalty per async state machine, which would otherwise be per + // instantiation. +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + public static async ValueTask AwaitTask(Task task, Continuation continuation) + { + await task.ConfigureAwait(false); + var result = continuation.Invoke(task); + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + return result; + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent + // mistakes. + public readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public object Invoke(Task task) => _continuation(task); + } +} diff --git a/src/Npgsql/Internal/PgTypeInfo.cs b/src/Npgsql/Internal/PgTypeInfo.cs new file mode 100644 index 0000000000..d83c5dfa36 --- /dev/null +++ b/src/Npgsql/Internal/PgTypeInfo.cs @@ -0,0 +1,328 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public class PgTypeInfo +{ + readonly bool _canBinaryConvert; + readonly BufferRequirements _binaryBufferRequirements; + + readonly bool _canTextConvert; + readonly BufferRequirements _textBufferRequirements; + + PgTypeInfo(PgSerializerOptions options, Type type, Type? unboxedType) + { + if (unboxedType is not null && !type.IsAssignableFrom(unboxedType)) + throw new ArgumentException("A value of unboxed type is not assignable to converter type", nameof(unboxedType)); + + Options = options; + IsBoxing = unboxedType is not null; + Type = unboxedType ?? type; + SupportsWriting = true; + } + + public PgTypeInfo(PgSerializerOptions options, PgConverter converter, PgTypeId pgTypeId, Type? unboxedType = null) + : this(options, converter.TypeToConvert, unboxedType) + { + Converter = converter; + PgTypeId = options.GetCanonicalTypeId(pgTypeId); + _canBinaryConvert = converter.CanConvert(DataFormat.Binary, out _binaryBufferRequirements); + _canTextConvert = converter.CanConvert(DataFormat.Text, out _textBufferRequirements); + } + + private protected PgTypeInfo(PgSerializerOptions options, Type type, PgConverterResolution? resolution, Type? unboxedType = null) + : this(options, type, unboxedType) + { + if (resolution is { } res) + { + // Resolutions should always be in canonical form already. + if (options.PortableTypeIds && res.PgTypeId.IsOid || !options.PortableTypeIds && res.PgTypeId.IsDataTypeName) + throw new ArgumentException("Given type id is not in canonical form. Make sure ConverterResolver implementations close over canonical ids, e.g. by calling options.GetCanonicalTypeId(pgTypeId) on the constructor arguments.", nameof(PgTypeId)); + + PgTypeId = res.PgTypeId; + Converter = res.Converter; + _canBinaryConvert = res.Converter.CanConvert(DataFormat.Binary, out _binaryBufferRequirements); + _canTextConvert = res.Converter.CanConvert(DataFormat.Text, out _textBufferRequirements); + } + } + + bool HasCachedInfo(PgConverter converter) => ReferenceEquals(Converter, converter); + + public Type Type { get; } + public PgSerializerOptions Options { get; } + + public bool SupportsWriting { get; init; } + public DataFormat? PreferredFormat { get; init; } + + // Doubles as the storage for the converter coming from a default resolution (used to confirm whether we can use cached info). + PgConverter? Converter { get; } + [MemberNotNullWhen(false, nameof(Converter))] + [MemberNotNullWhen(false, nameof(PgTypeId))] + internal bool IsResolverInfo => GetType() == typeof(PgResolverTypeInfo); + + // TODO pull validate from options + internal exempt for perf? + internal bool ValidateResolution => true; + + // Used for internal converters to save on binary bloat. + internal bool IsBoxing { get; } + + public PgTypeId? PgTypeId { get; } + + // Having it here so we can easily extend any behavior. + internal void DisposeWriteState(object writeState) + { + if (writeState is IDisposable disposable) + disposable.Dispose(); + } + + public PgConverterResolution GetResolution(T? value) + { + if (this is not PgResolverTypeInfo resolverInfo) + return new(Converter!, PgTypeId.GetValueOrDefault()); + + var resolution = resolverInfo.GetResolution(value, null); + return resolution ?? resolverInfo.GetDefaultResolution(null); + } + + // Note: this api is not called GetResolutionAsObject as the semantics are extended, DBNull is a NULL value for all object values. + public PgConverterResolution GetObjectResolution(object? value) + { + switch (this) + { + case { IsResolverInfo: false }: + return new(Converter, PgTypeId.GetValueOrDefault()); + case PgResolverTypeInfo resolverInfo: + PgConverterResolution? resolution = null; + if (value is not DBNull) + resolution = resolverInfo.GetResolutionAsObject(value, null); + return resolution ?? resolverInfo.GetDefaultResolution(null); + default: + return ThrowNotSupported(); + } + + static PgConverterResolution ThrowNotSupported() + => throw new NotSupportedException("Should not happen, please file a bug."); + } + + /// Throws if the instance is a PgResolverTypeInfo. + internal PgConverterResolution GetResolution() + { + if (IsResolverInfo) + ThrowHelper.ThrowInvalidOperationException("Instance is a PgResolverTypeInfo."); + return new(Converter, PgTypeId.GetValueOrDefault()); + } + + bool CanConvert(PgConverter converter, DataFormat format, out BufferRequirements bufferRequirements) + { + if (HasCachedInfo(converter)) + { + switch (format) + { + case DataFormat.Binary: + bufferRequirements = _binaryBufferRequirements; + return _canBinaryConvert; + case DataFormat.Text: + bufferRequirements = _textBufferRequirements; + return _canTextConvert; + } + } + + return converter.CanConvert(format, out bufferRequirements); + } + + public BufferRequirements? GetBufferRequirements(PgConverter converter, DataFormat format) + { + var success = CanConvert(converter, format, out var bufferRequirements); + return success ? bufferRequirements : null; + } + + // TryBind for reading. + internal bool TryBind(Field field, DataFormat format, out PgConverterInfo info) + { + switch (this) + { + case { IsResolverInfo: false }: + if (!CanConvert(Converter, format, out var bufferRequirements)) + { + info = default; + return false; + } + info = new(this, Converter, bufferRequirements.Read); + return true; + case PgResolverTypeInfo resolverInfo: + var resolution = resolverInfo.GetResolution(field); + if (!CanConvert(resolution.Converter, format, out bufferRequirements)) + { + info = default; + return false; + } + info = new(this, resolution.Converter, bufferRequirements.Read); + return true; + default: + throw new NotSupportedException("Should not happen, please file a bug."); + } + } + + // Bind for reading. + internal PgConverterInfo Bind(Field field, DataFormat format) + { + if (!TryBind(field, format, out var info)) + ThrowHelper.ThrowInvalidOperationException($"Resolved converter does not support {format} format."); + + return info; + } + + // Bind for writing. + /// When result is null, the value was interpreted to be a SQL NULL. + internal PgConverterInfo? Bind(PgConverter converter, T? value, out Size size, out object? writeState, out DataFormat format, DataFormat? formatPreference = null) + { + // Basically exists to catch cases like object[] resolving a polymorphic read converter, better to fail during binding than writing. + if (!SupportsWriting) + ThrowHelper.ThrowNotSupportedException($"Writing {Type} is not supported for this type info."); + + format = ResolveFormat(converter, out var bufferRequirements, formatPreference ?? PreferredFormat); + + writeState = null; + if (converter.GetSizeOrDbNull(format, bufferRequirements.Write, value, ref writeState) is not { } sizeOrDbNull) + { + size = default; + return null; + } + + size = sizeOrDbNull; + return new(this, converter, bufferRequirements.Write); + } + + // Bind for writing. + // Note: this api is not called BindAsObject as the semantics are extended, DBNull is a NULL value for all object values. + /// When result is null or DBNull, the value was interpreted to be a SQL NULL. + internal PgConverterInfo? BindObject(PgConverter converter, object? value, out Size size, out object? writeState, out DataFormat format, DataFormat? formatPreference = null) + { + // Basically exists to catch cases like object[] resolving a polymorphic read converter, better to fail during binding than writing. + if (!SupportsWriting) + throw new NotSupportedException($"Writing {Type} is not supported for this type info."); + + format = ResolveFormat(converter, out var bufferRequirements, formatPreference ?? PreferredFormat); + + // Given SQL values are effectively a union of T | NULL we support DBNull.Value to signify a NULL value for all types except DBNull in this api. + writeState = null; + if (value is DBNull && Type != typeof(DBNull) || converter.GetSizeOrDbNullAsObject(format, bufferRequirements.Write, value, ref writeState) is not { } sizeOrDbNull) + { + size = default; + return null; + } + + size = sizeOrDbNull; + return new(this, converter, bufferRequirements.Write); + } + + DataFormat ResolveFormat(PgConverter converter, out BufferRequirements bufferRequirements, DataFormat? formatPreference = null) + { + // First try to check for preferred support. + switch (formatPreference) + { + case DataFormat.Binary when CanConvert(converter, DataFormat.Binary, out bufferRequirements): + return DataFormat.Binary; + case DataFormat.Text when CanConvert(converter, DataFormat.Text, out bufferRequirements): + return DataFormat.Text; + default: + // The common case, no preference given (or no match) means we default to binary if supported. + if (CanConvert(converter, DataFormat.Binary, out bufferRequirements)) + return DataFormat.Binary; + if (CanConvert(converter, DataFormat.Text, out bufferRequirements)) + return DataFormat.Text; + + ThrowHelper.ThrowInvalidOperationException("Converter doesn't support any data format."); + bufferRequirements = default; + return default; + } + } +} + +public sealed class PgResolverTypeInfo : PgTypeInfo +{ + readonly PgConverterResolver _converterResolver; + + public PgResolverTypeInfo(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId? pgTypeId, Type? unboxedType = null) + : base(options, + converterResolver.TypeToConvert, + pgTypeId is { } typeId ? ResolveDefaultId(options, converterResolver, typeId) : null, + // We always mark resolvers with type object as boxing, as they may freely return converters for any type (see PgConverterResolver.Validate). + unboxedType ?? (converterResolver.TypeToConvert == typeof(object) ? typeof(object) : null)) + => _converterResolver = converterResolver; + + // We'll always validate the default resolution, the info will be re-used so there is no real downside. + static PgConverterResolution ResolveDefaultId(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId typeId) + => converterResolver.GetDefaultInternal(validate: true, options.PortableTypeIds, options.GetCanonicalTypeId(typeId)); + + public PgConverterResolution? GetResolution(T? value, PgTypeId? expectedPgTypeId) + { + return _converterResolver is PgConverterResolver resolverT + ? resolverT.GetInternal(this, value, expectedPgTypeId ?? PgTypeId) + : ThrowNotSupportedType(typeof(T)); + + PgConverterResolution ThrowNotSupportedType(Type? type) + => throw new NotSupportedException(IsBoxing + ? "TypeInfo only supports boxing conversions, call GetResolutionAsObject instead." + : $"TypeInfo is not of type {type}"); + } + + public PgConverterResolution? GetResolutionAsObject(object? value, PgTypeId? expectedPgTypeId) + => _converterResolver.GetAsObjectInternal(this, value, expectedPgTypeId ?? PgTypeId); + + public PgConverterResolution GetResolution(Field field) + => _converterResolver.GetInternal(this, field); + + public PgConverterResolution GetDefaultResolution(PgTypeId? expectedPgTypeId) + => _converterResolver.GetDefaultInternal(ValidateResolution, Options.PortableTypeIds, expectedPgTypeId ?? PgTypeId); + + public PgConverterResolver GetConverterResolver() => _converterResolver; +} + +public readonly struct PgConverterResolution +{ + public PgConverterResolution(PgConverter converter, PgTypeId pgTypeId) + { + Converter = converter; + PgTypeId = pgTypeId; + } + + public PgConverter Converter { get; } + public PgTypeId PgTypeId { get; } + + public PgConverter GetConverter() => (PgConverter)Converter; +} + +readonly struct PgConverterInfo +{ + readonly PgTypeInfo _typeInfo; + + public PgConverterInfo(PgTypeInfo pgTypeInfo, PgConverter converter, Size bufferRequirement) + { + _typeInfo = pgTypeInfo; + Converter = converter; + BufferRequirement = bufferRequirement; + + // Object typed resolvers can return any type of converter, so we check the type of the converter instead. + // We cannot do this in general as we should respect the 'unboxed type' of infos, which can differ from the converter type. + if (pgTypeInfo.IsResolverInfo && pgTypeInfo.Type == typeof(object)) + TypeToConvert = Converter.TypeToConvert; + else + TypeToConvert = pgTypeInfo.Type; + } + + public bool IsDefault => _typeInfo is null; + + public Type TypeToConvert { get; } + + public PgTypeInfo TypeInfo => _typeInfo; + + public PgConverter Converter { get; } + public Size BufferRequirement { get; } + + /// Whether Converter.TypeToConvert matches PgTypeInfo.Type, if it doesn't object apis should be used. + public bool IsBoxingConverter => _typeInfo.IsBoxing; +} diff --git a/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs b/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs new file mode 100644 index 0000000000..548d236096 --- /dev/null +++ b/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs @@ -0,0 +1,186 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Npgsql.Internal; + +struct PgTypeInfoResolverChainBuilder +{ + readonly List<(Type ImplementationType, object)> _factories = new(); + Action>? _addRangeResolvers; + Action>? _addMultirangeResolvers; + RangeArrayHandler _rangeArrayHandler = RangeArrayHandler.Instance; + MultirangeArrayHandler _multirangeArrayHandler = MultirangeArrayHandler.Instance; + Action>? _addArrayResolvers; + + public PgTypeInfoResolverChainBuilder() + { + } + + public void Clear() => _factories.Clear(); + + public void AppendResolverFactory(PgTypeInfoResolverFactory factory) + => AddResolverFactory(factory.GetType(), factory); + public void AppendResolverFactory(Func factory) where T : PgTypeInfoResolverFactory + => AddResolverFactory(typeof(T), Memoize(factory)); + + public void PrependResolverFactory(PgTypeInfoResolverFactory factory) + => AddResolverFactory(factory.GetType(), factory, prepend: true); + public void PrependResolverFactory(Func factory) where T : PgTypeInfoResolverFactory + => AddResolverFactory(typeof(T), Memoize(factory), prepend: true); + + // Memoize the caller factory so all our actions (_addArrayResolvers etc.) call into the same instance. + static Func Memoize(Func factory) + { + PgTypeInfoResolverFactory? instance = null; + return () => instance ??= factory(); + } + + static PgTypeInfoResolverFactory GetInstance((Type, object Instance) factory) => factory.Instance switch + { + PgTypeInfoResolverFactory f => f, + Func f => f(), + _ => throw new ArgumentOutOfRangeException(nameof(factory), factory, null) + }; + + void AddResolverFactory(Type type, object factory, bool prepend = false) + { + for (var i = 0; i < _factories.Count; i++) + if (_factories[i].ImplementationType == type) + { + _factories.RemoveAt(i); + break; + } + + if (prepend) + _factories.Insert(0, (type, factory)); + else + _factories.Add((type, factory)); + } + + public void EnableRanges() + { + _addRangeResolvers ??= AddResolvers; + _rangeArrayHandler = RangeArrayHandlerImpl.Instance; + + static void AddResolvers(PgTypeInfoResolverChainBuilder instance, List resolvers) + { + foreach (var factory in instance._factories) + if (GetInstance(factory).CreateRangeResolver() is { } resolver) + resolvers.Add(resolver); + } + } + + public void EnableMultiranges() + { + _addMultirangeResolvers ??= AddResolvers; + _multirangeArrayHandler = MultirangeArrayHandlerImpl.Instance; + + static void AddResolvers(PgTypeInfoResolverChainBuilder instance, List resolvers) + { + foreach (var factory in instance._factories) + if (GetInstance(factory).CreateMultirangeResolver() is { } resolver) + resolvers.Add(resolver); + } + } + + public void EnableArrays() + { + _addArrayResolvers ??= AddResolvers; + + static void AddResolvers(PgTypeInfoResolverChainBuilder instance, List resolvers) + { + foreach (var factory in instance._factories) + if (GetInstance(factory).CreateArrayResolver() is { } resolver) + resolvers.Add(resolver); + + if (instance._addRangeResolvers is not null) + foreach (var factory in instance._factories) + if (instance._rangeArrayHandler.CreateRangeArrayResolver(GetInstance(factory)) is { } resolver) + resolvers.Add(resolver); + + if (instance._addMultirangeResolvers is not null) + foreach (var factory in instance._factories) + if (instance._multirangeArrayHandler.CreateMultirangeArrayResolver(GetInstance(factory)) is { } resolver) + resolvers.Add(resolver); + } + } + + public PgTypeInfoResolverChain Build(Action>? configure = null) + { + var resolvers = new List(); + foreach (var factory in _factories) + resolvers.Add(GetInstance(factory).CreateResolver()); + var instance = this; + _addRangeResolvers?.Invoke(instance, resolvers); + _addMultirangeResolvers?.Invoke(instance, resolvers); + _addArrayResolvers?.Invoke(instance, resolvers); + configure?.Invoke(resolvers); + return new( + resolvers, + rangesEnabled: _addRangeResolvers is not null, + multirangesEnabled: _addMultirangeResolvers is not null, + arraysEnabled: _addArrayResolvers is not null + ); + } + + class RangeArrayHandler + { + public static RangeArrayHandler Instance { get; } = new(); + + public virtual IPgTypeInfoResolver? CreateRangeArrayResolver(PgTypeInfoResolverFactory factory) => null; + } + + sealed class RangeArrayHandlerImpl : RangeArrayHandler + { + public new static RangeArrayHandlerImpl Instance { get; } = new(); + + public override IPgTypeInfoResolver? CreateRangeArrayResolver(PgTypeInfoResolverFactory factory) => factory.CreateRangeArrayResolver(); + } + + class MultirangeArrayHandler + { + public static MultirangeArrayHandler Instance { get; } = new(); + + public virtual IPgTypeInfoResolver? CreateMultirangeArrayResolver(PgTypeInfoResolverFactory factory) => null; + } + + sealed class MultirangeArrayHandlerImpl : MultirangeArrayHandler + { + public new static MultirangeArrayHandlerImpl Instance { get; } = new(); + + public override IPgTypeInfoResolver? CreateMultirangeArrayResolver(PgTypeInfoResolverFactory factory) => factory.CreateMultirangeArrayResolver(); + } +} + +readonly struct PgTypeInfoResolverChain : IEnumerable +{ + [Flags] + enum EnabledFlags + { + None = 0, + Ranges = 1, + Multiranges = 2, + Arrays = 4 + } + + readonly EnabledFlags _enabled; + readonly List _resolvers; + + public PgTypeInfoResolverChain(List resolvers, bool rangesEnabled, bool multirangesEnabled, bool arraysEnabled) + { + _enabled = rangesEnabled ? EnabledFlags.Ranges | _enabled : _enabled; + _enabled = multirangesEnabled ? EnabledFlags.Multiranges | _enabled : _enabled; + _enabled = arraysEnabled ? EnabledFlags.Arrays | _enabled : _enabled; + _resolvers = resolvers; + } + + public bool RangesEnabled => _enabled.HasFlag(EnabledFlags.Ranges); + public bool MultirangesEnabled => _enabled.HasFlag(EnabledFlags.Multiranges); + public bool ArraysEnabled => _enabled.HasFlag(EnabledFlags.Arrays); + + public IEnumerator GetEnumerator() + => _resolvers?.GetEnumerator() ?? (IEnumerator)Array.Empty().GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() + => _resolvers?.GetEnumerator() ?? Array.Empty().GetEnumerator(); +} diff --git a/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs b/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..9392e2c840 --- /dev/null +++ b/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs @@ -0,0 +1,16 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgTypeInfoResolverFactory +{ + public abstract IPgTypeInfoResolver CreateResolver(); + public abstract IPgTypeInfoResolver? CreateArrayResolver(); + + public virtual IPgTypeInfoResolver? CreateRangeResolver() => null; + public virtual IPgTypeInfoResolver? CreateRangeArrayResolver() => null; + + public virtual IPgTypeInfoResolver? CreateMultirangeResolver() => null; + public virtual IPgTypeInfoResolver? CreateMultirangeArrayResolver() => null; +} diff --git a/src/Npgsql/Internal/PgWriter.cs b/src/Npgsql/Internal/PgWriter.cs new file mode 100644 index 0000000000..69a36afa1d --- /dev/null +++ b/src/Npgsql/Internal/PgWriter.cs @@ -0,0 +1,579 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +enum FlushMode +{ + None, + Blocking, + NonBlocking +} + +// A streaming alternative to a System.IO.Stream, instead based on the preferable IBufferWriter. +interface IStreamingWriter: IBufferWriter +{ + void Flush(TimeSpan timeout = default); + ValueTask FlushAsync(CancellationToken cancellationToken = default); +} + +sealed class NpgsqlBufferWriter : IStreamingWriter +{ + readonly NpgsqlWriteBuffer _buffer; + int? _lastBufferSize; + public NpgsqlBufferWriter(NpgsqlWriteBuffer buffer) => _buffer = buffer; + + public void Advance(int count) + { + if (_lastBufferSize < count || _buffer.WriteSpaceLeft < count) + ThrowHelper.ThrowInvalidOperationException("Cannot advance past the end of the current buffer."); + _lastBufferSize = null; + _buffer.WritePosition += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + var writePosition = _buffer.WritePosition; + var bufferSize = _buffer.Size - writePosition; + if (sizeHint > bufferSize) + ThrowOutOfMemoryException(); + + _lastBufferSize = bufferSize; + return _buffer.Buffer.AsMemory(writePosition, bufferSize); + } + + public Span GetSpan(int sizeHint = 0) + { + var writePosition = _buffer.WritePosition; + var bufferSize = _buffer.Size - writePosition; + if (sizeHint > bufferSize) + ThrowOutOfMemoryException(); + + _lastBufferSize = bufferSize; + return _buffer.Buffer.AsSpan(writePosition, bufferSize); + } + + static void ThrowOutOfMemoryException() => throw new OutOfMemoryException("Not enough space left in buffer."); + + public void Flush(TimeSpan timeout = default) + { + if (timeout == TimeSpan.Zero) + _buffer.Flush(); + else + { + TimeSpan? originalTimeout = null; + try + { + if (timeout != TimeSpan.Zero) + { + originalTimeout = _buffer.Timeout; + _buffer.Timeout = timeout; + } + _buffer.Flush(); + } + finally + { + if (originalTimeout is { } value) + _buffer.Timeout = value; + } + } + } + + public ValueTask FlushAsync(CancellationToken cancellationToken = default) + => new(_buffer.Flush(async: true, cancellationToken)); +} + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public sealed class PgWriter +{ + readonly IBufferWriter _writer; + + byte[]? _buffer; + int _offset; + int _pos; + int _length; + + int _totalBytesWritten; + + ValueMetadata _current; + NpgsqlDatabaseInfo? _typeCatalog; + + internal PgWriter(IBufferWriter writer) => _writer = writer; + + internal PgWriter Init(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode = FlushMode.None) + { + if (_pos != _offset) + ThrowHelper.ThrowInvalidOperationException("Invalid concurrent use or PgWriter was not committed properly, PgWriter still has uncommitted bytes."); + + // Elide write barrier if we can. + if (!ReferenceEquals(_typeCatalog, typeCatalog)) + _typeCatalog = typeCatalog; + + FlushMode = flushMode; + _totalBytesWritten = 0; + RequestBuffer(count: 0); + return this; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + void RequestBuffer(int count) + { + // GetMemory will check whether count is larger than the max buffer size. + var mem = _writer.GetMemory(count); + if (!MemoryMarshal.TryGetArray(mem, out var segment)) + ThrowHelper.ThrowNotSupportedException("Only array backed writers are supported."); + + _buffer = segment.Array!; + _offset = _pos = segment.Offset; + _length = segment.Offset + segment.Count; + } + + internal FlushMode FlushMode { get; private set; } + + internal void RefreshBuffer() => RequestBuffer(count: 0); + + internal PgWriter WithFlushMode(FlushMode mode) + { + FlushMode = mode; + return this; + } + + void Ensure(int count = 1) + { + if (count <= Remaining) + return; + + Slow(count); + + [MethodImpl(MethodImplOptions.NoInlining)] + void Slow(int count) + { + // Try to re-request a larger size. + Commit(); + RequestBuffer(count); + // GetMemory is expected to throw if count is too large for the remaining space. + Debug.Assert(count <= Remaining); + } + } + + Span Span => _buffer.AsSpan(_pos, _length - _pos); + + int Remaining => _length - _pos; + + void Advance(int count) => _pos += count; + + internal void Commit(int? expectedByteCount = null) + { + _totalBytesWritten += _pos - _offset; + _writer.Advance(_pos - _offset); + _offset = _pos; + + if (expectedByteCount is not null) + { + var totalBytesWritten = _totalBytesWritten; + _totalBytesWritten = 0; + if (totalBytesWritten != expectedByteCount) + ThrowHelper.ThrowInvalidOperationException($"Bytes written ({totalBytesWritten}) and expected byte count ({expectedByteCount}) don't match."); + } + } + + internal ValueTask BeginWrite(bool async, ValueMetadata current, CancellationToken cancellationToken) + { + _current = current; + if (ShouldFlush(current.BufferRequirement)) + return Flush(async, cancellationToken); + + return new(); + } + + public ValueMetadata Current => _current; + internal Size CurrentBufferRequirement => _current.BufferRequirement; + + // When we don't know the size during writing we're using the writer buffer as a sizing mechanism. + internal bool BufferingWrite => Current.Size.Kind is SizeKind.Unknown; + + // This method lives here to remove the chances oids will be cached on converters inadvertently when data type names should be used. + // Such a mapping (for instance for array element oids) should be done per operation to ensure it is done in the context of a specific backend. + public void WriteAsOid(PgTypeId pgTypeId) + { + var oid = _typeCatalog!.GetOid(pgTypeId); + WriteUInt32((uint)oid); + } + + public void WriteByte(byte value) + { + Ensure(sizeof(byte)); + Span[0] = value; + Advance(sizeof(byte)); + } + + public void WriteInt16(short value) + { + Ensure(sizeof(short)); + BinaryPrimitives.WriteInt16BigEndian(Span, value); + Advance(sizeof(short)); + } + + public void WriteInt32(int value) + { + Ensure(sizeof(int)); + BinaryPrimitives.WriteInt32BigEndian(Span, value); + Advance(sizeof(int)); + } + + public void WriteInt64(long value) + { + Ensure(sizeof(long)); + BinaryPrimitives.WriteInt64BigEndian(Span, value); + Advance(sizeof(long)); + } + + public void WriteUInt16(ushort value) + { + Ensure(sizeof(ushort)); + BinaryPrimitives.WriteUInt16BigEndian(Span, value); + Advance(sizeof(ushort)); + } + + public void WriteUInt32(uint value) + { + Ensure(sizeof(uint)); + BinaryPrimitives.WriteUInt32BigEndian(Span, value); + Advance(sizeof(uint)); + } + + public void WriteUInt64(ulong value) + { + Ensure(sizeof(ulong)); + BinaryPrimitives.WriteUInt64BigEndian(Span, value); + Advance(sizeof(ulong)); + } + + public void WriteFloat(float value) + { +#if NET5_0_OR_GREATER + Ensure(sizeof(float)); + BinaryPrimitives.WriteSingleBigEndian(Span, value); + Advance(sizeof(float)); +#else + WriteUInt32(Unsafe.As(ref value)); +#endif + } + + public void WriteDouble(double value) + { +#if NET5_0_OR_GREATER + Ensure(sizeof(double)); + BinaryPrimitives.WriteDoubleBigEndian(Span, value); + Advance(sizeof(double)); +#else + WriteUInt64(Unsafe.As(ref value)); +#endif + } + + public void WriteChars(ReadOnlySpan data, Encoding encoding) + { + // If we have more chars than bytes remaining we can immediately go to the slow path. + if (data.Length <= Remaining) + { + // If not, it's worth a shot to see if we can convert in one go. + var encodedLength = encoding.GetByteCount(data); + if (!ShouldFlush(encodedLength)) + { + var count = encoding.GetBytes(data, Span); + Advance(count); + return; + } + } + Core(data, encoding); + + void Core(ReadOnlySpan data, Encoding encoding) + { + var encoder = encoding.GetEncoder(); + var minBufferSize = encoding.GetMaxByteCount(1); + + bool completed; + do + { + if (ShouldFlush(minBufferSize)) + Flush(); + Ensure(minBufferSize); + encoder.Convert(data, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + data = data.Slice(charsUsed); + Advance(bytesUsed); + } while (!completed); + } + } + + public ValueTask WriteCharsAsync(ReadOnlyMemory data, Encoding encoding, CancellationToken cancellationToken = default) + { + var dataSpan = data.Span; + // If we have more chars than bytes remaining we can immediately go to the slow path. + if (data.Length <= Remaining) + { + // If not, it's worth a shot to see if we can convert in one go. + var encodedLength = encoding.GetByteCount(dataSpan); + if (!ShouldFlush(encodedLength)) + { + var count = encoding.GetBytes(dataSpan, Span); + Advance(count); + return new(); + } + } + + return Core(data, encoding, cancellationToken); + + async ValueTask Core(ReadOnlyMemory data, Encoding encoding, CancellationToken cancellationToken) + { + var encoder = encoding.GetEncoder(); + var minBufferSize = encoding.GetMaxByteCount(1); + + bool completed; + do + { + if (ShouldFlush(minBufferSize)) + await FlushAsync(cancellationToken).ConfigureAwait(false); + Ensure(minBufferSize); + encoder.Convert(data.Span, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + data = data.Slice(charsUsed); + Advance(bytesUsed); + } while (!completed); + } + } + + public void WriteBytes(ReadOnlySpan buffer) + => WriteBytes(allowMixedIO: false, buffer); + + internal void WriteBytes(bool allowMixedIO, ReadOnlySpan buffer) + { + while (!buffer.IsEmpty) + { + if (Remaining is 0) + Flush(allowWhenNonBlocking: allowMixedIO); + var write = Math.Min(buffer.Length, Remaining); + buffer.Slice(0, write).CopyTo(Span); + Advance(write); + buffer = buffer.Slice(write); + } + } + + public ValueTask WriteBytesAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + => WriteBytesAsync(allowMixedIO: false, buffer, cancellationToken); + + internal ValueTask WriteBytesAsync(bool allowMixedIO, ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (buffer.Length <= Remaining) + { + buffer.Span.CopyTo(Span); + Advance(buffer.Length); + return new(); + } + + return Core(allowMixedIO, buffer, cancellationToken); + + async ValueTask Core(bool allowMixedIO, ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + while (!buffer.IsEmpty) + { + if (Remaining is 0) + await FlushAsync(allowWhenBlocking: allowMixedIO, cancellationToken).ConfigureAwait(false); + var write = Math.Min(buffer.Length, Remaining); + buffer.Span.Slice(0, write).CopyTo(Span); + Advance(write); + buffer = buffer.Slice(write); + } + } + } + /// + /// Gets a that can be used to write to the underlying buffer. + /// + /// Blocking flushes during writes that were expected to be non-blocking and vice versa cause an exception to be thrown unless allowMixedIO is set to true, false by default. + /// The stream. + public Stream GetStream(bool allowMixedIO = false) + => new PgWriterStream(this, allowMixedIO); + + public bool ShouldFlush(Size bufferRequirement) + => ShouldFlush(bufferRequirement is { Kind: SizeKind.UpperBound } + ? Math.Min(Current.Size.Value, bufferRequirement.Value) + : bufferRequirement.GetValueOrDefault()); + + public bool ShouldFlush(int byteCount) => Remaining < byteCount && FlushMode is not FlushMode.None; + + public void Flush(TimeSpan timeout = default) + => Flush(allowWhenNonBlocking: false, timeout); + + void Flush(bool allowWhenNonBlocking, TimeSpan timeout = default) + { + switch (FlushMode) + { + case FlushMode.None: + return; + case FlushMode.NonBlocking when !allowWhenNonBlocking: + throw new NotSupportedException($"Cannot call {nameof(Flush)} on a non-blocking {nameof(PgWriter)}, call FlushAsync instead."); + } + + if (_writer is not IStreamingWriter writer) + throw new NotSupportedException($"Cannot call {nameof(Flush)} on a buffered {nameof(PgWriter)}, {nameof(FlushMode)}.{nameof(FlushMode.None)} should be used to prevent this."); + + Commit(); + writer.Flush(timeout); + RequestBuffer(count: 0); + } + + public ValueTask FlushAsync(CancellationToken cancellationToken = default) + => FlushAsync(allowWhenBlocking: false, cancellationToken); + + async ValueTask FlushAsync(bool allowWhenBlocking, CancellationToken cancellationToken = default) + { + switch (FlushMode) + { + case FlushMode.None: + return; + case FlushMode.Blocking when !allowWhenBlocking: + throw new NotSupportedException($"Cannot call {nameof(FlushAsync)} on a blocking {nameof(PgWriter)}, call Flush instead."); + } + + if (_writer is not IStreamingWriter writer) + throw new NotSupportedException($"Cannot call {nameof(FlushAsync)} on a buffered {nameof(PgWriter)}, {nameof(FlushMode)}.{nameof(FlushMode.None)} should be used to prevent this."); + + Commit(); + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + RequestBuffer(count: 0); + } + + internal ValueTask Flush(bool async, CancellationToken cancellationToken = default) + { + if (async) + return FlushAsync(cancellationToken); + + Flush(); + return new(); + } + + internal ValueTask BeginNestedWrite(bool async, Size bufferRequirement, int byteCount, object? state, CancellationToken cancellationToken) + { + Debug.Assert(bufferRequirement != -1); + + // ShouldFlush depends on the current size for upper bound requirements, so we must set it beforehand. + _current = new() { Format = _current.Format, Size = byteCount, BufferRequirement = bufferRequirement, WriteState = state }; + + if (ShouldFlush(bufferRequirement)) + return Core(async, cancellationToken); + + return new(new NestedWriteScope()); +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Core(bool async, CancellationToken cancellationToken) + { + await Flush(async, cancellationToken).ConfigureAwait(false); + return new(); + } + } + + public NestedWriteScope BeginNestedWrite(Size bufferRequirement, int byteCount, object? state) + => BeginNestedWrite(async: false, bufferRequirement, byteCount, state, CancellationToken.None).GetAwaiter().GetResult(); + + public ValueTask BeginNestedWriteAsync(Size bufferRequirement, int byteCount, object? state, CancellationToken cancellationToken = default) + => BeginNestedWrite(async: true, bufferRequirement, byteCount, state, cancellationToken); + + sealed class PgWriterStream : Stream + { + readonly PgWriter _writer; + readonly bool _allowMixedIO; + + internal PgWriterStream(PgWriter writer, bool allowMixedIO) + { + _writer = writer; + _allowMixedIO = allowMixedIO; + } + + public override void Write(byte[] buffer, int offset, int count) + => Write(async: false, buffer: buffer, offset: offset, count: count, CancellationToken.None).GetAwaiter().GetResult(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Write(async: true, buffer: buffer, offset: offset, count: count, cancellationToken: cancellationToken); + + Task Write(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (buffer is null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentNullException(nameof(offset)); + if (count < 0) + throw new ArgumentNullException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + + if (async) + { + if (cancellationToken.IsCancellationRequested) + return Task.FromCanceled(cancellationToken); + + return _writer.WriteBytesAsync(_allowMixedIO, buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + _writer.WriteBytes(_allowMixedIO, new Span(buffer, offset, count)); + return Task.CompletedTask; + } + +#if !NETSTANDARD2_0 + public override void Write(ReadOnlySpan buffer) => _writer.WriteBytes(_allowMixedIO, buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + return new(Task.FromCanceled(cancellationToken)); + + return _writer.WriteBytesAsync(buffer, cancellationToken); + } +#endif + + public override void Flush() + => _writer.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) + => _writer.FlushAsync(cancellationToken).AsTask(); + + public override bool CanRead => false; + public override bool CanWrite => true; + public override bool CanSeek => false; + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + + public override long Length => throw new NotSupportedException(); + public override void SetLength(long value) + => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + } +} + +// No-op for now. +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public struct NestedWriteScope : IDisposable +{ + public void Dispose() + { + } +} diff --git a/src/Npgsql/Internal/Postgres/DataTypeName.cs b/src/Npgsql/Internal/Postgres/DataTypeName.cs new file mode 100644 index 0000000000..c5b223f866 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/DataTypeName.cs @@ -0,0 +1,240 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal.Postgres; + +/// +/// Represents the fully-qualified name of a PostgreSQL type. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +[DebuggerDisplay("{DisplayName,nq}")] +public readonly struct DataTypeName : IEquatable +{ + /// + /// The maximum length of names in an unmodified PostgreSQL installation. + /// + /// + /// We need to respect this to get to valid names when deriving them (for multirange/arrays etc). + /// This does not include the namespace. + /// + internal const int NAMEDATALEN = 64 - 1; // Minus null terminator. + + readonly string _value; + + DataTypeName(string fullyQualifiedDataTypeName, bool validated) + { + if (!validated) + { + var schemaEndIndex = fullyQualifiedDataTypeName.IndexOf('.'); + if (schemaEndIndex == -1) + throw new ArgumentException("Given value does not contain a schema.", nameof(fullyQualifiedDataTypeName)); + + // Friendly array syntax is the only fully qualified name quirk that's allowed by postgres (see FromDisplayName). + if (fullyQualifiedDataTypeName.AsSpan(schemaEndIndex).EndsWith("[]".AsSpan())) + fullyQualifiedDataTypeName = NormalizeName(fullyQualifiedDataTypeName); + + var typeNameLength = fullyQualifiedDataTypeName.Length - (schemaEndIndex + 1); + if (typeNameLength > NAMEDATALEN) + throw new ArgumentException( + $"Name is too long and would be truncated to: {fullyQualifiedDataTypeName.Substring(0, + fullyQualifiedDataTypeName.Length - typeNameLength + NAMEDATALEN)}"); + } + + _value = fullyQualifiedDataTypeName; + } + + public DataTypeName(string fullyQualifiedDataTypeName) + : this(fullyQualifiedDataTypeName, validated: false) { } + + internal static DataTypeName ValidatedName(string fullyQualifiedDataTypeName) + => new(fullyQualifiedDataTypeName, validated: true); + + // Includes schema unless it's pg_catalog or the name is unspecified. + public string DisplayName => + Value.StartsWith("pg_catalog", StringComparison.Ordinal) || Value == Unspecified + ? UnqualifiedDisplayName + : Schema + "." + UnqualifiedDisplayName; + + public string UnqualifiedDisplayName => ToDisplayName(UnqualifiedNameSpan); + + internal ReadOnlySpan SchemaSpan => Value.AsSpan(0, _value.IndexOf('.')); + public string Schema => Value.Substring(0, _value.IndexOf('.')); + internal ReadOnlySpan UnqualifiedNameSpan => Value.AsSpan(_value.IndexOf('.') + 1); + public string UnqualifiedName => Value.Substring(_value.IndexOf('.') + 1); + public string Value => _value is null ? ThrowDefaultException() : _value; + + static string ThrowDefaultException() => + throw new InvalidOperationException($"This operation cannot be performed on a default value of {nameof(DataTypeName)}."); + + public static implicit operator string(DataTypeName value) => value.Value; + + // This contains two invalid sql identifiers (schema and name are both separate identifiers, and would both have to be quoted to be valid). + // Given this is an invalid name it's fine for us to represent a fully qualified 'unspecified' name with it. + public static DataTypeName Unspecified => new("-.-", validated: true); + + public bool IsArray => UnqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal); + + internal static DataTypeName CreateFullyQualifiedName(string dataTypeName) + => dataTypeName.IndexOf('.') != -1 ? new(dataTypeName) : new("pg_catalog." + dataTypeName); + + // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-ARRAY + // We don't have to deal with [] as we're always starting from a normalized fully qualified name. + public DataTypeName ToArrayName() + { + var unqualifiedNameSpan = UnqualifiedNameSpan; + if (unqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal)) + return this; + + var unqualifiedName = unqualifiedNameSpan.ToString(); + if (unqualifiedName.Length + "_".Length > NAMEDATALEN) + unqualifiedName = unqualifiedName.Substring(0, NAMEDATALEN - "_".Length); + + return new(Schema + "._" + unqualifiedName); + } + + // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-RANGE + // Manual testing on PG confirmed it's only the first occurence of 'range' that gets replaced. + public DataTypeName ToDefaultMultirangeName() + { + var unqualifiedNameSpan = UnqualifiedNameSpan; + if (UnqualifiedNameSpan.IndexOf("multirange".AsSpan(), StringComparison.Ordinal) != -1) + return this; + + var unqualifiedName = unqualifiedNameSpan.ToString(); + var rangeIndex = unqualifiedName.IndexOf("range", StringComparison.Ordinal); + if (rangeIndex != -1) + { + var str = unqualifiedName.Substring(0, rangeIndex) + "multirange" + unqualifiedName.Substring(rangeIndex + "range".Length); + + return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN + ? str.Substring(0, NAMEDATALEN - "multi".Length) + : str)); + } + + return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN + ? unqualifiedName.Substring(0, NAMEDATALEN - "_multirange".Length) + "_multirange" + : unqualifiedName + "_multirange")); + } + + // Create a DataTypeName from a broader range of valid names. + // including SQL aliases like 'timestamp without time zone', trailing facet info etc. + public static DataTypeName FromDisplayName(string displayName, string? schema = null) + { + var displayNameSpan = displayName.AsSpan().Trim(); + + // If we have a schema we're done, Postgres doesn't do display name conversions on fully qualified names. + // There is one exception and that's array syntax, which is always resolvable in both ways, while we want the canonical name. + var schemaEndIndex = displayNameSpan.IndexOf('.'); + if (schemaEndIndex is not -1 && + !displayNameSpan.Slice(schemaEndIndex).StartsWith("_".AsSpan(), StringComparison.Ordinal) && + !displayNameSpan.EndsWith("[]".AsSpan(), StringComparison.Ordinal)) + return new(displayName); + + // First we strip the schema to get the type name. + if (schemaEndIndex is not -1) + { + schema = displayNameSpan.Slice(0, schemaEndIndex).ToString(); + displayNameSpan = displayNameSpan.Slice(schemaEndIndex + 1); + } + + // Then we strip either of the two valid array representations to get the base type name (with or without facets). + var isArray = false; + if (displayNameSpan.StartsWith("_".AsSpan())) + { + isArray = true; + displayNameSpan = displayNameSpan.Slice(1); + } + else if (displayNameSpan.EndsWith("[]".AsSpan())) + { + isArray = true; + displayNameSpan = displayNameSpan.Slice(0, displayNameSpan.Length - 2); + } + + string mapped; + if (schemaEndIndex is -1) + { + // Finally we strip the facet info. + var parenIndex = displayNameSpan.IndexOf('('); + if (parenIndex > -1) + displayNameSpan = displayNameSpan.Slice(0, parenIndex); + + // Map any aliases to the internal type name. + mapped = displayNameSpan.ToString() switch + { + "boolean" => "bool", + "character" => "bpchar", + "decimal" => "numeric", + "real" => "float4", + "double precision" => "float8", + "smallint" => "int2", + "integer" => "int4", + "bigint" => "int8", + "time without time zone" => "time", + "timestamp without time zone" => "timestamp", + "time with time zone" => "timetz", + "timestamp with time zone" => "timestamptz", + "bit varying" => "varbit", + "character varying" => "varchar", + var value => value + }; + } + else + { + // If we had a schema originally we stop here, see comment at schemaEndIndex. + mapped = displayNameSpan.ToString(); + } + + return new((schema ?? "pg_catalog") + "." + (isArray ? "_" : "") + mapped); + } + + // The type names stored in a DataTypeName are usually the actual typname from the pg_type column. + // There are some canonical aliases defined in the SQL standard which we take into account. + // Additionally array types have a '_' prefix while for readability their element type should be postfixed with '[]'. + // See the table for all the aliases https://www.postgresql.org/docs/current/static/datatype.html#DATATYPE-TABLE + // Alternatively some of the source lives at https://github.com/postgres/postgres/blob/c8e1ba736b2b9e8c98d37a5b77c4ed31baf94147/src/backend/utils/adt/format_type.c#L186 + static string ToDisplayName(ReadOnlySpan unqualifiedName) + { + var isArray = unqualifiedName.IndexOf('_') == 0; + var baseTypeName = isArray ? unqualifiedName.Slice(1).ToString() : unqualifiedName.ToString(); + + var mappedBaseType = baseTypeName switch + { + "bool" => "boolean", + "bpchar" => "character", + "decimal" => "numeric", + "float4" => "real", + "float8" => "double precision", + "int2" => "smallint", + "int4" => "integer", + "int8" => "bigint", + "time" => "time without time zone", + "timestamp" => "timestamp without time zone", + "timetz" => "time with time zone", + "timestamptz" => "timestamp with time zone", + "varbit" => "bit varying", + "varchar" => "character varying", + _ => baseTypeName + }; + + if (isArray) + return mappedBaseType + "[]"; + + return mappedBaseType; + } + + internal static bool IsFullyQualified(ReadOnlySpan dataTypeName) => dataTypeName.Contains(".".AsSpan(), StringComparison.Ordinal); + + internal static string NormalizeName(string dataTypeName) + { + var fqName = FromDisplayName(dataTypeName); + return IsFullyQualified(dataTypeName.AsSpan()) ? fqName.Value : fqName.UnqualifiedName; + } + + public override string ToString() => Value; + public bool Equals(DataTypeName other) => string.Equals(_value, other._value); + public override bool Equals(object? obj) => obj is DataTypeName other && Equals(other); + public override int GetHashCode() => _value.GetHashCode(); + public static bool operator ==(DataTypeName left, DataTypeName right) => left.Equals(right); + public static bool operator !=(DataTypeName left, DataTypeName right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Postgres/DataTypeNames.cs b/src/Npgsql/Internal/Postgres/DataTypeNames.cs new file mode 100644 index 0000000000..275bcb9937 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/DataTypeNames.cs @@ -0,0 +1,79 @@ +using static Npgsql.Internal.Postgres.DataTypeName; + +namespace Npgsql.Internal.Postgres; + +/// +/// Well-known PostgreSQL data type names. +/// +static class DataTypeNames +{ + // Note: The names are fully qualified in source so the strings are constants and instances will be interned after the first call. + // Uses an internal constructor bypassing the public DataTypeName constructor validation, as we don't want to store all these names on + // fields either. + public static DataTypeName Int2 => ValidatedName("pg_catalog.int2"); + public static DataTypeName Int4 => ValidatedName("pg_catalog.int4"); + public static DataTypeName Int4Range => ValidatedName("pg_catalog.int4range"); + public static DataTypeName Int4Multirange => ValidatedName("pg_catalog.int4multirange"); + public static DataTypeName Int8 => ValidatedName("pg_catalog.int8"); + public static DataTypeName Int8Range => ValidatedName("pg_catalog.int8range"); + public static DataTypeName Int8Multirange => ValidatedName("pg_catalog.int8multirange"); + public static DataTypeName Float4 => ValidatedName("pg_catalog.float4"); + public static DataTypeName Float8 => ValidatedName("pg_catalog.float8"); + public static DataTypeName Numeric => ValidatedName("pg_catalog.numeric"); + public static DataTypeName NumRange => ValidatedName("pg_catalog.numrange"); + public static DataTypeName NumMultirange => ValidatedName("pg_catalog.nummultirange"); + public static DataTypeName Money => ValidatedName("pg_catalog.money"); + public static DataTypeName Bool => ValidatedName("pg_catalog.bool"); + public static DataTypeName Box => ValidatedName("pg_catalog.box"); + public static DataTypeName Circle => ValidatedName("pg_catalog.circle"); + public static DataTypeName Line => ValidatedName("pg_catalog.line"); + public static DataTypeName LSeg => ValidatedName("pg_catalog.lseg"); + public static DataTypeName Path => ValidatedName("pg_catalog.path"); + public static DataTypeName Point => ValidatedName("pg_catalog.point"); + public static DataTypeName Polygon => ValidatedName("pg_catalog.polygon"); + public static DataTypeName Bpchar => ValidatedName("pg_catalog.bpchar"); + public static DataTypeName Text => ValidatedName("pg_catalog.text"); + public static DataTypeName Varchar => ValidatedName("pg_catalog.varchar"); + public static DataTypeName Char => ValidatedName("pg_catalog.char"); + public static DataTypeName Name => ValidatedName("pg_catalog.name"); + public static DataTypeName Bytea => ValidatedName("pg_catalog.bytea"); + public static DataTypeName Date => ValidatedName("pg_catalog.date"); + public static DataTypeName DateRange => ValidatedName("pg_catalog.daterange"); + public static DataTypeName DateMultirange => ValidatedName("pg_catalog.datemultirange"); + public static DataTypeName Time => ValidatedName("pg_catalog.time"); + public static DataTypeName Timestamp => ValidatedName("pg_catalog.timestamp"); + public static DataTypeName TsRange => ValidatedName("pg_catalog.tsrange"); + public static DataTypeName TsMultirange => ValidatedName("pg_catalog.tsmultirange"); + public static DataTypeName TimestampTz => ValidatedName("pg_catalog.timestamptz"); + public static DataTypeName TsTzRange => ValidatedName("pg_catalog.tstzrange"); + public static DataTypeName TsTzMultirange => ValidatedName("pg_catalog.tstzmultirange"); + public static DataTypeName Interval => ValidatedName("pg_catalog.interval"); + public static DataTypeName TimeTz => ValidatedName("pg_catalog.timetz"); + public static DataTypeName Inet => ValidatedName("pg_catalog.inet"); + public static DataTypeName Cidr => ValidatedName("pg_catalog.cidr"); + public static DataTypeName MacAddr => ValidatedName("pg_catalog.macaddr"); + public static DataTypeName MacAddr8 => ValidatedName("pg_catalog.macaddr8"); + public static DataTypeName Bit => ValidatedName("pg_catalog.bit"); + public static DataTypeName Varbit => ValidatedName("pg_catalog.varbit"); + public static DataTypeName TsVector => ValidatedName("pg_catalog.tsvector"); + public static DataTypeName TsQuery => ValidatedName("pg_catalog.tsquery"); + public static DataTypeName RegConfig => ValidatedName("pg_catalog.regconfig"); + public static DataTypeName Uuid => ValidatedName("pg_catalog.uuid"); + public static DataTypeName Xml => ValidatedName("pg_catalog.xml"); + public static DataTypeName Json => ValidatedName("pg_catalog.json"); + public static DataTypeName Jsonb => ValidatedName("pg_catalog.jsonb"); + public static DataTypeName Jsonpath => ValidatedName("pg_catalog.jsonpath"); + public static DataTypeName Record => ValidatedName("pg_catalog.record"); + public static DataTypeName RefCursor => ValidatedName("pg_catalog.refcursor"); + public static DataTypeName OidVector => ValidatedName("pg_catalog.oidvector"); + public static DataTypeName Int2Vector => ValidatedName("pg_catalog.int2vector"); + public static DataTypeName Oid => ValidatedName("pg_catalog.oid"); + public static DataTypeName Xid => ValidatedName("pg_catalog.xid"); + public static DataTypeName Xid8 => ValidatedName("pg_catalog.xid8"); + public static DataTypeName Cid => ValidatedName("pg_catalog.cid"); + public static DataTypeName RegType => ValidatedName("pg_catalog.regtype"); + public static DataTypeName Tid => ValidatedName("pg_catalog.tid"); + public static DataTypeName PgLsn => ValidatedName("pg_catalog.pg_lsn"); + public static DataTypeName Unknown => ValidatedName("pg_catalog.unknown"); + public static DataTypeName Void => ValidatedName("pg_catalog.void"); +} diff --git a/src/Npgsql/Internal/Postgres/Field.cs b/src/Npgsql/Internal/Postgres/Field.cs new file mode 100644 index 0000000000..cb2879f998 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/Field.cs @@ -0,0 +1,19 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal.Postgres; + +/// Base field type shared between tables and composites. +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct Field +{ + public Field(string name, PgTypeId pgTypeId, int typeModifier) + { + Name = name; + PgTypeId = pgTypeId; + TypeModifier = typeModifier; + } + + public string Name { get; init; } + public PgTypeId PgTypeId { get; init; } + public int TypeModifier { get; init; } +} diff --git a/src/Npgsql/Internal/Postgres/Oid.cs b/src/Npgsql/Internal/Postgres/Oid.cs new file mode 100644 index 0000000000..55ede288fe --- /dev/null +++ b/src/Npgsql/Internal/Postgres/Oid.cs @@ -0,0 +1,22 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal.Postgres; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct Oid: IEquatable +{ + public Oid(uint value) => Value = value; + + public static explicit operator uint(Oid oid) => oid.Value; + public static implicit operator Oid(uint oid) => new(oid); + public uint Value { get; init; } + public static Oid Unspecified => new(0); + + public override string ToString() => Value.ToString(); + public bool Equals(Oid other) => Value == other.Value; + public override bool Equals(object? obj) => obj is Oid other && Equals(other); + public override int GetHashCode() => (int)Value; + public static bool operator ==(Oid left, Oid right) => left.Equals(right); + public static bool operator !=(Oid left, Oid right) => !left.Equals(right); +} diff --git a/src/Npgsql/Internal/Postgres/PgTypeId.cs b/src/Npgsql/Internal/Postgres/PgTypeId.cs new file mode 100644 index 0000000000..ee5ffb9d41 --- /dev/null +++ b/src/Npgsql/Internal/Postgres/PgTypeId.cs @@ -0,0 +1,48 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal.Postgres; + +/// +/// A discriminated union of and . +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct PgTypeId: IEquatable +{ + readonly DataTypeName _dataTypeName; + readonly Oid _oid; + + public PgTypeId(DataTypeName name) => _dataTypeName = name; + public PgTypeId(Oid oid) => _oid = oid; + + [MemberNotNullWhen(true, nameof(_dataTypeName))] + public bool IsDataTypeName => _dataTypeName != default; + public bool IsOid => _dataTypeName == default; + + public DataTypeName DataTypeName + => IsDataTypeName ? _dataTypeName : throw new InvalidOperationException("This value does not describe a DataTypeName."); + + public Oid Oid + => IsOid ? _oid : throw new InvalidOperationException("This value does not describe an Oid."); + + public static implicit operator PgTypeId(DataTypeName name) => new(name); + public static implicit operator PgTypeId(Oid id) => new(id); + + public override string ToString() => IsOid ? "OID " + _oid : "DataTypeName " + _dataTypeName.Value; + + public bool Equals(PgTypeId other) + { + if (IsOid && other.IsOid) + return _oid == other._oid; + if (IsDataTypeName && other.IsDataTypeName) + return _dataTypeName.Equals(other._dataTypeName); + return false; + } + + public override bool Equals(object? obj) => obj is PgTypeId other && Equals(other); + public override int GetHashCode() => IsOid ? _oid.GetHashCode() : _dataTypeName.GetHashCode(); + public static bool operator ==(PgTypeId left, PgTypeId right) => left.Equals(right); + public static bool operator !=(PgTypeId left, PgTypeId right) => !left.Equals(right); + + internal bool IsUnspecified => IsOid && _oid == Oid.Unspecified || _dataTypeName == DataTypeName.Unspecified; +} diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs new file mode 100644 index 0000000000..5d86451357 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs @@ -0,0 +1,254 @@ +using System; +using System.Collections.Generic; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Util; +using NpgsqlTypes; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.Internal.ResolverFactories; + +sealed partial class AdoTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateMultirangeResolver() => new MultirangeResolver(); + public override IPgTypeInfoResolver CreateMultirangeArrayResolver() => new MultirangeArrayResolver(); + + class MultirangeResolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => options.DatabaseInfo.SupportsMultirangeTypes ? Mappings.Find(type, dataTypeName, options) : null; + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // int4multirange + mappings.AddType[]>(DataTypeNames.Int4Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter(CreateRangeConverter(new Int4Converter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.Int4Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter(CreateRangeConverter(new Int4Converter(), options), options))); + + // int8multirange + mappings.AddType[]>(DataTypeNames.Int8Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.Int8Multirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // nummultirange + mappings.AddType[]>(DataTypeNames.NumMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter(CreateRangeConverter(new DecimalNumericConverter(), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.NumMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter(CreateRangeConverter(new DecimalNumericConverter(), options), options))); + + // tsmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddType[]>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), + options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), + options), options))); + } + else + { + mappings.AddResolverType[]>(DataTypeNames.TsMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddResolverType>>(DataTypeNames.TsMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + } + + mappings.AddType[]>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + mappings.AddType>>(DataTypeNames.TsMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // tstzmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), + options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), + options), options))); + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), + options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), + options))); + } + else + { + mappings.AddResolverType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddResolverType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), + options.GetCanonicalTypeId(DataTypeNames.TsMultirange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options), options))); + } + + mappings.AddType[]>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateArrayMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + mappings.AddType>>(DataTypeNames.TsTzMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); + + // datemultirange + mappings.AddType[]>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + #if NET6_0_OR_GREATER + mappings.AddType[]>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options)), + isDefault: true); + mappings.AddType>>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + #endif + + return mappings; + } + } + + sealed class MultirangeArrayResolver : MultirangeResolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => options.DatabaseInfo.SupportsMultirangeTypes ? Mappings.Find(type, dataTypeName, options) : null; + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // int4multirange + mappings.AddArrayType[]>(DataTypeNames.Int4Multirange); + mappings.AddArrayType>>(DataTypeNames.Int4Multirange); + + // int8multirange + mappings.AddArrayType[]>(DataTypeNames.Int8Multirange); + mappings.AddArrayType>>(DataTypeNames.Int8Multirange); + + // nummultirange + mappings.AddArrayType[]>(DataTypeNames.NumMultirange); + mappings.AddArrayType>>(DataTypeNames.NumMultirange); + + // tsmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddArrayType>>(DataTypeNames.TsMultirange); + } + else + { + mappings.AddResolverArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddResolverArrayType>>(DataTypeNames.TsMultirange); + } + + mappings.AddArrayType[]>(DataTypeNames.TsMultirange); + mappings.AddArrayType>>(DataTypeNames.TsMultirange); + + // tstzmultirange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + } + else + { + mappings.AddResolverArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddResolverArrayType>>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + } + + mappings.AddArrayType[]>(DataTypeNames.TsTzMultirange); + mappings.AddArrayType>>(DataTypeNames.TsTzMultirange); + + // datemultirange + mappings.AddArrayType[]>(DataTypeNames.DateMultirange); + mappings.AddArrayType>>(DataTypeNames.DateMultirange); + #if NET6_0_OR_GREATER + mappings.AddArrayType[]>(DataTypeNames.DateMultirange); + mappings.AddArrayType>>(DataTypeNames.DateMultirange); + #endif + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs new file mode 100644 index 0000000000..35993c8830 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs @@ -0,0 +1,152 @@ +using System; +using System.Numerics; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Util; +using NpgsqlTypes; +using static Npgsql.Internal.PgConverterFactory; + +namespace Npgsql.Internal.ResolverFactories; + +sealed partial class AdoTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateRangeResolver() => new RangeResolver(); + public override IPgTypeInfoResolver CreateRangeArrayResolver() => new RangeArrayResolver(); + + class RangeResolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // numeric ranges + mappings.AddStructType>(DataTypeNames.Int4Range, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int4Converter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.Int8Range, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.NumRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new DecimalNumericConverter(), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.NumRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new BigIntegerNumericConverter(), options))); + + // tsrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType>(DataTypeNames.TsRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true), options)), + isDefault: true); + } + else + { + mappings.AddResolverStructType>(DataTypeNames.TsRange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateRangeResolver(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzRange), + options.GetCanonicalTypeId(DataTypeNames.TsRange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + } + mappings.AddStructType>(DataTypeNames.TsRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options))); + + // tstzrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options))); + } + else + { + mappings.AddResolverStructType>(DataTypeNames.TsTzRange, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateRangeResolver(options, + options.GetCanonicalTypeId(DataTypeNames.TsTzRange), + options.GetCanonicalTypeId(DataTypeNames.TsRange), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + isDefault: true); + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions), options))); + } + mappings.AddStructType>(DataTypeNames.TsTzRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options))); + + // daterange + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => mapping.CreateInfo(options, + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options)), + isDefault: true); + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int4Converter(), options))); + #if NET6_0_OR_GREATER + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options))); + #endif + + return mappings; + } + } + + sealed class RangeArrayResolver : RangeResolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // numeric ranges + mappings.AddStructArrayType>(DataTypeNames.Int4Range); + mappings.AddStructArrayType>(DataTypeNames.Int8Range); + mappings.AddStructArrayType>(DataTypeNames.NumRange); + mappings.AddStructArrayType>(DataTypeNames.NumRange); + + // tsrange + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType>(DataTypeNames.TsRange); + else + mappings.AddResolverStructArrayType>(DataTypeNames.TsRange); + mappings.AddStructArrayType>(DataTypeNames.TsRange); + + // tstzrange + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + } + else + { + mappings.AddResolverStructArrayType>(DataTypeNames.TsTzRange); + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + } + mappings.AddStructArrayType>(DataTypeNames.TsTzRange); + + // daterange + mappings.AddStructArrayType>(DataTypeNames.DateRange); + mappings.AddStructArrayType>(DataTypeNames.DateRange); +#if NET6_0_OR_GREATER + mappings.AddStructArrayType>(DataTypeNames.DateRange); +#endif + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..61ed5cf011 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs @@ -0,0 +1,532 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.IO; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Converters.Internal; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using Npgsql.Util; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +sealed partial class AdoTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + Resolver ResolverInstance { get; } = new(); + + public static AdoTypeInfoResolverFactory Instance { get; } = new(); + + public override IPgTypeInfoResolver CreateResolver() => ResolverInstance; + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + // Baseline types that are always supported. + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + var info = Mappings.Find(type, dataTypeName, options); + if (info is null && dataTypeName is not null) + info = GetEnumTypeInfo(type, dataTypeName.GetValueOrDefault(), options); + + return info; + } + + static PgTypeInfo? GetEnumTypeInfo(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (type is not null && type != typeof(string)) + return null; + + if (options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) + return null; + + return new PgTypeInfo(options, new StringTextConverter(options.TextEncoding), dataTypeName); + } + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // Bool + mappings.AddStructType(DataTypeNames.Bool, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolConverter()), isDefault: true); + + // Numeric + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter()), isDefault: true); + // Clr byte/sbyte maps to 'int2' as there is no byte type in PostgreSQL. + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Float4, + static (options, mapping, _) => mapping.CreateInfo(options, new RealConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Float8, + static (options, mapping, _) => mapping.CreateInfo(options, new DoubleConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Money, + static (options, mapping, _) => mapping.CreateInfo(options, new MoneyConverter()), MatchRequirement.DataTypeName); + + // Text + // Update PgSerializerOptions.IsWellKnownTextType(Type) after any changes to this list. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), isDefault: true); + mappings.AddStructType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + // Uses the bytea converters, as neither type has a header. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => new PgTypeInfo(options, new StreamByteaConverter(), new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type != typeof(Stream) ? mapping.Type : null), + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), isDefault: true); + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + // Uses the bytea converters, as neither type has a header. + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter()), + MatchRequirement.DataTypeName); + mappings.AddType(dataTypeName, + static (options, mapping, _) => new PgTypeInfo(options, new StreamByteaConverter(), new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type != typeof(Stream) ? mapping.Type : null), + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + } + + // Jsonb + const byte jsonbVersion = 1; + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new StringTextConverter(options.TextEncoding))), isDefault: true); + mappings.AddStructType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new CharTextConverter(options.TextEncoding)))); + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new ArrayByteaConverter())), + MatchRequirement.DataTypeName); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new ReadOnlyMemoryByteaConverter())), + MatchRequirement.DataTypeName); + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => new PgTypeInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new StreamByteaConverter()), new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type != typeof(Stream) ? mapping.Type : null), + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Jsonpath + const byte jsonpathVersion = 1; + mappings.AddType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new StringTextConverter(options.TextEncoding))), isDefault: true); + //Special mappings, these have no corresponding array mapping. + mappings.AddType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Jsonpath, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Bytea + mappings.AddType(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ArrayByteaConverter()), isDefault: true); + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryByteaConverter())); + mappings.AddType(DataTypeNames.Bytea, + static (options, mapping, _) => new PgTypeInfo(options, new StreamByteaConverter(), new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type != typeof(Stream) ? mapping.Type : null), + mapping => mapping with { TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); + + // Varbit + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Varbit)), supportsWriting: false)); + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolBitStringConverter())); + mappings.AddStructType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitVector32BitStringConverter())); + + // Bit + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Bit)), supportsWriting: false)); + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BoolBitStringConverter())); + mappings.AddStructType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new BitVector32BitStringConverter())); + + // Timestamp + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType(DataTypeNames.Timestamp, + static (options, mapping, _) => mapping.CreateInfo(options, + new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: true)), isDefault: true); + } + else + { + mappings.AddResolverStructType(DataTypeNames.Timestamp, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + } + mappings.AddStructType(DataTypeNames.Timestamp, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // TimestampTz + if (Statics.LegacyTimestampBehavior) + { + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, + new LegacyDateTimeConverter(options.EnableDateTimeInfinityConversions, timestamp: false)), matchRequirement: MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new LegacyDateTimeOffsetConverter(options.EnableDateTimeInfinityConversions))); + } + else + { + mappings.AddResolverStructType(DataTypeNames.TimestampTz, + static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), + options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions))); + } + mappings.AddStructType(DataTypeNames.TimestampTz, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // Date + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => + mapping.CreateInfo(options, new DateTimeDateConverter(options.EnableDateTimeInfinityConversions)), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + #if NET6_0_OR_GREATER + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => mapping.CreateInfo(options, new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions))); + #endif + + // Interval + mappings.AddStructType(DataTypeNames.Interval, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanIntervalConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Interval, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlIntervalConverter())); + + // Time + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanTimeConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + #if NET6_0_OR_GREATER + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeOnlyTimeConverter())); + #endif + + // TimeTz + mappings.AddStructType(DataTypeNames.TimeTz, + static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeOffsetTimeTzConverter()), + MatchRequirement.DataTypeName); + + // Uuid + mappings.AddStructType(DataTypeNames.Uuid, + static (options, mapping, _) => mapping.CreateInfo(options, new GuidUuidConverter()), isDefault: true); + + // Hstore + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding)), isDefault: true); + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding))); + + // Unknown + mappings.AddType(DataTypeNames.Unknown, + static (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text), + MatchRequirement.DataTypeName); + + // Void + mappings.AddType(DataTypeNames.Void, + static (options, mapping, _) => mapping.CreateInfo(options, new VoidConverter(), supportsWriting: false), + MatchRequirement.DataTypeName); + + // UInt internal types + foreach (var dataTypeName in new[] { DataTypeNames.Oid, DataTypeNames.Xid, DataTypeNames.Cid, DataTypeNames.RegType, DataTypeNames.RegConfig }) + { + mappings.AddStructType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt32Converter()), + MatchRequirement.DataTypeName); + } + + // Char + mappings.AddStructType(DataTypeNames.Char, + static (options, mapping, _) => mapping.CreateInfo(options, new InternalCharConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.Char, + static (options, mapping, _) => mapping.CreateInfo(options, new InternalCharConverter()), + MatchRequirement.DataTypeName); + + // Xid8 + mappings.AddStructType(DataTypeNames.Xid8, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt64Converter()), + MatchRequirement.DataTypeName); + + // Oidvector + mappings.AddType( + DataTypeNames.OidVector, + static (options, mapping, _) => mapping.CreateInfo(options, + new ArrayBasedArrayConverter(new(new UInt32Converter(), new PgTypeId(DataTypeNames.Oid)), pgLowerBound: 0)), + MatchRequirement.DataTypeName); + + // Int2vector + mappings.AddType( + DataTypeNames.Int2Vector, + static (options, mapping, _) => mapping.CreateInfo(options, + new ArrayBasedArrayConverter(new(new Int2Converter(), new PgTypeId(DataTypeNames.Int2)), pgLowerBound: 0)), + MatchRequirement.DataTypeName); + + // Tid + mappings.AddStructType(DataTypeNames.Tid, + static (options, mapping, _) => mapping.CreateInfo(options, new TidConverter()), + MatchRequirement.DataTypeName); + + // PgLsn + mappings.AddStructType(DataTypeNames.PgLsn, + static (options, mapping, _) => mapping.CreateInfo(options, new PgLsnConverter()), + MatchRequirement.DataTypeName); + mappings.AddStructType(DataTypeNames.PgLsn, + static (options, mapping, _) => mapping.CreateInfo(options, new UInt64Converter()), + MatchRequirement.DataTypeName); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + var info = Mappings.Find(type, dataTypeName, options); + + Type? elementType = null; + if (info is null && dataTypeName is not null + && options.DatabaseInfo.GetPostgresType(dataTypeName) is PostgresArrayType { Element: var pgElementType } + && (type is null || type == typeof(object) || TypeInfoMappingCollection.IsArrayLikeType(type, out elementType))) + { + info = GetEnumArrayTypeInfo(elementType, pgElementType, type, dataTypeName.GetValueOrDefault(), options) ?? + GetObjectArrayTypeInfo(elementType, pgElementType, type, dataTypeName.GetValueOrDefault(), options); + } + return info; + } + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // Bool + mappings.AddStructArrayType(DataTypeNames.Bool); + + // Numeric + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Float4); + mappings.AddStructArrayType(DataTypeNames.Float8); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Money); + + // Text + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType(DataTypeNames.Text); + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + mappings.AddArrayType(DataTypeNames.Text); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + + // Jsonb + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType(DataTypeNames.Jsonb); + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + mappings.AddArrayType(DataTypeNames.Jsonb); + + // Jsonpath + mappings.AddArrayType(DataTypeNames.Jsonpath); + + // Bytea + mappings.AddArrayType(DataTypeNames.Bytea); + mappings.AddStructArrayType>(DataTypeNames.Bytea); + mappings.AddArrayType(DataTypeNames.Bytea); + + // Varbit + // Object mapping first. + mappings.AddPolymorphicResolverArrayType(DataTypeNames.Varbit, static options => resolution => resolution.Converter switch + { + BoolBitStringConverter => PgConverterFactory.CreatePolymorphicArrayConverter( + () => new ArrayBasedArrayConverter(resolution, typeof(Array)), + () => new ArrayBasedArrayConverter(new(new NullableConverter(resolution.GetConverter()), resolution.PgTypeId), typeof(Array)), + options), + BitArrayBitStringConverter => new ArrayBasedArrayConverter(resolution, typeof(Array)), + _ => throw new NotSupportedException() + }); + mappings.AddArrayType(DataTypeNames.Varbit); + mappings.AddStructArrayType(DataTypeNames.Varbit); + mappings.AddStructArrayType(DataTypeNames.Varbit); + + // Bit + // Object mapping first. + mappings.AddPolymorphicResolverArrayType(DataTypeNames.Bit, static options => resolution => resolution.Converter switch + { + BoolBitStringConverter => PgConverterFactory.CreatePolymorphicArrayConverter( + () => new ArrayBasedArrayConverter(resolution, typeof(Array)), + () => new ArrayBasedArrayConverter(new(new NullableConverter(resolution.GetConverter()), resolution.PgTypeId), typeof(Array)), + options), + BitArrayBitStringConverter => new ArrayBasedArrayConverter(resolution, typeof(Array)), + _ => throw new NotSupportedException() + }); + mappings.AddArrayType(DataTypeNames.Bit); + mappings.AddStructArrayType(DataTypeNames.Bit); + mappings.AddStructArrayType(DataTypeNames.Bit); + + // Timestamp + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType(DataTypeNames.Timestamp); + else + mappings.AddResolverStructArrayType(DataTypeNames.Timestamp); + mappings.AddStructArrayType(DataTypeNames.Timestamp); + + // TimestampTz + if (Statics.LegacyTimestampBehavior) + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + else + mappings.AddResolverStructArrayType(DataTypeNames.TimestampTz); + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + mappings.AddStructArrayType(DataTypeNames.TimestampTz); + + // Date + mappings.AddStructArrayType(DataTypeNames.Date); + mappings.AddStructArrayType(DataTypeNames.Date); + #if NET6_0_OR_GREATER + mappings.AddStructArrayType(DataTypeNames.Date); + #endif + + // Interval + mappings.AddStructArrayType(DataTypeNames.Interval); + mappings.AddStructArrayType(DataTypeNames.Interval); + + // Time + mappings.AddStructArrayType(DataTypeNames.Time); + mappings.AddStructArrayType(DataTypeNames.Time); + #if NET6_0_OR_GREATER + mappings.AddStructArrayType(DataTypeNames.Time); + #endif + + // TimeTz + mappings.AddStructArrayType(DataTypeNames.TimeTz); + // Uuid + mappings.AddStructArrayType(DataTypeNames.Uuid); + + // Hstore + mappings.AddArrayType>("hstore"); + mappings.AddArrayType>("hstore"); + + // UInt internal types + foreach (var dataTypeName in new[] { DataTypeNames.Oid, DataTypeNames.Xid, DataTypeNames.Cid, DataTypeNames.RegType, (string)DataTypeNames.RegConfig }) + { + mappings.AddStructArrayType(dataTypeName); + } + + // Char + mappings.AddStructArrayType(DataTypeNames.Char); + mappings.AddStructArrayType(DataTypeNames.Char); + + // Xid8 + mappings.AddStructArrayType(DataTypeNames.Xid8); + + // Oidvector + mappings.AddArrayType(DataTypeNames.OidVector); + + // Int2vector + mappings.AddArrayType(DataTypeNames.Int2Vector); + + return mappings; + } + + static PgTypeInfo? GetObjectArrayTypeInfo(Type? elementType, PostgresType pgElementType, Type? type, DataTypeName dataTypeName, + PgSerializerOptions options) + { + if (elementType != typeof(object)) + return null; + + // Probe if there is any mapping at all for this element type. + var elementId = options.ToCanonicalTypeId(pgElementType); + if (options.GetDefaultTypeInfo(elementId) is null) + return null; + + var mappings = new TypeInfoMappingCollection(); + mappings.AddType(pgElementType.DataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new ObjectConverter(options, elementId)), MatchRequirement.DataTypeName); + mappings.AddArrayType(pgElementType.DataTypeName); + return mappings.Find(type, dataTypeName, options); + } + + static PgTypeInfo? GetEnumArrayTypeInfo(Type? elementType, PostgresType pgElementType, Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if ((type != typeof(object) && elementType is not null && elementType != typeof(string)) || pgElementType is not PostgresEnumType enumType) + return null; + + var mappings = new TypeInfoMappingCollection(); + mappings.AddType(enumType.DataTypeName, + (options, mapping, _) => mapping.CreateInfo(options, new StringTextConverter(options.TextEncoding)), MatchRequirement.DataTypeName); + mappings.AddArrayType(enumType.DataTypeName); + return mappings.Find(type, dataTypeName, options); + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/ExtraConversionsTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/ExtraConversionsTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..9b5de89736 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/ExtraConversionsTypeInfoResolverFactory.cs @@ -0,0 +1,234 @@ +using System; +using System.Collections.Immutable; +using System.Numerics; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class ExtraConversionResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddInfos(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddInfos(TypeInfoMappingCollection mappings) + { + // Int2 + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + mappings.AddStructType(DataTypeNames.Int2, + static (options, mapping, _) => mapping.CreateInfo(options, new Int2Converter())); + + // Int4 + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + mappings.AddStructType(DataTypeNames.Int4, + static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); + + // Int8 + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + mappings.AddStructType(DataTypeNames.Int8, + static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); + + // Float4 + mappings.AddStructType(DataTypeNames.Float4, + static (options, mapping, _) => mapping.CreateInfo(options, new RealConverter())); + + // Float8 + mappings.AddStructType(DataTypeNames.Float8, + static (options, mapping, _) => mapping.CreateInfo(options, new DoubleConverter())); + + // Numeric + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new DecimalNumericConverter())); + mappings.AddStructType(DataTypeNames.Numeric, + static (options, mapping, _) => mapping.CreateInfo(options, new BigIntegerNumericConverter())); + + // Bytea + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new ArraySegmentByteaConverter())); + mappings.AddStructType>(DataTypeNames.Bytea, + static (options, mapping, _) => mapping.CreateInfo(options, new MemoryByteaConverter())); + + // Varbit + mappings.AddType(DataTypeNames.Varbit, + static (options, mapping, _) => mapping.CreateInfo(options, new StringBitStringConverter())); + + // Bit + mappings.AddType(DataTypeNames.Bit, + static (options, mapping, _) => mapping.CreateInfo(options, new StringBitStringConverter())); + + // Text + // Update PgSerializerOptions.IsWellKnownTextType(Type) after any changes to this list. + mappings.AddType(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArrayTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + mappings.AddStructType>(DataTypeNames.Text, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArraySegmentTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text)); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddType(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArrayTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new ReadOnlyMemoryTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + mappings.AddStructType>(dataTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CharArraySegmentTextConverter(options.TextEncoding), + preferredFormat: DataFormat.Text)); + } + + // Jsonb + const byte jsonbVersion = 1; + mappings.AddType(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new CharArrayTextConverter(options.TextEncoding)))); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new ReadOnlyMemoryTextConverter(options.TextEncoding)))); + mappings.AddStructType>(DataTypeNames.Jsonb, + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter>(jsonbVersion, new CharArraySegmentTextConverter(options.TextEncoding)))); + + // Hstore + mappings.AddType>("hstore", + static (options, mapping, _) => mapping.CreateInfo(options, new HstoreConverter>(options.TextEncoding, result => result.ToImmutableDictionary()))); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddArrayInfos(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddArrayInfos(TypeInfoMappingCollection mappings) + { + // Int2 + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + mappings.AddStructArrayType(DataTypeNames.Int2); + + // Int4 + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + mappings.AddStructArrayType(DataTypeNames.Int4); + + // Int8 + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + mappings.AddStructArrayType(DataTypeNames.Int8); + + // Float4 + mappings.AddStructArrayType(DataTypeNames.Float4); + + // Float8 + mappings.AddStructArrayType(DataTypeNames.Float8); + + // Numeric + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + mappings.AddStructArrayType(DataTypeNames.Numeric); + + // Bytea + mappings.AddStructArrayType>(DataTypeNames.Bytea); + mappings.AddStructArrayType>(DataTypeNames.Bytea); + + // Varbit + mappings.AddArrayType(DataTypeNames.Varbit); + + // Bit + mappings.AddArrayType(DataTypeNames.Bit); + + // Text + mappings.AddArrayType(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + mappings.AddStructArrayType>(DataTypeNames.Text); + + // Alternative text types + foreach(var dataTypeName in new[] { "citext", DataTypeNames.Varchar, + DataTypeNames.Bpchar, DataTypeNames.Json, + DataTypeNames.Xml, DataTypeNames.Name, DataTypeNames.RefCursor }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + mappings.AddStructArrayType>(dataTypeName); + } + + // Jsonb + mappings.AddArrayType(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + mappings.AddStructArrayType>(DataTypeNames.Jsonb); + + // Hstore + mappings.AddArrayType>("hstore"); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/FullTextSearchTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/FullTextSearchTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..272824ad2d --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/FullTextSearchTypeInfoResolverFactory.cs @@ -0,0 +1,93 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class FullTextSearchTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is { SchemaSpan: "pg_catalog", UnqualifiedNameSpan: "tsquery" or "_tsquery" or "tsvector" or "_tsvector" }) + throw new NotSupportedException( + string.Format(NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), typeof(TBuilder).Name)); + + if (type is null) + return; + + if (TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType)) + type = elementType; + + if (Nullable.GetUnderlyingType(type) is { } underlyingType) + type = underlyingType; + + if (type == typeof(NpgsqlTsVector) || typeof(NpgsqlTsQuery).IsAssignableFrom(type)) + throw new NotSupportedException( + string.Format(NpgsqlStrings.FullTextSearchNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), typeof(TBuilder).Name)); + } + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tsvector + mappings.AddType(DataTypeNames.TsVector, + static (options, mapping, _) => mapping.CreateInfo(options, new TsVectorConverter(options.TextEncoding)), isDefault: true); + + // tsquery + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding)), isDefault: true); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + mappings.AddType(DataTypeNames.TsQuery, + static (options, mapping, _) => mapping.CreateInfo(options, new TsQueryConverter(options.TextEncoding))); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // tsvector + mappings.AddArrayType(DataTypeNames.TsVector); + + // tsquery + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + mappings.AddArrayType(DataTypeNames.TsQuery); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/GeometricTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/GeometricTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..a365434f54 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/GeometricTypeInfoResolverFactory.cs @@ -0,0 +1,63 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class GeometricTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(DataTypeNames.Point, + static (options, mapping, _) => mapping.CreateInfo(options, new PointConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Box, + static (options, mapping, _) => mapping.CreateInfo(options, new BoxConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Polygon, + static (options, mapping, _) => mapping.CreateInfo(options, new PolygonConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Line, + static (options, mapping, _) => mapping.CreateInfo(options, new LineConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.LSeg, + static (options, mapping, _) => mapping.CreateInfo(options, new LineSegmentConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Path, + static (options, mapping, _) => mapping.CreateInfo(options, new PathConverter()), isDefault: true); + mappings.AddStructType(DataTypeNames.Circle, + static (options, mapping, _) => mapping.CreateInfo(options, new CircleConverter()), isDefault: true); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddStructArrayType(DataTypeNames.Point); + mappings.AddStructArrayType(DataTypeNames.Box); + mappings.AddStructArrayType(DataTypeNames.Polygon); + mappings.AddStructArrayType(DataTypeNames.Line); + mappings.AddStructArrayType(DataTypeNames.LSeg); + mappings.AddStructArrayType(DataTypeNames.Path); + mappings.AddStructArrayType(DataTypeNames.Circle); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..2515cf9a5b --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs @@ -0,0 +1,188 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; + +namespace Npgsql.Internal.ResolverFactories; + +[RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] +[RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] +sealed class JsonDynamicTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly Type[]? _jsonbClrTypes; + readonly Type[]? _jsonClrTypes; + readonly JsonSerializerOptions? _serializerOptions; + + public JsonDynamicTypeInfoResolverFactory(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + { + _jsonbClrTypes = jsonbClrTypes; + _jsonClrTypes = jsonClrTypes; + _serializerOptions = serializerOptions; + } + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_jsonbClrTypes, _jsonClrTypes, _serializerOptions); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(_jsonbClrTypes, _jsonClrTypes, _serializerOptions); + + // Split into a nested class to avoid erroneous trimming/AOT warnings because the JsonDynamicTypeInfoResolverFactory is marked as incompatible. + internal static class Support + { + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName) + { + if (dataTypeName is { SchemaSpan: "pg_catalog", UnqualifiedNameSpan: "json" or "_json" or "jsonb" or "_jsonb" }) + throw new NotSupportedException( + string.Format( + NpgsqlStrings.DynamicJsonNotEnabled, + type is null || type == typeof(object) ? "" : type.Name, + nameof(NpgsqlSlimDataSourceBuilder.EnableDynamicJson), + typeof(TBuilder).Name)); + } + } + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + class Resolver : DynamicTypeInfoResolver, IPgTypeInfoResolver + { + JsonSerializerOptions? _serializerOptions; + JsonSerializerOptions SerializerOptions + #if NET7_0_OR_GREATER + => _serializerOptions ??= JsonSerializerOptions.Default; + #else + => _serializerOptions ??= new(); + #endif + + readonly Type[] _jsonbClrTypes; + readonly Type[] _jsonClrTypes; + TypeInfoMappingCollection? _mappings; + + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _jsonbClrTypes, _jsonClrTypes, SerializerOptions); + + public Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + { + _jsonbClrTypes = jsonbClrTypes ?? Array.Empty(); + _jsonClrTypes = jsonClrTypes ?? Array.Empty(); + _serializerOptions = serializerOptions; + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, Type[] jsonbClrTypes, Type[] jsonClrTypes, JsonSerializerOptions serializerOptions) + { + // We do GetTypeInfo calls directly so we need a resolver. + serializerOptions.TypeInfoResolver ??= new DefaultJsonTypeInfoResolver(); + + // These live in the RUC/RDC part as JsonValues can contain any .NET type. + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + var jsonb = dataTypeName == DataTypeNames.Jsonb; + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + } + + AddUserMappings(jsonb: true, jsonbClrTypes); + AddUserMappings(jsonb: false, jsonClrTypes); + + void AddUserMappings(bool jsonb, Type[] clrTypes) + { + var dynamicMappings = CreateCollection(); + var dataTypeName = (string)(jsonb ? DataTypeNames.Jsonb : DataTypeNames.Json); + foreach (var jsonType in clrTypes) + { + var jsonTypeInfo = serializerOptions.GetTypeInfo(jsonType); + dynamicMappings.AddMapping(jsonTypeInfo.Type, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, jsonType))); + + if (!jsonType.IsValueType && jsonTypeInfo.PolymorphismOptions is not null) + { + foreach (var derived in jsonTypeInfo.PolymorphismOptions.DerivedTypes) + dynamicMappings.AddMapping(derived.DerivedType, dataTypeName, + factory: (options, mapping, _) => mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, jsonType))); + } + } + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + } + + return mappings; + } + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + // Match all types except null, object and text types as long as DataTypeName (json/jsonb) is present. + if (type is null || type == typeof(object) || PgSerializerOptions.IsWellKnownTextType(type) + || dataTypeName != DataTypeNames.Jsonb && dataTypeName != DataTypeNames.Json) + return null; + + return CreateCollection().AddMapping(type, dataTypeName, (options, mapping, _) => + { + var jsonb = dataTypeName == DataTypeNames.Jsonb; + + // For jsonb we can't properly support polymorphic serialization unless we do quite some additional work + // so we default to mapping.Type instead (exact types will never serialize their "$type" fields, essentially disabling the feature). + var baseType = jsonb ? mapping.Type : typeof(object); + + return mapping.CreateInfo(options, + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, SerializerOptions, baseType)); + }); + } + + static PgConverter CreateSystemTextJsonConverter(Type valueType, bool jsonb, Encoding textEncoding, JsonSerializerOptions serializerOptions, Type baseType) + => (PgConverter)Activator.CreateInstance( + typeof(JsonConverter<,>).MakeGenericType(valueType, baseType), + jsonb, + textEncoding, + serializerOptions)!; + } + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), base.Mappings); + + public ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + : base(jsonbClrTypes, jsonClrTypes, serializerOptions) { } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); + + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, TypeInfoMappingCollection baseMappings) + { + if (baseMappings.Items.Count == 0) + return mappings; + + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + } + + var dynamicMappings = CreateCollection(baseMappings); + foreach (var mapping in baseMappings.Items) + dynamicMappings.AddArrayMapping(mapping.Type, mapping.DataTypeName); + mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..a94d5d36f8 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs @@ -0,0 +1,100 @@ +using System; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class JsonTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + readonly JsonSerializerOptions? _serializerOptions; + + public JsonTypeInfoResolverFactory(JsonSerializerOptions? serializerOptions = null) => _serializerOptions = serializerOptions; + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_serializerOptions); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(_serializerOptions); + + class Resolver : IPgTypeInfoResolver + { + static JsonSerializerOptions? DefaultSerializerOptions; + + readonly JsonSerializerOptions _serializerOptions; + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _serializerOptions); + + public Resolver(JsonSerializerOptions? serializerOptions = null) + { + if (serializerOptions is null) + { + serializerOptions = DefaultSerializerOptions; + if (serializerOptions is null) + { + serializerOptions = new JsonSerializerOptions(); + serializerOptions.TypeInfoResolver = new BasicJsonTypeInfoResolver(); + DefaultSerializerOptions = serializerOptions; + } + } + + _serializerOptions = serializerOptions; + } + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, JsonSerializerOptions serializerOptions) + { + // Jsonb is the first default for JsonDocument + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + var jsonb = dataTypeName == DataTypeNames.Jsonb; + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, + new JsonConverter(jsonb, options.TextEncoding, serializerOptions)), + isDefault: true); + mappings.AddStructType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, + new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + } + + return mappings; + } + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + sealed class BasicJsonTypeInfoResolver : IJsonTypeInfoResolver + { + public JsonTypeInfo? GetTypeInfo(Type type, JsonSerializerOptions options) + { + if (type == typeof(JsonDocument)) + return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonDocumentConverter); + if (type == typeof(JsonElement)) + return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonElementConverter); + return null; + } + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public ArrayResolver(JsonSerializerOptions? serializerOptions = null) + : base(serializerOptions) + { + } + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) + { + mappings.AddArrayType(dataTypeName); + mappings.AddStructArrayType(dataTypeName); + } + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/LTreeTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/LTreeTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..720d8ee78d --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/LTreeTypeInfoResolverFactory.cs @@ -0,0 +1,66 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class LTreeTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is { UnqualifiedNameSpan: "ltree" or "_ltree" or "lquery" or "_lquery" or "ltxtquery" or "_ltxtquery" }) + throw new NotSupportedException( + string.Format(NpgsqlStrings.LTreeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableLTree), + typeof(TBuilder).Name)); + } + + class Resolver : IPgTypeInfoResolver + { + const byte LTreeVersion = 1; + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddType("ltree", + static (options, mapping, _) => mapping.CreateInfo(options, + new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + mappings.AddType("lquery", + static (options, mapping, _) => mapping.CreateInfo(options, + new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + mappings.AddType("ltxtquery", + static (options, mapping, _) => mapping.CreateInfo(options, + new VersionPrefixedTextConverter(LTreeVersion, new StringTextConverter(options.TextEncoding))), + MatchRequirement.DataTypeName); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddArrayType("ltree"); + mappings.AddArrayType("lquery"); + mappings.AddArrayType("ltxtquery"); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..eca3dfec64 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs @@ -0,0 +1,82 @@ +using System; +using System.Net; +using System.Net.NetworkInformation; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class NetworkTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // macaddr + mappings.AddType(DataTypeNames.MacAddr, + static (options, mapping, _) => mapping.CreateInfo(options, new MacaddrConverter(macaddr8: false)), isDefault: true); + mappings.AddType(DataTypeNames.MacAddr8, + static (options, mapping, _) => mapping.CreateInfo(options, new MacaddrConverter(macaddr8: true)), + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }); + + // inet + // There are certain IPAddress values like Loopback or Any that return a *private* derived type (see https://github.com/dotnet/runtime/issues/27870). + // However we still need to be able to resolve some typed converter for those values. + // We do so by returning a boxing info when we deal with a derived type, as a result we don't need an exact typed converter. + // For arrays users can't actually reference the private type so we'll only see some version of ArrayType. + // For reads we'll only see the public type so we never surface an InvalidCastException trying to cast IPAddress to ReadOnlyIPAddress. + // Finally we add a custom predicate to be able to match any type which values are assignable to IPAddress. + mappings.AddType(DataTypeNames.Inet, + static (options, mapping, _) => new PgTypeInfo(options, new IPAddressConverter(), + new DataTypeName(mapping.DataTypeName), unboxedType: mapping.Type == typeof(IPAddress) ? null : mapping.Type), + mapping => mapping with + { + MatchRequirement = MatchRequirement.Single, + TypeMatchPredicate = type => type is null || typeof(IPAddress).IsAssignableFrom(type) + }); + mappings.AddStructType(DataTypeNames.Inet, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlInetConverter())); + + // cidr + mappings.AddStructType(DataTypeNames.Cidr, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlCidrConverter()), isDefault: true); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + // macaddr + mappings.AddArrayType(DataTypeNames.MacAddr); + mappings.AddArrayType(DataTypeNames.MacAddr8); + + // inet + mappings.AddArrayType(DataTypeNames.Inet); + mappings.AddStructArrayType(DataTypeNames.Inet); + + // cidr + mappings.AddStructArrayType(DataTypeNames.Cidr); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/RecordTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/RecordTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..eb7de18a1f --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/RecordTypeInfoResolverFactory.cs @@ -0,0 +1,59 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class RecordTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is { SchemaSpan: "pg_catalog", UnqualifiedNameSpan: "record" or "_record" }) + { + throw new NotSupportedException( + string.Format( + NpgsqlStrings.RecordsNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableRecordsAsTuples), + typeof(TBuilder).Name, + nameof(NpgsqlSlimDataSourceBuilder.EnableRecords))); + } + } + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddType(DataTypeNames.Record, static (options, mapping, _) => + mapping.CreateInfo(options, new RecordConverter(options), supportsWriting: false), + MatchRequirement.DataTypeName); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddArrayType(DataTypeNames.Record); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..189f84a868 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs @@ -0,0 +1,74 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal.ResolverFactories; + +[RequiresUnreferencedCode("Tupled record resolver may perform reflection on trimmed tuple types.")] +[RequiresDynamicCode("Tupled records need to construct a generic converter for a statically unknown (value)tuple type.")] +sealed class TupledRecordTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + [RequiresUnreferencedCode("Tupled record resolver may perform reflection on trimmed tuple types.")] + [RequiresDynamicCode("Tupled records need to construct a generic converter for a statically unknown (value)tuple type.")] + class Resolver : DynamicTypeInfoResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (!(dataTypeName == DataTypeNames.Record && type is { IsConstructedGenericType: true, FullName: not null } && ( + type.FullName.StartsWith("System.Tuple", StringComparison.Ordinal) + || type.FullName.StartsWith("System.ValueTuple", StringComparison.Ordinal)))) + return null; + + return CreateCollection().AddMapping(type, dataTypeName, (options, mapping, _) => + { + var constructors = mapping.Type.GetConstructors(); + ConstructorInfo? constructor = null; + if (constructors.Length is 1) + constructor = constructors[0]; + else + { + var args = mapping.Type.GenericTypeArguments.Length; + foreach (var ctor in constructors) + if (ctor.GetParameters().Length == args) + { + constructor = ctor; + break; + } + } + + if (constructor is null) + throw new InvalidOperationException($"Couldn't find a suitable constructor for record type: {mapping.Type.FullName}"); + + var factory = typeof(Resolver).GetMethod(nameof(CreateFactory), BindingFlags.Static | BindingFlags.NonPublic)! + .MakeGenericMethod(mapping.Type) + .Invoke(null, new object[] { constructor, constructor.GetParameters().Length }); + + var converterType = typeof(RecordConverter<>).MakeGenericType(mapping.Type); + var converter = (PgConverter)Activator.CreateInstance(converterType, options, factory)!; + return mapping.CreateInfo(options, converter, supportsWriting: false); + }); + } + + static Func CreateFactory(ConstructorInfo constructor, int constructorParameters) => array => + { + if (array.Length != constructorParameters) + throw new InvalidCastException($"Cannot read record type with {array.Length} fields as {typeof(T)}"); + return (T)constructor.Invoke(array); + }; + } + + [RequiresUnreferencedCode("Tupled record resolver may perform reflection on trimmed tuple types.")] + [RequiresDynamicCode("Tupled records need to construct a generic converter for a statically unknown (value)tuple type.")] + sealed class ArrayResolver : Resolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..a04c3cc111 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +[RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] +[RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] +sealed class UnmappedTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + public override IPgTypeInfoResolver CreateResolver() => new EnumResolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new EnumArrayResolver(); + + public override IPgTypeInfoResolver CreateRangeResolver() => new RangeResolver(); + public override IPgTypeInfoResolver CreateRangeArrayResolver() => new RangeArrayResolver(); + + public override IPgTypeInfoResolver? CreateMultirangeResolver() => new MultirangeResolver(); + public override IPgTypeInfoResolver? CreateMultirangeArrayResolver() => new MultirangeArrayResolver(); + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + class EnumResolver : DynamicTypeInfoResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + if (type is null || !IsTypeOrNullableOfType(type, static type => type.IsEnum, out var matchedType) || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) + return null; + + return CreateCollection().AddMapping(matchedType, dataTypeName, static (options, mapping, _) => + { + var enumToLabel = new Dictionary(); + var labelToEnum = new Dictionary(); + foreach (var field in mapping.Type.GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var attribute = (PgNameAttribute?)field.GetCustomAttribute(typeof(PgNameAttribute), false); + var enumName = attribute?.PgName ?? options.DefaultNameTranslator.TranslateMemberName(field.Name); + var enumValue = (Enum)field.GetValue(null)!; + + enumToLabel[enumValue] = enumName; + labelToEnum[enumName] = enumValue; + } + + return mapping.CreateInfo(options, (PgConverter)Activator.CreateInstance(typeof(EnumConverter<>).MakeGenericType(mapping.Type), + enumToLabel, labelToEnum, + options.TextEncoding)!); + }); + } + } + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + sealed class EnumArrayResolver : EnumResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + => type is not null && IsArrayLikeType(type, out var elementType) && IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName) + ? base.GetMappings(elementType, elementDataTypeName, options)?.AddArrayMapping(elementType, elementDataTypeName) + : null; + } + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + class RangeResolver : DynamicTypeInfoResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + var matchedType = type; + if ((type is not null && type != typeof(object) && !IsTypeOrNullableOfType(type, + static type => type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>), + out matchedType)) + || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresRangeType rangeType) + return null; + + var subInfo = + matchedType is null + ? options.GetDefaultTypeInfo(rangeType.Subtype) + // Input matchedType here as we don't want an NpgsqlRange over Nullable (it has its own nullability tracking, for better or worse) + : options.GetTypeInfo(matchedType == typeof(object) ? matchedType : matchedType.GetGenericArguments()[0], rangeType.Subtype); + + // We have no generic RangeConverterResolver so we would not know how to compose a range mapping for such infos. + // See https://github.com/npgsql/npgsql/issues/5268 + if (subInfo is not { IsResolverInfo: false }) + return null; + + subInfo = subInfo.ToNonBoxing(); + + var converterType = typeof(NpgsqlRange<>).MakeGenericType(subInfo.Type); + + return CreateCollection().AddMapping(matchedType ?? converterType, dataTypeName, + (options, mapping, _) => + new PgTypeInfo( + options, + (PgConverter)Activator.CreateInstance(typeof(RangeConverter<>).MakeGenericType(subInfo.Type), + subInfo.GetResolution().Converter)!, + new DataTypeName(mapping.DataTypeName), + unboxedType: matchedType is not null && matchedType != converterType ? converterType : null + ) { PreferredFormat = subInfo.PreferredFormat, SupportsWriting = subInfo.SupportsWriting }, + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }); + } + } + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + sealed class RangeArrayResolver : RangeResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + Type? elementType = null; + if ((type is not null && type != typeof(object) && !IsArrayLikeType(type, out elementType)) + || !IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName)) + return null; + + var mappings = base.GetMappings(elementType, elementDataTypeName, options); + elementType ??= mappings?.Find(null, elementDataTypeName, options)?.Type; // Try to get the default mapping. + return elementType is null ? null : mappings?.AddArrayMapping(elementType, elementDataTypeName); + } + } + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + class MultirangeResolver : DynamicTypeInfoResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + Type? elementType = null; + if ((type is not null && type != typeof(object) && !IsArrayLikeType(type, out elementType)) + || elementType is not null && !IsTypeOrNullableOfType(elementType, + static type => type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>), out _) + || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresMultirangeType multirangeType) + return null; + + var subInfo = + type is null + ? options.GetDefaultTypeInfo(multirangeType.Subrange) + : options.GetTypeInfo(elementType ?? typeof(object), multirangeType.Subrange); + + // We have no generic MultirangeConverterResolver so we would not know how to compose a range mapping for such infos. + // See https://github.com/npgsql/npgsql/issues/5268 + if (subInfo is not { IsResolverInfo: false }) + return null; + + subInfo = subInfo.ToNonBoxing(); + + var converterType = subInfo.Type.MakeArrayType(); + + return CreateCollection().AddMapping(type ?? converterType, dataTypeName, + (options, mapping, _) => + new PgTypeInfo( + options, + (PgConverter)Activator.CreateInstance(typeof(MultirangeConverter<,>).MakeGenericType(converterType, subInfo.Type), + subInfo.GetResolution().Converter)!, + new DataTypeName(mapping.DataTypeName), + unboxedType: type is not null && type != converterType ? converterType : null + ) { PreferredFormat = subInfo.PreferredFormat, SupportsWriting = subInfo.SupportsWriting }, + mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }); + } + } + + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + sealed class MultirangeArrayResolver : MultirangeResolver + { + protected override DynamicMappingCollection? GetMappings(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) + { + var elementType = type == typeof(object) ? type : null; + if ((type is not null && type != typeof(object) && !IsArrayLikeType(type, out elementType)) + || !IsArrayDataTypeName(dataTypeName, options, out var elementDataTypeName)) + return null; + + var mappings = base.GetMappings(elementType, elementDataTypeName, options); + elementType ??= mappings?.Find(null, elementDataTypeName, options)?.Type; // Try to get the default mapping. + return elementType is null ? null : mappings?.AddArrayMapping(elementType, elementDataTypeName); + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs b/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs new file mode 100644 index 0000000000..2d47f86807 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using Npgsql.Properties; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class UnsupportedTypeInfoResolver : IPgTypeInfoResolver +{ + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (options.IntrospectionMode) + return null; + + RecordTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); + FullTextSearchTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); + LTreeTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); + + JsonDynamicTypeInfoResolverFactory.Support.ThrowIfUnsupported(type, dataTypeName); + + switch (dataTypeName is null ? null : options.DatabaseInfo.GetPostgresType(dataTypeName.GetValueOrDefault())) + { + case PostgresEnumType: + // Unmapped enum types never work on object or default. + if (type is not null && type != typeof(object)) + throw new NotSupportedException( + string.Format( + NpgsqlStrings.UnmappedEnumsNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + typeof(TBuilder).Name)); + break; + + case PostgresRangeType when !options.RangesEnabled: + throw new NotSupportedException( + string.Format(NpgsqlStrings.RangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRanges), typeof(TBuilder).Name)); + case PostgresRangeType: + throw new NotSupportedException( + string.Format( + NpgsqlStrings.UnmappedRangesNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + typeof(TBuilder).Name)); + + case PostgresMultirangeType when !options.MultirangesEnabled: + throw new NotSupportedException( + string.Format(NpgsqlStrings.MultirangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableMultiranges), typeof(TBuilder).Name)); + case PostgresMultirangeType: + throw new NotSupportedException( + string.Format( + NpgsqlStrings.UnmappedRangesNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + typeof(TBuilder).Name)); + + case PostgresArrayType when !options.ArraysEnabled: + throw new NotSupportedException( + string.Format(NpgsqlStrings.ArraysNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableArrays), typeof(TBuilder).Name)); + } + + if (type is not null) + { + if (TypeInfoMappingCollection.IsArrayLikeType(type, out var elementType) && TypeInfoMappingCollection.IsArrayLikeType(elementType, out _)) + throw new NotSupportedException("Writing is not supported for jagged collections, use a multidimensional array instead."); + + if (typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IList).IsAssignableFrom(type) && type != typeof(string) && (dataTypeName is null || dataTypeName.Value.IsArray)) + throw new NotSupportedException("Writing is not supported for IEnumerable parameters, use an array or some implementation of IList instead."); + } + + return null; + } +} diff --git a/src/Npgsql/Internal/Size.cs b/src/Npgsql/Internal/Size.cs new file mode 100644 index 0000000000..7cbdd9bb20 --- /dev/null +++ b/src/Npgsql/Internal/Size.cs @@ -0,0 +1,74 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public enum SizeKind +{ + Unknown = 0, + Exact, + UpperBound +} + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public readonly struct Size : IEquatable +{ + readonly int _value; + readonly SizeKind _kind; + + Size(SizeKind kind, int value) + { + _value = value; + _kind = kind; + } + + public int Value + { + get + { + if (_kind is SizeKind.Unknown) + ThrowHelper.ThrowInvalidOperationException("Cannot get value from default or Unknown kind"); + return _value; + } + } + + internal int GetValueOrDefault() => _value; + + public SizeKind Kind => _kind; + + public static Size Create(int byteCount) => new(SizeKind.Exact, byteCount); + public static Size CreateUpperBound(int byteCount) => new(SizeKind.UpperBound, byteCount); + public static Size Unknown { get; } = new(SizeKind.Unknown, 0); + public static Size Zero { get; } = new(SizeKind.Exact, 0); + + public Size Combine(Size result) + { + if (_kind is SizeKind.Unknown || result._kind is SizeKind.Unknown) + return Unknown; + + if (_kind is SizeKind.UpperBound || result._kind is SizeKind.UpperBound) + return CreateUpperBound((int)Math.Min((long)(_value + result._value), int.MaxValue)); + + return Create((int)Math.Min((long)(_value + result._value), int.MaxValue)); + } + + public static implicit operator Size(int value) => Create(value); + + string DebuggerDisplay => ToString(); + + public bool Equals(Size other) => _value == other._value && _kind == other.Kind; + public override bool Equals(object? obj) => obj is Size other && Equals(other); + public override int GetHashCode() => HashCode.Combine(_value, (int)_kind); + public static bool operator ==(Size left, Size right) => left.Equals(right); + public static bool operator !=(Size left, Size right) => !left.Equals(right); + + public override string ToString() => _kind switch + { + SizeKind.Exact or SizeKind.UpperBound => $"{_value} ({_kind.ToString()})", + SizeKind.Unknown => nameof(SizeKind.Unknown), + _ => throw new ArgumentOutOfRangeException() + }; +} diff --git a/src/Npgsql/Internal/TransportSecurityHandler.cs b/src/Npgsql/Internal/TransportSecurityHandler.cs new file mode 100644 index 0000000000..e34b2444a7 --- /dev/null +++ b/src/Npgsql/Internal/TransportSecurityHandler.cs @@ -0,0 +1,39 @@ +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; +using Npgsql.Properties; +using Npgsql.Util; + +namespace Npgsql.Internal; + +class TransportSecurityHandler +{ + public virtual bool SupportEncryption => false; + + public virtual Func? RootCertificateCallback + { + get => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); + set => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); + } + + public virtual Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, bool isFirstAttempt) + => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); + + public virtual void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind, + ref bool successfulBind) + => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); +} + +sealed class RealTransportSecurityHandler : TransportSecurityHandler +{ + public override bool SupportEncryption => true; + + public override Func? RootCertificateCallback { get; set; } + + public override Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, bool isFirstAttempt) + => connector.NegotiateEncryption(sslMode, timeout, async, isFirstAttempt); + + public override void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind, + ref bool successfulBind) + => connector.AuthenticateSASLSha256Plus(ref mechanism, ref cbindFlag, ref cbind, ref successfulBind); +} diff --git a/src/Npgsql/Internal/TypeInfoCache.cs b/src/Npgsql/Internal/TypeInfoCache.cs new file mode 100644 index 0000000000..df570ca825 --- /dev/null +++ b/src/Npgsql/Internal/TypeInfoCache.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class TypeInfoCache where TPgTypeId : struct +{ + readonly PgSerializerOptions _options; + readonly bool _validatePgTypeIds; + + // Mostly used for parameter writing, 8ns + readonly ConcurrentDictionary _cacheByClrType = new(); + + // Used for reading, occasionally for parameter writing where a db type was given. + // 8ns, about 10ns total to scan an array with 6, 7 different clr types under one pg type + readonly ConcurrentDictionary _cacheByPgTypeId = new(); + + static TypeInfoCache() + { + if (typeof(TPgTypeId) != typeof(Oid) && typeof(TPgTypeId) != typeof(DataTypeName)) + throw new InvalidOperationException("Cannot use this type argument."); + } + + public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) + { + _options = options; + _validatePgTypeIds = validatePgTypeIds; + } + + /// + /// + /// + /// + /// + /// + /// When this flag is true, and both type and pgTypeId are non null, a default info for the pgTypeId can be returned if an exact match + /// can't be found. + /// + /// + /// + public PgTypeInfo? GetOrAddInfo(Type? type, TPgTypeId? pgTypeId, bool defaultTypeFallback = false) + { + if (pgTypeId is { } id) + { + if (_cacheByPgTypeId.TryGetValue(id, out var infos)) + if (FindMatch(type, infos, defaultTypeFallback) is { } info) + return info; + + return AddEntryById(type, id, infos, defaultTypeFallback); + } + + if (type is not null) + return _cacheByClrType.TryGetValue(type, out var info) ? info : AddByType(type); + + return null; + + PgTypeInfo? FindMatch(Type? type, (Type? Type, PgTypeInfo? Info)[] infos, bool defaultTypeFallback) + { + PgTypeInfo? defaultInfo = null; + var negativeExactMatch = false; + for (var i = 0; i < infos.Length; i++) + { + ref var item = ref infos[i]; + if (item.Type == type) + { + if (item.Info is not null || !defaultTypeFallback) + return item.Info; + negativeExactMatch = true; + } + + if (defaultTypeFallback && item.Type is null) + defaultInfo = item.Info; + } + + // We can only return default info if we've seen a negative match (type: typeof(object), info: null) + // Otherwise we might return a previously requested default while the resolvers could produce the exact match. + return negativeExactMatch ? defaultInfo : null; + } + + PgTypeInfo? AddByType(Type type) + { + // We don't pass PgTypeId as we're interested in default converters here. + var info = CreateInfo(type, null, _options, defaultTypeFallback: false, _validatePgTypeIds); + + return info is null + ? null + : _cacheByClrType.TryAdd(type, info) // We never remove entries so either of these branches will always succeed. + ? info + : _cacheByClrType[type]; + } + + PgTypeInfo? AddEntryById(Type? type, TPgTypeId pgTypeId, (Type? Type, PgTypeInfo? Info)[]? infos, bool defaultTypeFallback) + { + // We cache negatives (null info) to allow 'object or default' checks to never hit the resolvers after the first lookup. + var info = CreateInfo(type, pgTypeId, _options, defaultTypeFallback, _validatePgTypeIds); + + var isDefaultInfo = type is null && info is not null; + if (infos is null) + { + // Also add defaults by their info type to save a future resolver lookup + resize. + infos = isDefaultInfo + ? new [] { (type, info), (info!.Type, info) } + : new [] { (type, info) }; + + if (_cacheByPgTypeId.TryAdd(pgTypeId, infos)) + return info; + } + + // We have to update it instead. + while (true) + { + infos = _cacheByPgTypeId[pgTypeId]; + if (FindMatch(type, infos, defaultTypeFallback) is { } racedInfo) + return racedInfo; + + // Also add defaults by their info type to save a future resolver lookup + resize. + var oldInfos = infos; + var hasExactType = false; + if (isDefaultInfo) + { + foreach (var oldInfo in oldInfos) + if (oldInfo.Type == info!.Type) + hasExactType = true; + } + Array.Resize(ref infos, oldInfos.Length + (isDefaultInfo && !hasExactType ? 2 : 1)); + infos[oldInfos.Length] = (type, info); + if (isDefaultInfo && !hasExactType) + infos[oldInfos.Length + 1] = (info!.Type, info); + + if (_cacheByPgTypeId.TryUpdate(pgTypeId, infos, oldInfos)) + return info; + } + } + + static PgTypeInfo? CreateInfo(Type? type, TPgTypeId? typeId, PgSerializerOptions options, bool defaultTypeFallback, bool validatePgTypeIds) + { + var pgTypeId = AsPgTypeId(typeId); + // Validate that we only pass data types that are supported by the backend. + var dataTypeName = pgTypeId is { } id ? (DataTypeName?)options.DatabaseInfo.GetDataTypeName(id, validate: validatePgTypeIds) : null; + var info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); + if (info is null && defaultTypeFallback) + { + type = null; + info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); + } + + if (info is null) + return null; + + if (pgTypeId is not null && info.PgTypeId != pgTypeId) + throw new InvalidOperationException("A Postgres type was passed but the resolved PgTypeInfo does not have an equal PgTypeId."); + + if (type is not null && !info.IsBoxing && info.Type != type) + throw new InvalidOperationException($"A CLR type '{type}' was passed but the resolved PgTypeInfo does not have an equal Type: {info.Type}."); + + return info; + } + + static PgTypeId? AsPgTypeId(TPgTypeId? pgTypeId) + => pgTypeId switch + { + { } id when typeof(TPgTypeId) == typeof(DataTypeName) => new((DataTypeName)(object)id), + { } id => new((Oid)(object)id), + null => null + }; + } +} diff --git a/src/Npgsql/Internal/TypeInfoMapping.cs b/src/Npgsql/Internal/TypeInfoMapping.cs new file mode 100644 index 0000000000..753c2bcac3 --- /dev/null +++ b/src/Npgsql/Internal/TypeInfoMapping.cs @@ -0,0 +1,776 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.Internal; + +/// +/// +/// +/// +/// +/// +/// Signals whether a resolver based TypeInfo can keep its PgTypeId undecided or whether it should follow mapping.DataTypeName. +/// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public delegate PgTypeInfo TypeInfoFactory(PgSerializerOptions options, TypeInfoMapping mapping, bool resolvedDataTypeName); + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public enum MatchRequirement +{ + /// Match when the clr type and datatype name both match. + /// It's also the only requirement that participates in clr type fallback matching. + All, + /// Match when the datatype name or CLR type matches while the other also matches or is absent. + Single, + /// Match when the datatype name matches and the clr type also matches or is absent. + DataTypeName +} + +/// A factory for well-known PgConverters. +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public static class PgConverterFactory +{ + public static PgConverter CreateArrayMultirangeConverter(PgConverter rangeConverter, PgSerializerOptions options) where T : notnull + => new MultirangeConverter(rangeConverter); + + public static PgConverter> CreateListMultirangeConverter(PgConverter rangeConverter, PgSerializerOptions options) where T : notnull + => new MultirangeConverter, T>(rangeConverter); + + public static PgConverter> CreateRangeConverter(PgConverter subTypeConverter, PgSerializerOptions options) + => new RangeConverter(subTypeConverter); + + public static PgConverter CreatePolymorphicArrayConverter(Func> arrayConverterFactory, Func> nullableArrayConverterFactory, PgSerializerOptions options) + => options.ArrayNullabilityMode switch + { + ArrayNullabilityMode.Never => arrayConverterFactory(), + ArrayNullabilityMode.Always => nullableArrayConverterFactory(), + ArrayNullabilityMode.PerInstance => new PolymorphicArrayConverter(arrayConverterFactory(), nullableArrayConverterFactory()), + _ => throw new ArgumentOutOfRangeException() + }; +} + +[DebuggerDisplay("{DebuggerDisplay,nq}")] +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct TypeInfoMapping +{ + public TypeInfoMapping(Type type, string dataTypeName, TypeInfoFactory factory) + { + Type = type; + // For objects it makes no sense to have clr type only matches by default, there are too many implementations. + MatchRequirement = type == typeof(object) ? MatchRequirement.DataTypeName : MatchRequirement.All; + DataTypeName = Postgres.DataTypeName.NormalizeName(dataTypeName); + Factory = factory; + } + + public TypeInfoFactory Factory { get; init; } + public Type Type { get; init; } + public string DataTypeName { get; init; } + + public MatchRequirement MatchRequirement { get; init; } + public Func? TypeMatchPredicate { get; init; } + + public bool TypeEquals(Type type) => TypeMatchPredicate?.Invoke(type) ?? Type == type; + public bool DataTypeNameEquals(string dataTypeName) + { + var span = DataTypeName.AsSpan(); + return Postgres.DataTypeName.IsFullyQualified(span) + ? span.Equals(dataTypeName.AsSpan(), StringComparison.Ordinal) + : span.Equals(Postgres.DataTypeName.ValidatedName(dataTypeName).UnqualifiedNameSpan, StringComparison.Ordinal); + } + + string DebuggerDisplay + { + get + { + var builder = new StringBuilder() + .Append(Type.Name) + .Append(" <-> ") + .Append(Postgres.DataTypeName.FromDisplayName(DataTypeName).DisplayName); + + if (MatchRequirement is not MatchRequirement.All) + builder.Append($" ({MatchRequirement.ToString().ToLowerInvariant()})"); + + return builder.ToString(); + } + } +} + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public sealed class TypeInfoMappingCollection +{ + readonly TypeInfoMappingCollection? _baseCollection; + readonly List _items; + + public TypeInfoMappingCollection(int capacity = 0) + => _items = new(capacity); + + public TypeInfoMappingCollection() : this(0) { } + + // Not used for resolving, only for composing (arrays that need to find the element mapping etc). + public TypeInfoMappingCollection(TypeInfoMappingCollection baseCollection) : this(0) + => _baseCollection = baseCollection; + + public TypeInfoMappingCollection(IEnumerable items) + => _items = new(items); + + public IReadOnlyList Items => _items; + + /// Returns the first default converter or the first converter that matches both type and dataTypeName. + /// If just a type was passed and no default was found we return the first converter with a type match. + public PgTypeInfo? Find(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + TypeInfoMapping? fallback = null; + foreach (var mapping in _items) + { + var looseTypeMatch = mapping.TypeMatchPredicate is { } pred ? pred(type) : type is null || mapping.Type == type; + var typeMatch = type is not null && looseTypeMatch; + var dataTypeMatch = dataTypeName is not null && mapping.DataTypeNameEquals(dataTypeName.Value.Value); + + var matchRequirement = mapping.MatchRequirement; + if (dataTypeMatch && typeMatch + || matchRequirement is not MatchRequirement.All && dataTypeMatch && looseTypeMatch + || matchRequirement is MatchRequirement.Single && dataTypeName is null && typeMatch) + { + var resolvedDataTypeName = ResolveFullyQualifiedDataTypeName(dataTypeName, mapping.DataTypeName, options); + return mapping.Factory(options, mapping with { Type = type ?? mapping.Type, DataTypeName = resolvedDataTypeName }, dataTypeName is not null); + } + + // DataTypeName is explicitly requiring dataTypeName so it won't be used for a fallback, Single would have matched above already. + if (matchRequirement is MatchRequirement.All && fallback is null && dataTypeName is null && typeMatch) + fallback = mapping; + } + + if (fallback is { } fbMapping) + { + var resolvedDataTypeName = ResolveFullyQualifiedDataTypeName(dataTypeName, fbMapping.DataTypeName, options); + return fbMapping.Factory(options, fbMapping with { Type = type!, DataTypeName = resolvedDataTypeName }, dataTypeName is not null); + } + + return null; + + static string ResolveFullyQualifiedDataTypeName(DataTypeName? dataTypeName, string mappingDataTypeName, PgSerializerOptions options) + { + // Make sure plugins (which match on unqualified names) and converter resolvers get the fully qualified name to canonicalize. + if (dataTypeName is not null) + return dataTypeName.GetValueOrDefault().Value; + + if (TypeInfoMappingHelpers.TryResolveFullyQualifiedName(options, mappingDataTypeName, out var fqDataTypeName)) + return fqDataTypeName.Value; + + throw new NotSupportedException($"Cannot resolve '{mappingDataTypeName}' to a fully qualified datatype name. The datatype was not found in the current database info."); + } + } + + bool TryGetMapping(Type type, string dataTypeName, out TypeInfoMapping value) + { + foreach (var mapping in _baseCollection?._items ?? _items) + { + // During mapping we just use look for the declared type, regardless of TypeMatchPredicate. + if (mapping.Type == type && mapping.DataTypeNameEquals(dataTypeName)) + { + value = mapping; + return true; + } + } + + value = default; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + TypeInfoMapping GetMapping(Type type, string dataTypeName) + => TryGetMapping(type, dataTypeName, out var info) ? info : throw new InvalidOperationException($"Could not find mapping for {type} <-> {dataTypeName}"); + + // Helper to eliminate generic display class duplication. + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) + => (options, mapping, dataTypeNameMatch) => + { + var resolvedInnerMapping = innerMapping; + if (!DataTypeName.IsFullyQualified(innerMapping.DataTypeName.AsSpan())) + resolvedInnerMapping = innerMapping with { DataTypeName = new DataTypeName(mapping.DataTypeName).Schema + "." + innerMapping.DataTypeName }; + + var innerInfo = innerMapping.Factory(options, resolvedInnerMapping, dataTypeNameMatch); + var converter = mapper(mapping, innerInfo); + var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; + var writingSupported = supportsWriting && innerInfo.SupportsWriting; + var unboxedType = ComputeUnboxedType(defaultType: mappingType, converter.TypeToConvert, mapping.Type); + + return new PgTypeInfo(options, converter, options.GetCanonicalTypeId(new DataTypeName(mapping.DataTypeName)), unboxedType) + { + PreferredFormat = preferredFormat, + SupportsWriting = writingSupported + }; + }; + + // Helper to eliminate generic display class duplication. + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) + => (options, mapping, dataTypeNameMatch) => + { + var resolvedInnerMapping = innerMapping; + if (!DataTypeName.IsFullyQualified(innerMapping.DataTypeName.AsSpan())) + resolvedInnerMapping = innerMapping with { DataTypeName = new DataTypeName(mapping.DataTypeName).Schema + "." + innerMapping.DataTypeName }; + + var innerInfo = (PgResolverTypeInfo)innerMapping.Factory(options, resolvedInnerMapping, dataTypeNameMatch); + var resolver = mapper(mapping, innerInfo); + var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; + var writingSupported = supportsWriting && innerInfo.SupportsWriting; + var unboxedType = ComputeUnboxedType(defaultType: mappingType, resolver.TypeToConvert, mapping.Type); + // We include the data type name if the inner info did so as well. + // This way we can rely on its logic around resolvedDataTypeName, including when it ignores that flag. + PgTypeId? pgTypeId = innerInfo.PgTypeId is not null + ? options.GetCanonicalTypeId(new DataTypeName(mapping.DataTypeName)) + : null; + return new PgResolverTypeInfo(options, resolver, pgTypeId, unboxedType) + { + PreferredFormat = preferredFormat, + SupportsWriting = writingSupported + }; + }; + + static Type? ComputeUnboxedType(Type defaultType, Type converterType, Type matchedType) + { + // The minimal hierarchy that should hold for things to work is object < converterType < matchedType. + // Though these types could often be seen in a hierarchy: object < converterType < defaultType < matchedType. + // Some caveats with the latter being for instance Array being the matchedType while the defaultType is int[]. + Debug.Assert(converterType.IsAssignableFrom(matchedType) || matchedType == typeof(object)); + Debug.Assert(converterType.IsAssignableFrom(defaultType)); + + // A special case for object matches, where we return a more specific type than was matched. + // This is to report e.g. Array converters as Array when their matched type was object. + if (matchedType == typeof(object)) + return converterType; + + // This is to report e.g. Array converters as int[,,,] when their matched type was such. + if (matchedType != defaultType) + return matchedType; + + // If defaultType does not equal converterType we take defaultType as it's more specific. + // This is to report e.g. Array converters as int[] when their matched type was their default type. + if (defaultType != converterType) + return defaultType; + + // Keep the converter type. + return null; + } + + public void Add(TypeInfoMapping mapping) => _items.Add(mapping); + + public void AddRange(TypeInfoMappingCollection collection) => _items.AddRange(collection._items); + + Func GetDefaultConfigure(bool isDefault) + => GetDefaultConfigure(isDefault ? MatchRequirement.Single : MatchRequirement.All); + Func GetDefaultConfigure(MatchRequirement matchRequirement) + => matchRequirement switch + { + MatchRequirement.All => static mapping => mapping with { MatchRequirement = MatchRequirement.All }, + MatchRequirement.DataTypeName => static mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }, + MatchRequirement.Single => static mapping => mapping with { MatchRequirement = MatchRequirement.Single }, + _ => throw new ArgumentOutOfRangeException(nameof(matchRequirement), matchRequirement, null) + }; + + Func GetArrayTypeMatchPredicate(Func elementTypeMatchPredicate) + => type => type is null ? elementTypeMatchPredicate(null) : type.IsArray && elementTypeMatchPredicate(type.GetElementType()!); + Func GetListTypeMatchPredicate(Func elementTypeMatchPredicate) + => type => type is null ? elementTypeMatchPredicate(null) + // We anti-constrain on IsArray to avoid matching byte/sbyte, short/ushort int/uint + // with the list mapping of the earlier type when an exact match is probably available. + : !type.IsArray && typeof(IList).IsAssignableFrom(type) && elementTypeMatchPredicate(typeof(TElement)); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(isDefault)); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : class + => AddType(dataTypeName, createInfo, GetDefaultConfigure(matchRequirement)); + + public void AddType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : class + { + var mapping = new TypeInfoMapping(typeof(T), dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + if (typeof(T) != typeof(object) && mapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single && !TryGetMapping(typeof(object), mapping.DataTypeName, out _)) + _items.Add(new TypeInfoMapping(typeof(object), dataTypeName, + CreateComposedFactory(typeof(T), mapping, static (_, info) => info.GetResolution().Converter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement + }); + _items.Add(mapping); + } + + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : class + => AddResolverType(dataTypeName, createInfo, GetDefaultConfigure(isDefault)); + + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : class + => AddResolverType(dataTypeName, createInfo, GetDefaultConfigure(matchRequirement)); + + public void AddResolverType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : class + { + var mapping = new TypeInfoMapping(typeof(T), dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + if (typeof(T) != typeof(object) && mapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single && !TryGetMapping(typeof(object), mapping.DataTypeName, out _)) + _items.Add(new TypeInfoMapping(typeof(object), dataTypeName, + CreateComposedFactory(typeof(T), mapping, static (_, info) => info.GetConverterResolver(), copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement + }); + _items.Add(mapping); + } + + + public void AddArrayType(string elementDataTypeName) where TElement : class + => AddArrayType(GetMapping(typeof(TElement), elementDataTypeName), suppressObjectMapping: false); + + public void AddArrayType(string elementDataTypeName, bool suppressObjectMapping) where TElement : class + => AddArrayType(GetMapping(typeof(TElement), elementDataTypeName), suppressObjectMapping); + + public void AddArrayType(TypeInfoMapping elementMapping) where TElement : class + => AddArrayType(elementMapping, suppressObjectMapping: false); + + public void AddArrayType(TypeInfoMapping elementMapping, bool suppressObjectMapping) where TElement : class + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + AddArrayType(elementMapping, typeof(TElement[]), CreateArrayBasedConverter, arrayTypeMatchPredicate, suppressObjectMapping: suppressObjectMapping || TryGetMapping(typeof(object), arrayDataTypeName, out _)); + AddArrayType(elementMapping, typeof(IList), CreateListBasedConverter, listTypeMatchPredicate, suppressObjectMapping: true); + + void AddArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) + { + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(arrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + if (!dataTypeNameMatch) + throw new InvalidOperationException("Should not happen, please file a bug."); + + return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + })); + } + } + + public void AddResolverArrayType(string elementDataTypeName) where TElement : class + => AddResolverArrayType(GetMapping(typeof(TElement), elementDataTypeName), suppressObjectMapping: false); + + public void AddResolverArrayType(string elementDataTypeName, bool suppressObjectMapping) where TElement : class + => AddResolverArrayType(GetMapping(typeof(TElement), elementDataTypeName), suppressObjectMapping); + + public void AddResolverArrayType(TypeInfoMapping elementMapping) where TElement : class + => AddResolverArrayType(elementMapping, suppressObjectMapping: false); + + public void AddResolverArrayType(TypeInfoMapping elementMapping, bool suppressObjectMapping) where TElement : class + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + AddResolverArrayType(elementMapping, typeof(TElement[]), CreateArrayBasedConverterResolver, arrayTypeMatchPredicate, suppressObjectMapping: suppressObjectMapping || TryGetMapping(typeof(object), arrayDataTypeName, out _)); + AddResolverArrayType(elementMapping, typeof(IList), CreateListBasedConverterResolver, listTypeMatchPredicate, suppressObjectMapping: true); + + void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) + { + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(arrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + if (!dataTypeNameMatch) + throw new InvalidOperationException("Should not happen, please file a bug."); + + return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + })); + } + } + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetResolution().GetConverter()), GetDefaultConfigure(isDefault)); + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetResolution().GetConverter()), GetDefaultConfigure(matchRequirement)); + + public void AddStructType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : struct + => AddStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverter(innerInfo.GetResolution().GetConverter()), configure); + + // Lives outside to prevent capture of T. + void AddStructType(Type type, Type nullableType, string dataTypeName, TypeInfoFactory createInfo, + Func nullableConverter, Func? configure) + { + var mapping = new TypeInfoMapping(type, dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + if (type != typeof(object) && mapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single && !TryGetMapping(typeof(object), mapping.DataTypeName, out _)) + _items.Add(new TypeInfoMapping(typeof(object), dataTypeName, + CreateComposedFactory(type, mapping, static (_, info) => info.GetResolution().Converter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement + }); + _items.Add(mapping); + _items.Add(new TypeInfoMapping(nullableType, dataTypeName, + CreateComposedFactory(nullableType, mapping, nullableConverter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement, + TypeMatchPredicate = mapping.TypeMatchPredicate is not null + ? matchType => matchType is null + ? mapping.TypeMatchPredicate(null) + : matchType == nullableType && mapping.TypeMatchPredicate(type) + : null + }); + } + + public void AddStructArrayType(string elementDataTypeName) where TElement : struct + => AddStructArrayType(GetMapping(typeof(TElement), elementDataTypeName), GetMapping(typeof(TElement?), elementDataTypeName), suppressObjectMapping: false); + + public void AddStructArrayType(string elementDataTypeName, bool suppressObjectMapping) where TElement : struct + => AddStructArrayType(GetMapping(typeof(TElement), elementDataTypeName), GetMapping(typeof(TElement?), elementDataTypeName), suppressObjectMapping); + + public void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping) where TElement : struct + => AddStructArrayType(elementMapping, nullableElementMapping, suppressObjectMapping: false); + + public void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, bool suppressObjectMapping) where TElement : struct + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || type == typeof(TElement?))); + var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableListTypeMatchPredicate = GetListTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || type == typeof(TElement?))); + + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + AddStructArrayType(elementMapping, nullableElementMapping, typeof(TElement[]), typeof(TElement?[]), + CreateArrayBasedConverter, CreateArrayBasedConverter, + arrayTypeMatchPredicate, nullableArrayTypeMatchPredicate, suppressObjectMapping: suppressObjectMapping || TryGetMapping(typeof(object), arrayDataTypeName, out _)); + + // Don't add the object converter for the list based converter. + AddStructArrayType(elementMapping, nullableElementMapping, typeof(IList), typeof(IList), + CreateListBasedConverter, CreateListBasedConverter, + listTypeMatchPredicate, nullableListTypeMatchPredicate, suppressObjectMapping: true); + } + + // Lives outside to prevent capture of TElement. + void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, Type type, Type nullableType, + Func converter, Func nullableConverter, + Func? typeMatchPredicate, Func? nullableTypeMatchPredicate, bool suppressObjectMapping) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + { + MatchRequirement = arrayMapping.MatchRequirement, + TypeMatchPredicate = nullableTypeMatchPredicate + }; + + _items.Add(arrayMapping); + _items.Add(nullableArrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + { + return options.ArrayNullabilityMode switch + { + _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( + arrayMapping.Factory(options, mapping, dataTypeNameMatch), + nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + mapping.DataTypeName + ), + _ => throw new ArgumentOutOfRangeException() + }; + }) { MatchRequirement = MatchRequirement.DataTypeName }); + + PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullableInnerTypeInfo, string dataTypeName) + { + var converter = + new PolymorphicArrayConverter( + innerTypeInfo.GetResolution().GetConverter(), + nullableInnerTypeInfo.GetResolution().GetConverter()); + + return new PgTypeInfo(innerTypeInfo.Options, converter, + innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName)), unboxedType: typeof(Array)) { SupportsWriting = false }; + } + } + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), GetDefaultConfigure(isDefault)); + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, MatchRequirement matchRequirement) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), GetDefaultConfigure(matchRequirement)); + + public void AddResolverStructType(string dataTypeName, TypeInfoFactory createInfo, Func? configure) where T : struct + => AddResolverStructType(typeof(T), typeof(T?), dataTypeName, createInfo, + static (_, innerInfo) => new NullableConverterResolver(innerInfo), configure); + + // Lives outside to prevent capture of T. + void AddResolverStructType(Type type, Type nullableType, string dataTypeName, TypeInfoFactory createInfo, + Func nullableConverter, Func? configure) + { + var mapping = new TypeInfoMapping(type, dataTypeName, createInfo); + mapping = configure?.Invoke(mapping) ?? mapping; + if (type != typeof(object) && mapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single && !TryGetMapping(typeof(object), mapping.DataTypeName, out _)) + _items.Add(new TypeInfoMapping(typeof(object), dataTypeName, + CreateComposedFactory(type, mapping, static (_, info) => info.GetConverterResolver(), copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement + }); + _items.Add(mapping); + _items.Add(new TypeInfoMapping(nullableType, dataTypeName, + CreateComposedFactory(nullableType, mapping, nullableConverter, copyPreferredFormat: true)) + { + MatchRequirement = mapping.MatchRequirement, + TypeMatchPredicate = mapping.TypeMatchPredicate is not null + ? matchType => matchType is null + ? mapping.TypeMatchPredicate(null) + : matchType == nullableType && mapping.TypeMatchPredicate(type) + : null + }); + } + + public void AddResolverStructArrayType(string elementDataTypeName) where TElement : struct + => AddResolverStructArrayType(GetMapping(typeof(TElement), elementDataTypeName), GetMapping(typeof(TElement?), elementDataTypeName), suppressObjectMapping: false); + + public void AddResolverStructArrayType(string elementDataTypeName, bool suppressObjectMapping) where TElement : struct + => AddResolverStructArrayType(GetMapping(typeof(TElement), elementDataTypeName), GetMapping(typeof(TElement?), elementDataTypeName), suppressObjectMapping); + + public void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping) where TElement : struct + => AddResolverStructArrayType(elementMapping, nullableElementMapping, suppressObjectMapping: false); + + public void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, bool suppressObjectMapping) where TElement : struct + { + // Always use a predicate to match all dimensions. + var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || type == typeof(TElement?))); + var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement))); + var nullableListTypeMatchPredicate = GetListTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type => + type is null || type == typeof(TElement?))); + + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + AddResolverStructArrayType(elementMapping, nullableElementMapping, typeof(TElement[]), typeof(TElement?[]), + CreateArrayBasedConverterResolver, + CreateArrayBasedConverterResolver, suppressObjectMapping: suppressObjectMapping || TryGetMapping(typeof(object), arrayDataTypeName, out _), arrayTypeMatchPredicate, nullableArrayTypeMatchPredicate); + + // Don't add the object converter for the list based converter. + AddResolverStructArrayType(elementMapping, nullableElementMapping, typeof(IList), typeof(IList), + CreateListBasedConverterResolver, + CreateListBasedConverterResolver, suppressObjectMapping: true, listTypeMatchPredicate, nullableListTypeMatchPredicate); + } + + // Lives outside to prevent capture of TElement. + void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullableElementMapping, Type type, Type nullableType, + Func converter, Func nullableConverter, + bool suppressObjectMapping, Func? typeMatchPredicate, Func? nullableTypeMatchPredicate) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = nullableTypeMatchPredicate + }; + + _items.Add(arrayMapping); + _items.Add(nullableArrayMapping); + suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); + if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => options.ArrayNullabilityMode switch + { + _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( + arrayMapping.Factory(options, mapping, dataTypeNameMatch), + nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + mapping.DataTypeName + ), + _ => throw new ArgumentOutOfRangeException() + }) { MatchRequirement = MatchRequirement.DataTypeName }); + + PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullableInnerTypeInfo, string dataTypeName) + { + var resolver = + new PolymorphicArrayConverterResolver((PgResolverTypeInfo)innerTypeInfo, + (PgResolverTypeInfo)nullableInnerTypeInfo); + + return new PgResolverTypeInfo(innerTypeInfo.Options, resolver, + innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName))) { SupportsWriting = false }; + } + } + + public void AddPolymorphicResolverArrayType(string elementDataTypeName, Func> elementToArrayConverterFactory) + => AddPolymorphicResolverArrayType(GetMapping(typeof(object), elementDataTypeName), elementToArrayConverterFactory); + + public void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Func> elementToArrayConverterFactory) + { + AddPolymorphicResolverArrayType(elementMapping, typeof(object), + (mapping, elemInfo) => new ArrayPolymorphicConverterResolver( + elemInfo.Options.GetCanonicalTypeId(new DataTypeName(mapping.DataTypeName)), elemInfo, elementToArrayConverterFactory(elemInfo.Options)) + , null); + + void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate) + { + var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); + var mapping = new TypeInfoMapping(type, arrayDataTypeName, + CreateComposedFactory(typeof(Array), elementMapping, converter, supportsWriting: false)) + { + MatchRequirement = elementMapping.MatchRequirement, + TypeMatchPredicate = typeMatchPredicate + }; + _items.Add(mapping); + } + } + + /// Returns whether type matches any of the types we register pg arrays as. + [UnconditionalSuppressMessage("Trimming", "IL2070", + Justification = "Checking for IList implementing types requires interface list enumeration which isn't compatible with trimming. " + + "However as long as a concrete IList is rooted somewhere in the app, for instance through an `AddArrayType(...)` mapping, every implementation must keep it.")] + // We care about IList implementations if the instantiation is actually rooted by us through an Array mapping. + // Dynamic resolvers are a notable counterexample, but they are all correctly marked with RequiresUnreferencedCode. + public static bool IsArrayLikeType(Type type, [NotNullWhen(true)] out Type? elementType) + { + if (type.GetElementType() is { } t) + { + elementType = t; + return true; + } + + if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() is var def && (def == typeof(List<>) || def == typeof(IList<>))) + { + elementType = type.GetGenericArguments()[0]; + return true; + } + + foreach (var inf in type.GetInterfaces()) + { + if (inf.IsConstructedGenericType && inf.GetGenericTypeDefinition() == typeof(IList<>)) + { + elementType = inf.GetGenericArguments()[0]; + return true; + } + } + + elementType = null; + return false; + } + + static string GetArrayDataTypeName(string dataTypeName) + => DataTypeName.IsFullyQualified(dataTypeName.AsSpan()) + ? DataTypeName.ValidatedName(dataTypeName).ToArrayName().Value + : "_" + DataTypeName.FromDisplayName(dataTypeName).UnqualifiedName; + + static ArrayBasedArrayConverter CreateArrayBasedConverter(TypeInfoMapping mapping, PgTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayBasedArrayConverter(elemInfo.GetResolution(), mapping.Type); + + ThrowBoxingNotSupported(resolver: false); + return default; + } + + static ListBasedArrayConverter, TElement> CreateListBasedConverter(TypeInfoMapping mapping, PgTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ListBasedArrayConverter, TElement>(elemInfo.GetResolution()); + + ThrowBoxingNotSupported(resolver: false); + return default; + } + + static ArrayConverterResolver CreateArrayBasedConverterResolver(TypeInfoMapping mapping, PgResolverTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayConverterResolver(elemInfo, mapping.Type); + + ThrowBoxingNotSupported(resolver: true); + return default; + } + + static ArrayConverterResolver, TElement> CreateListBasedConverterResolver(TypeInfoMapping mapping, PgResolverTypeInfo elemInfo) + { + if (!elemInfo.IsBoxing) + return new ArrayConverterResolver, TElement>(elemInfo, mapping.Type); + + ThrowBoxingNotSupported(resolver: true); + return default; + } + + [DoesNotReturn] + static void ThrowBoxingNotSupported(bool resolver) + => throw new InvalidOperationException($"Boxing converters are not supported, manually construct a mapping over a casting converter{(resolver ? " resolver" : "")} instead."); +} + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public static class TypeInfoMappingHelpers +{ + internal static bool TryResolveFullyQualifiedName(PgSerializerOptions options, string dataTypeName, out DataTypeName fqDataTypeName) + { + if (DataTypeName.IsFullyQualified(dataTypeName.AsSpan())) + { + fqDataTypeName = new DataTypeName(dataTypeName); + return true; + } + + if (options.DatabaseInfo.TryGetPostgresTypeByName(dataTypeName, out var pgType)) + { + fqDataTypeName = pgType.DataTypeName; + return true; + } + + fqDataTypeName = default; + return false; + } + + internal static PostgresType GetPgType(this TypeInfoMapping mapping, PgSerializerOptions options) + => options.DatabaseInfo.GetPostgresType(new DataTypeName(mapping.DataTypeName)); + + public static PgTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverter converter, DataFormat? preferredFormat = null, bool supportsWriting = true) + => new(options, converter, new DataTypeName(mapping.DataTypeName)) + { + PreferredFormat = preferredFormat, + SupportsWriting = supportsWriting + }; + + public static PgResolverTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverterResolver resolver, bool includeDataTypeName = true, DataFormat? preferredFormat = null, bool supportsWriting = true) + => new(options, resolver, includeDataTypeName ? new DataTypeName(mapping.DataTypeName) : null) + { + PreferredFormat = preferredFormat, + SupportsWriting = supportsWriting + }; +} diff --git a/src/Npgsql/Internal/ValueMetadata.cs b/src/Npgsql/Internal/ValueMetadata.cs new file mode 100644 index 0000000000..b71028c4a1 --- /dev/null +++ b/src/Npgsql/Internal/ValueMetadata.cs @@ -0,0 +1,12 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct ValueMetadata +{ + public required DataFormat Format { get; init; } + public required Size BufferRequirement { get; init; } + public required Size Size { get; init; } + public object? WriteState { get; init; } +} diff --git a/src/Npgsql/KerberosUsernameProvider.cs b/src/Npgsql/KerberosUsernameProvider.cs index 52ceb67bd8..a962a6fdc2 100644 --- a/src/Npgsql/KerberosUsernameProvider.cs +++ b/src/Npgsql/KerberosUsernameProvider.cs @@ -1,93 +1,125 @@ using System; using System.Diagnostics; using System.IO; -using System.Linq; -using Npgsql.Logging; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; -namespace Npgsql -{ - /// - /// Launches MIT Kerberos klist and parses out the default principal from it. - /// Caches the result. - /// - class KerberosUsernameProvider - { - static bool _performedDetection; - static string? _principalWithRealm; - static string? _principalWithoutRealm; +namespace Npgsql; - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(KerberosUsernameProvider)); +/// +/// Launches MIT Kerberos klist and parses out the default principal from it. +/// Caches the result. +/// +sealed class KerberosUsernameProvider +{ + static bool _performedDetection; + static string? _principalWithRealm; + static string? _principalWithoutRealm; - internal static string? GetUsername(bool includeRealm) + internal static ValueTask GetUsername(bool async, bool includeRealm, ILogger connectionLogger, CancellationToken cancellationToken) + { + if (_performedDetection) + return new(includeRealm ? _principalWithRealm : _principalWithoutRealm); + var klistPath = FindInPath("klist"); + if (klistPath == null) { - if (!_performedDetection) - { - DetectUsername(); - _performedDetection = true; - } - return includeRealm ? _principalWithRealm : _principalWithoutRealm; + connectionLogger.LogDebug("klist not found in PATH, skipping Kerberos username detection"); + return new((string?)null); } + var processStartInfo = new ProcessStartInfo + { + FileName = klistPath, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false + }; - static void DetectUsername() + var process = Process.Start(processStartInfo); + if (process is null) { - var klistPath = FindInPath("klist"); - if (klistPath == null) - { - Log.Debug("klist not found in PATH, skipping Kerberos username detection"); - return; - } + connectionLogger.LogDebug("klist process could not be started"); + return new((string?)null); + } - var processStartInfo = new ProcessStartInfo - { - FileName = klistPath, - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false - }; - var process = Process.Start(processStartInfo); - if (process is null) - { - Log.Debug($"klist process could not be started"); - return; - } + return GetUsernameAsyncInternal(); +#pragma warning disable CS1998 + async ValueTask GetUsernameAsyncInternal() +#pragma warning restore CS1998 + { +#if NET5_0_OR_GREATER + if (async) + await process.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + else + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + process.WaitForExit(); +#else + // ReSharper disable once MethodHasAsyncOverload process.WaitForExit(); +#endif + if (process.ExitCode != 0) { - Log.Debug($"klist exited with code {process.ExitCode}: {process.StandardError.ReadToEnd()}"); - return; + connectionLogger.LogDebug($"klist exited with code {process.ExitCode}: {process.StandardError.ReadToEnd()}"); + return null; } var line = default(string); for (var i = 0; i < 2; i++) + // ReSharper disable once MethodHasAsyncOverload +#if NET7_0_OR_GREATER + if ((line = async ? await process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) : process.StandardOutput.ReadLine()) == null) +#elif NET5_0_OR_GREATER + if ((line = async ? await process.StandardOutput.ReadLineAsync().ConfigureAwait(false) : process.StandardOutput.ReadLine()) == null) +#else if ((line = process.StandardOutput.ReadLine()) == null) +#endif { - Log.Debug("Unexpected output from klist, aborting Kerberos username detection"); - return; + connectionLogger.LogDebug("Unexpected output from klist, aborting Kerberos username detection"); + return null; } - var components = line!.Split(':'); - if (components.Length != 2) - { - Log.Debug("Unexpected output from klist, aborting Kerberos username detection"); - return; - } + return ParseKListOutput(line!, includeRealm, connectionLogger); + } + } - var principalWithRealm = components[1].Trim(); - components = principalWithRealm.Split('@'); - if (components.Length != 2) - { - Log.Debug($"Badly-formed default principal {principalWithRealm} from klist, aborting Kerberos username detection"); - return; - } + static string? ParseKListOutput(string line, bool includeRealm, ILogger connectionLogger) + { + var colonIndex = line.IndexOf(':'); + var colonLastIndex = line.LastIndexOf(':'); + if (colonIndex == -1 || colonIndex != colonLastIndex) + { + connectionLogger.LogDebug("Unexpected output from klist, aborting Kerberos username detection"); + return null; + } + var secondPart = line.AsSpan(1 + line.IndexOf(':')); - _principalWithRealm = principalWithRealm; - _principalWithoutRealm = components[0]; + var principalWithRealm = secondPart.Trim(); + var atIndex = principalWithRealm.IndexOf('@'); + var atLastIndex = principalWithRealm.LastIndexOf('@'); + if (atIndex == -1 || atIndex != atLastIndex) + { + connectionLogger.LogDebug( + $"Badly-formed default principal {principalWithRealm.ToString()} from klist, aborting Kerberos username detection"); + return null; + } + + _principalWithRealm = principalWithRealm.ToString(); + _principalWithoutRealm = principalWithRealm.Slice(0, atIndex).ToString(); + _performedDetection = true; + return includeRealm ? _principalWithRealm : _principalWithoutRealm; + } + + static string? FindInPath(string name) + { + foreach (var p in Environment.GetEnvironmentVariable("PATH")?.Split(Path.PathSeparator) ?? Array.Empty()) + { + var path = Path.Combine(p, name); + if (File.Exists(path)) + return path; } - static string? FindInPath(string name) => Environment.GetEnvironmentVariable("PATH") - ?.Split(Path.PathSeparator) - .Select(p => Path.Combine(p, name)) - .FirstOrDefault(File.Exists); + return null; } } diff --git a/src/Npgsql/LogMessages.cs b/src/Npgsql/LogMessages.cs new file mode 100644 index 0000000000..8d5f471c27 --- /dev/null +++ b/src/Npgsql/LogMessages.cs @@ -0,0 +1,510 @@ +using System; +using System.Collections.Generic; +using System.Data; +using Microsoft.Extensions.Logging; +using NpgsqlTypes; + +namespace Npgsql; + +// ReSharper disable InconsistentNaming +#pragma warning disable SYSLIB1015 // Argument is not referenced from the logging message +#pragma warning disable SYSLIB1006 // Multiple logging methods are using event id + +static partial class LogMessages +{ + #region Connection + + [LoggerMessage( + EventId = NpgsqlEventId.OpeningConnection, + Level = LogLevel.Trace, + Message = "Opening connection to {Host}:{Port}/{Database}...")] + internal static partial void OpeningConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); + + [LoggerMessage( + EventId = NpgsqlEventId.OpenedConnection, + Level = LogLevel.Debug, + Message = "Opened connection to {Host}:{Port}/{Database}")] + internal static partial void OpenedConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.OpenedConnection, + Level = LogLevel.Debug, + Message = "Opened multiplexing connection to {Host}:{Port}/{Database}")] + internal static partial void OpenedMultiplexingConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); + + [LoggerMessage( + EventId = NpgsqlEventId.ClosingConnection, + Level = LogLevel.Trace, + Message = "Closing connection to {Host}:{Port}/{Database}...")] + internal static partial void ClosingConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ClosedConnection, + Level = LogLevel.Debug, + Message = "Closed connection to {Host}:{Port}/{Database}")] + internal static partial void ClosedConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ClosedConnection, + Level = LogLevel.Debug, + Message = "Closed multiplexing connection to {Host}:{Port}/{Database}")] + internal static partial void ClosedMultiplexingConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); + + [LoggerMessage( + EventId = NpgsqlEventId.OpeningPhysicalConnection, + Level = LogLevel.Trace, + Message = "Opening physical connection to {Host}:{Port}/{Database}...")] + internal static partial void OpeningPhysicalConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); + + [LoggerMessage( + EventId = NpgsqlEventId.OpenedPhysicalConnection, + Level = LogLevel.Debug, + Message = "Opened physical connection to {Host}:{Port}/{Database} (in {DurationMs}ms)")] + internal static partial void OpenedPhysicalConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, long DurationMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ClosingPhysicalConnection, + Level = LogLevel.Trace, + Message = "Closing physical connection to {Host}:{Port}/{Database}...")] + internal static partial void ClosingPhysicalConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ClosedPhysicalConnection, + Level = LogLevel.Debug, + Message = "Closed physical connection to {Host}:{Port}/{Database}")] + internal static partial void ClosedPhysicalConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingWait, + Level = LogLevel.Information, + Message = "Starting to wait (timeout={TimeoutMs}ms)...")] + internal static partial void StartingWait(ILogger logger, int TimeoutMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ReceivedNotice, + Level = LogLevel.Debug, + Message = "Received notice: {NoticeText}")] + internal static partial void ReceivedNotice(ILogger logger, string NoticeText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ConnectionExceededMaximumLifetime, + Level = LogLevel.Debug, + Message = "Connection has exceeded its maximum lifetime ('{ConnectionMaximumLifeTime}') and will be closed.")] + internal static partial void ConnectionExceededMaximumLifetime(ILogger logger, TimeSpan ConnectionMaximumLifeTime, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.SendingKeepalive, + Level = LogLevel.Trace, + Message = "Sending keepalive...")] + internal static partial void SendingKeepalive(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CompletedKeepalive, + Level = LogLevel.Trace, + Message = "Completed keepalive")] + internal static partial void CompletedKeepalive(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.KeepaliveFailed, + Level = LogLevel.Trace, + Message = "Keepalive failed")] + internal static partial void KeepaliveFailed(ILogger logger, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.BreakingConnection, + Level = LogLevel.Trace, + Message = "Breaking connection")] + internal static partial void BreakingConnection(ILogger logger, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.CaughtUserExceptionInNoticeEventHandler, + Level = LogLevel.Error, + Message = "User exception caught when emitting notice event")] + internal static partial void CaughtUserExceptionInNoticeEventHandler(ILogger logger, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.CaughtUserExceptionInNotificationEventHandler, + Level = LogLevel.Error, + Message = "User exception caught when emitting notification event")] + internal static partial void CaughtUserExceptionInNotificationEventHandler(ILogger logger, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.ExceptionWhenClosingPhysicalConnection, + Level = LogLevel.Warning, + Message = "Exception while closing connector")] + internal static partial void ExceptionWhenClosingPhysicalConnection(ILogger logger, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.ExceptionWhenOpeningConnectionForMultiplexing, + Level = LogLevel.Error, + Message = "Exception opening a connection for multiplexing")] + internal static partial void ExceptionWhenOpeningConnectionForMultiplexing(ILogger logger, Exception exception); + + [LoggerMessage( + Level = LogLevel.Trace, + Message = "Start user action")] + internal static partial void StartUserAction(ILogger logger, int ConnectorId); + + [LoggerMessage( + Level = LogLevel.Trace, + Message = "End user action")] + internal static partial void EndUserAction(ILogger logger, int ConnectorId); + + #endregion Connection + + #region Command + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingCommand, + Level = LogLevel.Debug, + Message = "Executing command: {CommandText}", + SkipEnabledCheck = true)] + internal static partial void ExecutingCommand(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingCommand, + Level = LogLevel.Debug, + Message = "Executing command: {CommandText}\n Parameters: {Parameters}", + SkipEnabledCheck = true)] + internal static partial void ExecutingCommandWithParameters(ILogger logger, string CommandText, IEnumerable Parameters, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingCommand, + Level = LogLevel.Debug, + Message = "Executing batch: {BatchCommands}", + SkipEnabledCheck = true)] + internal static partial void ExecutingBatch(ILogger logger, string[] BatchCommands, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingCommand, + Level = LogLevel.Debug, + Message = "Executing batch: {BatchCommands}", + SkipEnabledCheck = true)] + internal static partial void ExecutingBatchWithParameters(ILogger logger, (string CommandText, object[] Parameters)[] BatchCommands, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommandExecutionCompleted, + Level = LogLevel.Information, + Message = "Command execution completed (duration={DurationMs}ms): {CommandText}", + SkipEnabledCheck = true)] + internal static partial void CommandExecutionCompleted(ILogger logger, string CommandText, long DurationMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommandExecutionCompleted, + Level = LogLevel.Information, + Message = "Command execution completed (duration={DurationMs}ms): {CommandText}\n Parameters: {Parameters}", + SkipEnabledCheck = true)] + internal static partial void CommandExecutionCompletedWithParameters(ILogger logger, string CommandText, IEnumerable Parameters, long DurationMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommandExecutionCompleted, + Level = LogLevel.Information, + Message = "Batch execution completed (duration={DurationMs}ms): {BatchCommands}", + SkipEnabledCheck = true)] + internal static partial void BatchExecutionCompleted(ILogger logger, string[] BatchCommands, long DurationMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommandExecutionCompleted, + Level = LogLevel.Information, + Message = "Batch execution completed (duration={DurationMs}ms): {BatchCommands}", + SkipEnabledCheck = true)] + internal static partial void BatchExecutionCompletedWithParameters( + ILogger logger, (string CommandText, object[] Parameters)[] BatchCommands, long DurationMs, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CancellingCommand, + Level = LogLevel.Debug, + Message = "Sending PostgreSQL cancellation...")] + internal static partial void CancellingCommand(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingInternalCommand, + Level = LogLevel.Debug, + Message = "Executing internal command: {CommandText}")] + internal static partial void ExecutingInternalCommand(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.PreparingCommandExplicitly, + Level = LogLevel.Debug, + Message = "Preparing command explicitly: {CommandText}", + SkipEnabledCheck = true)] + internal static partial void PreparingCommandExplicitly(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommandPreparedExplicitly, + Level = LogLevel.Information, + Message = "Prepared command explicitly")] + internal static partial void CommandPreparedExplicitly(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.AutoPreparingStatement, + Level = LogLevel.Debug, + Message = "Auto-preparing statement: {CommandText}")] + internal static partial void AutoPreparingStatement(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.UnpreparingCommand, + Level = LogLevel.Debug, + Message = "Prepared command explicitly")] + internal static partial void UnpreparingCommand(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.DerivingParameters, + Level = LogLevel.Debug, + Message = "Deriving Parameters for query: {CommandText}")] + internal static partial void DerivingParameters(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExceptionWhenWritingMultiplexedCommands, + Level = LogLevel.Error, + Message = "Exception while writing multiplexed commands")] + internal static partial void ExceptionWhenWritingMultiplexedCommands(ILogger logger, int ConnectorId, Exception exception); + + [LoggerMessage( + Level = LogLevel.Trace, + Message = "Cleaning up reader")] + internal static partial void ReaderCleanup(ILogger logger, int ConnectorId); + + #endregion Command + + #region Transaction + + [LoggerMessage( + EventId = NpgsqlEventId.StartedTransaction, + Level = LogLevel.Debug, + Message = "Starting transaction")] + internal static partial void StartedTransaction(ILogger logger, IsolationLevel IsolationLevel, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommittedTransaction, + Level = LogLevel.Debug, + Message = "Committed transaction")] + internal static partial void CommittedTransaction(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.RolledBackTransaction, + Level = LogLevel.Debug, + Message = "Rolled back transaction")] + internal static partial void RolledBackTransaction(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CreatingSavepoint, + Level = LogLevel.Debug, + Message = "Creating savepoint '{SavepointName}'")] + internal static partial void CreatingSavepoint(ILogger logger, string SavepointName, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.RolledBackToSavepoint, + Level = LogLevel.Debug, + Message = "Rolled back to savepoint '{SavepointName}'")] + internal static partial void RolledBackToSavepoint(ILogger logger, string SavepointName, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ReleasedSavepoint, + Level = LogLevel.Debug, + Message = "Released savepoint '{SavepointName}'")] + internal static partial void ReleasedSavepoint(ILogger logger, string SavepointName, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExceptionDuringTransactionDispose, + Level = LogLevel.Error, + Message = "Exception while disposing transaction")] + internal static partial void ExceptionDuringTransactionDispose(ILogger logger, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.EnlistedVolatileResourceManager, + Level = LogLevel.Debug, + Message = "Enlisted volatile resource manager (local transaction ID={LocalTransactionIdentifier})")] + internal static partial void EnlistedVolatileResourceManager(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommittingSinglePhaseTransaction, + Level = LogLevel.Debug, + Message = "Committing single-phase transaction (local ID={LocalTransactionIdentifier})")] + internal static partial void CommittingSinglePhaseTransaction(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.RollingBackSinglePhaseTransaction, + Level = LogLevel.Debug, + Message = "Rolling back single-phase transaction (local ID={LocalTransactionIdentifier})")] + internal static partial void RollingBackSinglePhaseTransaction(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.SinglePhaseTransactionRollbackFailed, + Level = LogLevel.Error, + Message = "Exception during single-phase transaction rollback (local ID={LocalTransactionIdentifier})")] + internal static partial void SinglePhaseTransactionRollbackFailed(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.PreparingTwoPhaseTransaction, + Level = LogLevel.Debug, + Message = "Preparing two-phase transaction (local ID={LocalTransactionIdentifier})")] + internal static partial void PreparingTwoPhaseTransaction(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CommittingTwoPhaseTransaction, + Level = LogLevel.Debug, + Message = "Committing two-phase transaction (local ID={LocalTransactionIdentifier})")] + internal static partial void CommittingTwoPhaseTransaction(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.TwoPhaseTransactionCommitFailed, + Level = LogLevel.Error, + Message = "Exception during two-phase transaction commit (local ID={LocalTransactionIdentifier})")] + internal static partial void TwoPhaseTransactionCommitFailed(ILogger logger, string LocalTransactionIdentifier, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.RollingBackTwoPhaseTransaction, + Level = LogLevel.Debug, + Message = "Rolling back two-phase transaction (local ID={LocalTransactionIdentifier})")] + internal static partial void RollingBackTwoPhaseTransaction(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.TwoPhaseTransactionRollbackFailed, + Level = LogLevel.Debug, + Message = "Exception during two-phase transaction rollback (local ID={LocalTransactionIdentifier})")] + internal static partial void TwoPhaseTransactionRollbackFailed(ILogger logger, string LocalTransactionIdentifier, int ConnectorId, Exception exception); + + [LoggerMessage( + EventId = NpgsqlEventId.TwoPhaseTransactionInDoubt, + Level = LogLevel.Warning, + Message = "Two-phase transaction in doubt (local ID={LocalTransactionIdentifier})")] + internal static partial void TwoPhaseTransactionInDoubt(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ConnectionInUseWhenRollingBack, + Level = LogLevel.Warning, + Message = "Connection in use while trying to rollback, will cancel and retry (local ID={LocalTransactionIdentifier}")] + internal static partial void ConnectionInUseWhenRollingBack(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CleaningUpResourceManager, + Level = LogLevel.Trace, + Message = "Cleaning up resource manager (local ID={LocalTransactionIdentifier})")] + internal static partial void CleaningUpResourceManager(ILogger logger, string LocalTransactionIdentifier, int ConnectorId); + + #endregion Transaction + + #region Copy + + [LoggerMessage( + EventId = NpgsqlEventId.StartingBinaryExport, + Level = LogLevel.Information, + Message = "Starting binary export")] + internal static partial void StartingBinaryExport(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingBinaryImport, + Level = LogLevel.Information, + Message = "Starting binary import")] + internal static partial void StartingBinaryImport(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingTextExport, + Level = LogLevel.Information, + Message = "Starting text export")] + internal static partial void StartingTextExport(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingTextImport, + Level = LogLevel.Information, + Message = "Starting text import")] + internal static partial void StartingTextImport(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingRawCopy, + Level = LogLevel.Information, + Message = "Starting raw COPY operation")] + internal static partial void StartingRawCopy(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CopyOperationCompleted, + Level = LogLevel.Information, + Message = "Binary COPY operation completed ({Rows} rows transferred)")] + internal static partial void BinaryCopyOperationCompleted(ILogger logger, ulong Rows, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CopyOperationCompleted, + Level = LogLevel.Information, + Message = "COPY operation completed")] + internal static partial void CopyOperationCompleted(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.CopyOperationCancelled, + Level = LogLevel.Information, + Message = "COPY operation was cancelled")] + internal static partial void CopyOperationCancelled(ILogger logger, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExceptionWhenDisposingCopyOperation, + Level = LogLevel.Debug, + Message = "Exception when disposing a COPY operation")] + internal static partial void ExceptionWhenDisposingCopyOperation(ILogger logger, int ConnectorId, Exception exception); + + #endregion Copy + + #region Replication + + [LoggerMessage( + EventId = NpgsqlEventId.CreatingReplicationSlot, + Level = LogLevel.Information, + Message = "Creating replication slot '{SlotName}'")] + internal static partial void CreatingReplicationSlot(ILogger logger, string SlotName, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.DroppingReplicationSlot, + Level = LogLevel.Information, + Message = "Dropping replication slot '{SlotName}'")] + internal static partial void DroppingReplicationSlot(ILogger logger, string SlotName, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingLogicalReplication, + Level = LogLevel.Information, + Message = "Starting logical replication on slot '{SlotName}'")] + internal static partial void StartingLogicalReplication(ILogger logger, string SlotName, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.StartingPhysicalReplication, + Level = LogLevel.Information, + Message = "Starting physical replication on slot: '{SlotName}'")] + internal static partial void StartingPhysicalReplication(ILogger logger, string? SlotName, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ExecutingReplicationCommand, + Level = LogLevel.Debug, + Message = "Executing replication command: {CommandText}")] + internal static partial void ExecutingReplicationCommand(ILogger logger, string CommandText, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ReceivedReplicationPrimaryKeepalive, + Level = LogLevel.Trace, + Message = "Received replication primary keepalive message from the server with current end of WAL of {EndLsn} and timestamp of {Timestamp}")] + internal static partial void ReceivedReplicationPrimaryKeepalive(ILogger logger, NpgsqlLogSequenceNumber EndLsn, DateTime Timestamp, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.SendingReplicationStandbyStatusUpdate, + Level = LogLevel.Trace, + Message = "Sending a replication standby status update because {Reason}")] + internal static partial void SendingReplicationStandbyStatusUpdate(ILogger logger, string Reason, int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.SentReplicationFeedbackMessage, + Level = LogLevel.Trace, + Message = "Feedback message sent with LastReceivedLsn={LastReceivedLsn}, LastFlushedLsn={LastFlushedLsn}, LastAppliedLsn={LastAppliedLsn}, Timestamp={Timestamp}", + SkipEnabledCheck = true)] + internal static partial void SentReplicationFeedbackMessage( + ILogger logger, + NpgsqlLogSequenceNumber LastReceivedLsn, + NpgsqlLogSequenceNumber LastFlushedLsn, + NpgsqlLogSequenceNumber LastAppliedLsn, + DateTime Timestamp, + int ConnectorId); + + [LoggerMessage( + EventId = NpgsqlEventId.ReplicationFeedbackMessageSendingFailed, + Level = LogLevel.Error, + Message = "An exception occurred while sending a feedback message")] + internal static partial void ReplicationFeedbackMessageSendingFailed(ILogger logger, int? ConnectorId, Exception exception); + + #endregion Replication +} diff --git a/src/Npgsql/Logging/ConsoleLoggingProvider.cs b/src/Npgsql/Logging/ConsoleLoggingProvider.cs deleted file mode 100644 index e049d0e3e8..0000000000 --- a/src/Npgsql/Logging/ConsoleLoggingProvider.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System; -using System.Text; - -namespace Npgsql.Logging -{ - /// - /// An logging provider that outputs Npgsql logging messages to standard error. - /// - public class ConsoleLoggingProvider : INpgsqlLoggingProvider - { - readonly NpgsqlLogLevel _minLevel; - readonly bool _printLevel; - readonly bool _printConnectorId; - - /// - /// Constructs a new - /// - /// Only messages of this level of higher will be logged - /// If true, will output the log level (e.g. WARN). Defaults to false. - /// If true, will output the connector ID. Defaults to false. - public ConsoleLoggingProvider(NpgsqlLogLevel minLevel=NpgsqlLogLevel.Info, bool printLevel=false, bool printConnectorId=false) - { - _minLevel = minLevel; - _printLevel = printLevel; - _printConnectorId = printConnectorId; - } - - /// - /// Creates a new instance of the given name. - /// - public NpgsqlLogger CreateLogger(string name) - { - return new ConsoleLogger(_minLevel, _printLevel, _printConnectorId); - } - } - - class ConsoleLogger : NpgsqlLogger - { - readonly NpgsqlLogLevel _minLevel; - readonly bool _printLevel; - readonly bool _printConnectorId; - - internal ConsoleLogger(NpgsqlLogLevel minLevel, bool printLevel, bool printConnectorId) - { - _minLevel = minLevel; - _printLevel = printLevel; - _printConnectorId = printConnectorId; - } - - public override bool IsEnabled(NpgsqlLogLevel level) => level >= _minLevel; - - public override void Log(NpgsqlLogLevel level, int connectorId, string msg, Exception? exception = null) - { - if (!IsEnabled(level)) - return; - - var sb = new StringBuilder(); - if (_printLevel) { - sb.Append(level.ToString().ToUpper()); - sb.Append(' '); - } - - if (_printConnectorId && connectorId != 0) - { - sb.Append("["); - sb.Append(connectorId); - sb.Append("] "); - } - - sb.AppendLine(msg); - - if (exception != null) - sb.AppendLine(exception.ToString()); - - Console.Error.Write(sb.ToString()); - } - } -} diff --git a/src/Npgsql/Logging/INpgsqlLoggingProvider.cs b/src/Npgsql/Logging/INpgsqlLoggingProvider.cs deleted file mode 100644 index 922dc7775e..0000000000 --- a/src/Npgsql/Logging/INpgsqlLoggingProvider.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace Npgsql.Logging -{ - /// Used to create logger instances of the given name. - public interface INpgsqlLoggingProvider - { - /// - /// Creates a new INpgsqlLogger instance of the given name. - /// - NpgsqlLogger CreateLogger(string name); - } -} diff --git a/src/Npgsql/Logging/NoOpLoggingProvider.cs b/src/Npgsql/Logging/NoOpLoggingProvider.cs deleted file mode 100644 index 09ae96c2dd..0000000000 --- a/src/Npgsql/Logging/NoOpLoggingProvider.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System; - -namespace Npgsql.Logging -{ - class NoOpLoggingProvider : INpgsqlLoggingProvider - { - public NpgsqlLogger CreateLogger(string name) => NoOpLogger.Instance; - } - - class NoOpLogger : NpgsqlLogger - { - internal static NoOpLogger Instance = new NoOpLogger(); - - NoOpLogger() {} - public override bool IsEnabled(NpgsqlLogLevel level) => false; - public override void Log(NpgsqlLogLevel level, int connectorId, string msg, Exception? exception = null) - { - } - } -} diff --git a/src/Npgsql/Logging/NpgsqlLogLevel.cs b/src/Npgsql/Logging/NpgsqlLogLevel.cs deleted file mode 100644 index 4f127c0b99..0000000000 --- a/src/Npgsql/Logging/NpgsqlLogLevel.cs +++ /dev/null @@ -1,14 +0,0 @@ -#pragma warning disable 1591 - -namespace Npgsql.Logging -{ - public enum NpgsqlLogLevel - { - Trace = 1, - Debug = 2, - Info = 3, - Warn = 4, - Error = 5, - Fatal = 6, - } -} diff --git a/src/Npgsql/Logging/NpgsqlLogManager.cs b/src/Npgsql/Logging/NpgsqlLogManager.cs deleted file mode 100644 index 3bafb2d120..0000000000 --- a/src/Npgsql/Logging/NpgsqlLogManager.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; - -namespace Npgsql.Logging -{ - /// - /// Manages logging for Npgsql, used to set the logging provider. - /// - public static class NpgsqlLogManager - { - /// - /// The logging provider used for logging in Npgsql. - /// - public static INpgsqlLoggingProvider Provider - { - get - { - _providerRetrieved = true; - return _provider; - } - set - { - if (_providerRetrieved) - throw new InvalidOperationException("The logging provider must be set before any Npgsql action is taken"); - - _provider = value ?? throw new ArgumentNullException(nameof(value)); - } - } - - /// - /// Determines whether parameter contents will be logged alongside SQL statements - this may reveal sensitive information. - /// Defaults to false. - /// - public static bool IsParameterLoggingEnabled { get; set; } - - static INpgsqlLoggingProvider _provider = new NoOpLoggingProvider(); - static bool _providerRetrieved; - - internal static NpgsqlLogger CreateLogger(string name) => Provider.CreateLogger("Npgsql." + name); - } -} diff --git a/src/Npgsql/Logging/NpgsqlLogger.cs b/src/Npgsql/Logging/NpgsqlLogger.cs deleted file mode 100644 index cd84fd5279..0000000000 --- a/src/Npgsql/Logging/NpgsqlLogger.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; - -#pragma warning disable 1591 - -namespace Npgsql.Logging -{ - /// - /// A generic interface for logging. - /// - public abstract class NpgsqlLogger - { - public abstract bool IsEnabled(NpgsqlLogLevel level); - public abstract void Log(NpgsqlLogLevel level, int connectorId, string msg, Exception? exception = null); - - internal void Trace(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Trace, connectionId, msg); - internal void Debug(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Debug, connectionId, msg); - internal void Info(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Info, connectionId, msg); - internal void Warn(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Warn, connectionId, msg); - internal void Error(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Error, connectionId, msg); - internal void Fatal(string msg, int connectionId = 0) => Log(NpgsqlLogLevel.Fatal, connectionId, msg); - - internal void Trace(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Trace, connectionId, msg, ex); - internal void Debug(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Debug, connectionId, msg, ex); - internal void Info(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Info, connectionId, msg, ex); - internal void Warn(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Warn, connectionId, msg, ex); - internal void Error(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Error, connectionId, msg, ex); - internal void Fatal(string msg, Exception ex, int connectionId = 0) => Log(NpgsqlLogLevel.Fatal, connectionId, msg, ex); - } -} diff --git a/src/Npgsql/MetricsReporter.cs b/src/Npgsql/MetricsReporter.cs new file mode 100644 index 0000000000..f29f0c47e2 --- /dev/null +++ b/src/Npgsql/MetricsReporter.cs @@ -0,0 +1,275 @@ +using System; + +namespace Npgsql; + +#if NET6_0_OR_GREATER +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Runtime.InteropServices; +using System.Threading; + +// .NET docs on metric instrumentation: https://learn.microsoft.com/en-us/dotnet/core/diagnostics/metrics-instrumentation +// OpenTelemetry semantic conventions for database metric: https://opentelemetry.io/docs/specs/otel/metrics/semantic_conventions/database-metrics +sealed class MetricsReporter : IDisposable +{ + const string Version = "0.1.0"; + + static readonly Meter Meter; + + static readonly UpDownCounter CommandsExecuting; + static readonly Counter CommandsFailed; + static readonly Histogram CommandDuration; + + static readonly Counter BytesWritten; + static readonly Counter BytesRead; + + static readonly UpDownCounter PendingConnectionRequests; + static readonly Counter ConnectionTimeouts; + static readonly Histogram ConnectionCreateTime; + static readonly ObservableGauge PreparedRatio; + + readonly NpgsqlDataSource _dataSource; + readonly KeyValuePair _poolNameTag; + + static readonly List Reporters = new(); + + CommandCounters _commandCounters; + + [StructLayout(LayoutKind.Explicit)] + struct CommandCounters + { + [FieldOffset(0)] internal int CommandsStarted; + [FieldOffset(4)] internal int PreparedCommandsStarted; + [FieldOffset(0)] internal long All; + } + + static MetricsReporter() + { + Meter = new("Npgsql", Version); + + CommandsExecuting = Meter.CreateUpDownCounter( + "db.client.commands.executing", + unit: "{command}", + description: "The number of currently executing database commands."); + + CommandsFailed = Meter.CreateCounter( + "db.client.commands.failed", + unit: "{command}", + description: "The number of database commands which have failed."); + + CommandDuration = Meter.CreateHistogram( + "db.client.commands.duration", + unit: "s", + description: "The duration of database commands, in seconds."); + + BytesWritten = Meter.CreateCounter( + "db.client.commands.bytes_written", + unit: "By", + description: "The number of bytes written."); + + BytesRead = Meter.CreateCounter( + "db.client.commands.bytes_read", + unit: "By", + description: "The number of bytes read."); + + PendingConnectionRequests = Meter.CreateUpDownCounter( + "db.client.connections.pending_requests", + unit: "{request}", + description: "The number of pending requests for an open connection, cumulative for the entire pool."); + + ConnectionTimeouts = Meter.CreateCounter( + "db.client.connections.timeouts", + unit: "{timeout}", + description: "The number of connection timeouts that have occurred trying to obtain a connection from the pool."); + + ConnectionCreateTime = Meter.CreateHistogram( + "db.client.connections.create_time", + unit: "s", + description: "The time it took to create a new connection."); + + // Observable metrics; these are for values we already track internally (and efficiently) inside the connection pool implementation. + Meter.CreateObservableUpDownCounter( + "db.client.connections.usage", + GetConnectionUsage, + unit: "{connection}", + description: "The number of connections that are currently in state described by the state attribute."); + + // It's a bit ridiculous to manage "max connections" as an observable counter, given that it never changes for a given pool. + // However, we can't simply report it once at startup, since clients who connect later wouldn't have it. And since reporting it + // repeatedly isn't possible because we need to provide incremental figures, we just manage it as an observable counter. + Meter.CreateObservableUpDownCounter( + "db.client.connections.max", + GetMaxConnections, + unit: "{connection}", + description: "The maximum number of open connections allowed."); + + PreparedRatio = Meter.CreateObservableGauge( + "db.client.commands.prepared_ratio", + GetPreparedCommandsRatio, + description: "The ratio of prepared command executions."); + } + + public MetricsReporter(NpgsqlDataSource dataSource) + { + _dataSource = dataSource; + _poolNameTag = new KeyValuePair("pool.name", dataSource.Name); + + lock (Reporters) + { + Reporters.Add(this); + Reporters.Sort((x,y) => string.Compare(x._dataSource.Name, y._dataSource.Name, StringComparison.Ordinal)); + } + } + + internal long ReportCommandStart() + { + CommandsExecuting.Add(1, _poolNameTag); + if (PreparedRatio.Enabled) + Interlocked.Increment(ref _commandCounters.CommandsStarted); + + return CommandDuration.Enabled ? Stopwatch.GetTimestamp() : 0; + } + + internal void ReportCommandStop(long startTimestamp) + { + CommandsExecuting.Add(-1, _poolNameTag); + + if (CommandDuration.Enabled && startTimestamp > 0) + { +#if NET7_0_OR_GREATER + var duration = Stopwatch.GetElapsedTime(startTimestamp); +#else + var duration = new TimeSpan((long)((Stopwatch.GetTimestamp() - startTimestamp) * StopWatchTickFrequency)); +#endif + CommandDuration.Record(duration.TotalSeconds, _poolNameTag); + } + } + + internal void CommandStartPrepared() + { + if (PreparedRatio.Enabled) + Interlocked.Increment(ref _commandCounters.PreparedCommandsStarted); + } + + internal void ReportCommandFailed() => CommandsFailed.Add(1, _poolNameTag); + + internal void ReportBytesWritten(long bytesWritten) => BytesWritten.Add(bytesWritten, _poolNameTag); + internal void ReportBytesRead(long bytesRead) => BytesRead.Add(bytesRead, _poolNameTag); + + internal void ReportConnectionPoolTimeout() + => ConnectionTimeouts.Add(1, _poolNameTag); + + internal void ReportPendingConnectionRequestStart() + => PendingConnectionRequests.Add(1, _poolNameTag); + internal void ReportPendingConnectionRequestStop() + => PendingConnectionRequests.Add(-1, _poolNameTag); + + internal void ReportConnectionCreateTime(TimeSpan duration) + => ConnectionCreateTime.Record(duration.TotalSeconds, _poolNameTag); + + static IEnumerable> GetConnectionUsage() + { + lock (Reporters) + { + var measurements = new List>(); + + for (var i = 0; i < Reporters.Count; i++) + { + var reporter = Reporters[i]; + + if (reporter._dataSource is PoolingDataSource poolingDataSource) + { + var stats = poolingDataSource.Statistics; + + measurements.Add(new Measurement( + stats.Idle, + reporter._poolNameTag, + new KeyValuePair("state", "idle"))); + + measurements.Add(new Measurement( + stats.Busy, + reporter._poolNameTag, + new KeyValuePair("state", "used"))); + } + } + + return measurements; + } + } + + static IEnumerable> GetMaxConnections() + { + lock (Reporters) + { + var measurements = new List>(); + + foreach (var reporter in Reporters) + { + if (reporter._dataSource is PoolingDataSource poolingDataSource) + { + measurements.Add(new Measurement(poolingDataSource.MaxConnections, reporter._poolNameTag)); + } + } + + return measurements; + } + } + + static IEnumerable> GetPreparedCommandsRatio() + { + lock (Reporters) + { + var measurements = new List>(Reporters.Count); + + for (var i = 0; i < Reporters.Count; i++) + { + var reporter = Reporters[i]; + + var counters = new CommandCounters + { + All = Interlocked.Exchange(ref reporter._commandCounters.All, default) + }; + + var value = (double)counters.PreparedCommandsStarted / counters.CommandsStarted * 100; + + if (double.IsFinite(value)) + measurements.Add(new Measurement(value, reporter._poolNameTag)); + } + + return measurements; + } + } + + public void Dispose() + { + lock (Reporters) + { + Reporters.Remove(this); + } + } + +#if !NET7_0_OR_GREATER + const long TicksPerMicrosecond = 10; + const long TicksPerMillisecond = TicksPerMicrosecond * 1000; + const long TicksPerSecond = TicksPerMillisecond * 1000; // 10,000,000 + static readonly double StopWatchTickFrequency = (double)TicksPerSecond / Stopwatch.Frequency; +#endif +} +#else +sealed class MetricsReporter : IDisposable +{ + public MetricsReporter(NpgsqlDataSource _) {} + internal long ReportCommandStart() => 0; + internal void ReportCommandStop(long startTimestamp) {} + internal void CommandStartPrepared() {} + internal void ReportCommandFailed() {} + internal void ReportBytesWritten(long bytesWritten) {} + internal void ReportBytesRead(long bytesRead) {} + internal void ReportConnectionPoolTimeout() {} + internal void ReportPendingConnectionRequestStart() {} + internal void ReportPendingConnectionRequestStop() {} + internal void ReportConnectionCreateTime(TimeSpan duration) {} + public void Dispose() {} +} +#endif diff --git a/src/Npgsql/MultiHostDataSourceWrapper.cs b/src/Npgsql/MultiHostDataSourceWrapper.cs new file mode 100644 index 0000000000..4dcded98cc --- /dev/null +++ b/src/Npgsql/MultiHostDataSourceWrapper.cs @@ -0,0 +1,48 @@ +using Npgsql.Internal; +using Npgsql.Util; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; + +namespace Npgsql; + +sealed class MultiHostDataSourceWrapper : NpgsqlDataSource +{ + internal override bool OwnsConnectors => false; + + readonly NpgsqlMultiHostDataSource _wrappedSource; + + public MultiHostDataSourceWrapper(NpgsqlMultiHostDataSource source, TargetSessionAttributes targetSessionAttributes) + : base(CloneSettingsForTargetSessionAttributes(source.Settings, targetSessionAttributes), source.Configuration) + => _wrappedSource = source; + + static NpgsqlConnectionStringBuilder CloneSettingsForTargetSessionAttributes( + NpgsqlConnectionStringBuilder settings, + TargetSessionAttributes targetSessionAttributes) + { + var clonedSettings = settings.Clone(); + clonedSettings.TargetSessionAttributesParsed = targetSessionAttributes; + return clonedSettings; + } + + internal override (int Total, int Idle, int Busy) Statistics => _wrappedSource.Statistics; + + internal override void Clear() => _wrappedSource.Clear(); + internal override ValueTask Get(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + => _wrappedSource.Get(conn, timeout, async, cancellationToken); + internal override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + => throw new NpgsqlException("Npgsql bug: trying to get an idle connector from " + nameof(MultiHostDataSourceWrapper)); + internal override ValueTask OpenNewConnector(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + => throw new NpgsqlException("Npgsql bug: trying to open a new connector from " + nameof(MultiHostDataSourceWrapper)); + internal override void Return(NpgsqlConnector connector) + => _wrappedSource.Return(connector); + + internal override void AddPendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) + => _wrappedSource.AddPendingEnlistedConnector(connector, transaction); + internal override bool TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) + => _wrappedSource.TryRemovePendingEnlistedConnector(connector, transaction); + internal override bool TryRentEnlistedPending(Transaction transaction, NpgsqlConnection connection, + [NotNullWhen(true)] out NpgsqlConnector? connector) + => _wrappedSource.TryRentEnlistedPending(transaction, connection, out connector); +} \ No newline at end of file diff --git a/src/Npgsql/MultiplexingDataSource.cs b/src/Npgsql/MultiplexingDataSource.cs new file mode 100644 index 0000000000..277bc4e835 --- /dev/null +++ b/src/Npgsql/MultiplexingDataSource.cs @@ -0,0 +1,400 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Util; + +namespace Npgsql; + +sealed class MultiplexingDataSource : PoolingDataSource +{ + readonly ILogger _connectionLogger; + readonly ILogger _commandLogger; + + readonly bool _autoPrepare; + + readonly ChannelReader _multiplexCommandReader; + internal ChannelWriter MultiplexCommandWriter { get; } + + readonly Task _multiplexWriteLoop; + + /// + /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before + /// flushing to the network. + /// + readonly int _writeCoalescingBufferThresholdBytes; + + // TODO: Make this configurable + const int MultiplexingCommandChannelBound = 4096; + + internal MultiplexingDataSource( + NpgsqlConnectionStringBuilder settings, + NpgsqlDataSourceConfiguration dataSourceConfig, + NpgsqlMultiHostDataSource? parentPool = null) + : base(settings, dataSourceConfig, parentPool) + { + Debug.Assert(Settings.Multiplexing); + + // TODO: Validate multiplexing options are set only when Multiplexing is on + + _autoPrepare = settings.MaxAutoPrepare > 0; + + _writeCoalescingBufferThresholdBytes = Settings.WriteCoalescingBufferThresholdBytes; + + var multiplexCommandChannel = Channel.CreateBounded( + new BoundedChannelOptions(MultiplexingCommandChannelBound) + { + FullMode = BoundedChannelFullMode.Wait, + SingleReader = true + }); + _multiplexCommandReader = multiplexCommandChannel.Reader; + MultiplexCommandWriter = multiplexCommandChannel.Writer; + + _connectionLogger = dataSourceConfig.LoggingConfiguration.ConnectionLogger; + _commandLogger = dataSourceConfig.LoggingConfiguration.CommandLogger; + + _multiplexWriteLoop = Task.Run(MultiplexingWriteLoop, CancellationToken.None) + .ContinueWith(t => + { + if (t.IsFaulted) + { + // Note that MultiplexingWriteLoop should never throw an exception - everything should be caught and handled internally. + _connectionLogger.LogError(t.Exception, "Exception in multiplexing write loop, this is an Npgsql bug, please file an issue."); + } + }); + } + + async Task MultiplexingWriteLoop() + { + // This method is async, but only ever yields when there are no pending commands in the command channel. + // No I/O should ever be performed asynchronously, as that would block further writing for the entire + // application; whenever I/O cannot complete immediately, we chain a callback with ContinueWith and move + // on to the next connector. + Debug.Assert(_multiplexCommandReader != null); + + var stats = new MultiplexingStats { Stopwatch = new Stopwatch() }; + + while (true) + { + NpgsqlConnector? connector; + NpgsqlCommand? command; + + try + { + // Get a first command out. + if (!_multiplexCommandReader.TryRead(out command)) + command = await _multiplexCommandReader.ReadAsync().ConfigureAwait(false); + } + catch (ChannelClosedException) + { + return; + } + + try + { + // First step is to get a connector on which to execute + var spinwait = new SpinWait(); + while (true) + { + if (TryGetIdleConnector(out connector)) + { + // See increment under over-capacity mode below + Interlocked.Increment(ref connector.CommandsInFlightCount); + break; + } + + connector = await OpenNewConnector( + command.InternalConnection!, + new NpgsqlTimeout(TimeSpan.FromSeconds(Settings.Timeout)), + async: true, + CancellationToken.None).ConfigureAwait(false); + + if (connector != null) + { + // Managed to created a new connector + connector.Connection = null; + + // See increment under over-capacity mode below + Interlocked.Increment(ref connector.CommandsInFlightCount); + + break; + } + + // There were no idle connectors and we're at max capacity, so we can't open a new one. + // Enter over-capacity mode - find an unlocked connector with the least currently in-flight + // commands and sent on it, even though there are already pending commands. + var minInFlight = int.MaxValue; + foreach (var c in Connectors) + { + if (c?.MultiplexAsyncWritingLock == 0 && c.CommandsInFlightCount < minInFlight) + { + minInFlight = c.CommandsInFlightCount; + connector = c; + } + } + + // There could be no writable connectors (all stuck in transaction or flushing). + if (connector == null) + { + // TODO: This is problematic - when absolutely all connectors are both busy *and* currently + // performing (async) I/O, this will spin-wait. + // We could call WaitAsync, but that would wait for an idle connector, whereas we want any + // writeable (non-writing) connector even if it has in-flight commands. Maybe something + // with better back-off. + // On the other hand, this is exactly *one* thread doing spin-wait, maybe not that bad. + spinwait.SpinOnce(); + continue; + } + + // We may be in a race condition with the connector read loop, which may be currently returning + // the connector to the Idle channel (because it has completed all commands). + // Increment the in-flight count to make sure the connector isn't returned as idle. + var newInFlight = Interlocked.Increment(ref connector.CommandsInFlightCount); + if (newInFlight == 1) + { + // The connector's in-flight was 0, so it was idle - abort over-capacity read + // and retry the normal flow. + Interlocked.Decrement(ref connector.CommandsInFlightCount); + spinwait.SpinOnce(); + continue; + } + + break; + } + } + catch (Exception exception) + { + LogMessages.ExceptionWhenOpeningConnectionForMultiplexing(_connectionLogger, exception); + + // Fail the first command in the channel as a way of bubbling the exception up to the user + command.ExecutionCompletion.SetException(exception); + + continue; + } + + // We now have a ready connector, and can start writing commands to it. + Debug.Assert(connector != null); + + try + { + stats.Reset(); + connector.FlagAsNotWritableForMultiplexing(); + command.TraceCommandStart(connector); + + // Read queued commands and write them to the connector's buffer, for as long as we're + // under our write threshold and timer delay. + // Note we already have one command we read above, and have already updated the connector's + // CommandsInFlightCount. Now write that command. + var first = true; + bool writtenSynchronously; + do + { + if (first) + first = false; + else + Interlocked.Increment(ref connector.CommandsInFlightCount); + writtenSynchronously = WriteCommand(connector, command, ref stats); + } while (connector.WriteBuffer.WritePosition < _writeCoalescingBufferThresholdBytes && + writtenSynchronously && + _multiplexCommandReader.TryRead(out command)); + + // If all commands were written synchronously (good path), complete the write here, flushing + // and updating statistics. If not, CompleteRewrite is scheduled to run later, when the async + // operations complete, so skip it and continue. + if (writtenSynchronously) + Flush(connector, ref stats); + } + catch (Exception ex) + { + FailWrite(connector, ex); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + bool WriteCommand(NpgsqlConnector connector, NpgsqlCommand command, ref MultiplexingStats stats) + { + // Note: this method *never* awaits on I/O - doing so would suspend all outgoing multiplexing commands + // for the entire pool. In the normal/fast case, writing the command is purely synchronous (serialize + // to buffer in memory), and the actual flush will occur at the level above. For cases where the + // command overflows the buffer, async I/O is done, and we schedule continuations separately - + // but the main thread continues to handle other commands on other connectors. + if (_autoPrepare) + { + // TODO: Need to log based on numPrepared like in non-multiplexing mode... + for (var i = 0; i < command.InternalBatchCommands.Count; i++) + command.InternalBatchCommands[i].TryAutoPrepare(connector); + } + + var written = connector.CommandsInFlightWriter!.TryWrite(command); + Debug.Assert(written, $"Failed to enqueue command to {connector.CommandsInFlightWriter}"); + + // Purposefully don't wait for I/O to complete + var task = command.Write(connector, async: true, flush: false); + stats.NumCommands++; + + switch (task.Status) + { + case TaskStatus.RanToCompletion: + return true; + + case TaskStatus.Faulted: + task.GetAwaiter().GetResult(); // Throw the exception + return true; + + case TaskStatus.WaitingForActivation: + case TaskStatus.Running: + { + // Asynchronous completion, which means the writing is flushing to network and there's actual I/O + // (i.e. a big command which overflowed our buffer). + // We don't (ever) await in the write loop, so remove the connector from the writable list (as it's + // still flushing) and schedule a continuation to continue taking care of this connector. + // The write loop continues to the next connector. + + // Create a copy of the statistics and purposefully box it via the closure. We need a separate + // copy of the stats for the async writing that will continue in parallel with this loop. + var clonedStats = stats.Clone(); + + // ReSharper disable once MethodSupportsCancellation + task.ContinueWith((t, o) => + { + var conn = (NpgsqlConnector)o!; + + if (t.IsFaulted) + { + FailWrite(conn, t.Exception!.InnerException!); + return; + } + + // There's almost certainly more buffered outgoing data for the command, after the flush + // occured. Complete the write, which will flush again (and update statistics). + try + { + Flush(conn, ref clonedStats); + } + catch (Exception e) + { + FailWrite(conn, e); + } + }, connector); + + return false; + } + + default: + Debug.Fail("When writing command to connector, task is in invalid state " + task.Status); + ThrowHelper.ThrowNpgsqlException("When writing command to connector, task is in invalid state " + task.Status); + return false; + } + } + + void Flush(NpgsqlConnector connector, ref MultiplexingStats stats) + { + var task = connector.Flush(async: true); + switch (task.Status) + { + case TaskStatus.RanToCompletion: + CompleteWrite(connector, ref stats); + return; + + case TaskStatus.Faulted: + task.GetAwaiter().GetResult(); // Throw the exception + return; + + case TaskStatus.WaitingForActivation: + case TaskStatus.Running: + { + // Asynchronous completion - the flush didn't complete immediately (e.g. TCP zero window). + + // Create a copy of the statistics and purposefully box it via the closure. We need a separate + // copy of the stats for the async writing that will continue in parallel with this loop. + var clonedStats = stats.Clone(); + + task.ContinueWith((t, o) => + { + var conn = (NpgsqlConnector)o!; + if (t.IsFaulted) + { + FailWrite(conn, t.Exception!.InnerException!); + return; + } + + CompleteWrite(conn, ref clonedStats); + }, connector); + + return; + } + + default: + Debug.Fail("When flushing, task is in invalid state " + task.Status); + ThrowHelper.ThrowNpgsqlException("When flushing, task is in invalid state " + task.Status); + return; + } + } + + void FailWrite(NpgsqlConnector connector, Exception exception) + { + // Note that all commands already passed validation. This means any error here is either an unrecoverable network issue + // (in which case we're already broken), or some other issue while writing (e.g. invalid UTF8 characters in the SQL query) - + // unrecoverable in any case. + + // All commands enqueued in CommandsInFlightWriter will be drained by the reader and failed. + // Note that some of these commands where only written to the connector's buffer, but never + // actually sent - because of a later exception. + // In theory, we could track commands that were only enqueued and not sent, and retry those + // (on another connector), but that would add some book-keeping and complexity, and in any case + // if one connector was broken, chances are that all are (networking). + Debug.Assert(connector.IsBroken); + + LogMessages.ExceptionWhenWritingMultiplexedCommands(_commandLogger, connector.Id, exception); + } + + static void CompleteWrite(NpgsqlConnector connector, ref MultiplexingStats stats) + { + // All I/O has completed, mark this connector as safe for writing again. + // This will allow the connector to be returned to the pool by its read loop, and also to be selected + // for over-capacity write. + connector.FlagAsWritableForMultiplexing(); + + NpgsqlEventSource.Log.MultiplexingBatchSent(stats.NumCommands, stats.Stopwatch); + } + + // ReSharper disable once FunctionNeverReturns + } + + protected override void DisposeBase() + { + MultiplexCommandWriter.Complete(new ObjectDisposedException(nameof(MultiplexingDataSource))); + _multiplexWriteLoop.GetAwaiter().GetResult(); + base.DisposeBase(); + } + + protected override async ValueTask DisposeAsyncBase() + { + MultiplexCommandWriter.Complete(new ObjectDisposedException(nameof(MultiplexingDataSource))); + await _multiplexWriteLoop.ConfigureAwait(false); + await base.DisposeAsyncBase().ConfigureAwait(false); + } + + struct MultiplexingStats + { + internal Stopwatch Stopwatch; + internal int NumCommands; + + internal void Reset() + { + NumCommands = 0; + Stopwatch.Reset(); + } + + internal MultiplexingStats Clone() + { + var clone = new MultiplexingStats { Stopwatch = Stopwatch, NumCommands = NumCommands }; + Stopwatch = new Stopwatch(); + return clone; + } + } +} diff --git a/src/Npgsql/NameTranslation/INpgsqlNameTranslator.cs b/src/Npgsql/NameTranslation/INpgsqlNameTranslator.cs index d249955b6b..1fa188a91e 100644 --- a/src/Npgsql/NameTranslation/INpgsqlNameTranslator.cs +++ b/src/Npgsql/NameTranslation/INpgsqlNameTranslator.cs @@ -1,20 +1,19 @@ -namespace Npgsql +namespace Npgsql; + +/// +/// A component which translates a CLR name (e.g. SomeClass) into a database name (e.g. some_class) +/// according to some scheme. +/// Used for mapping enum and composite types. +/// +public interface INpgsqlNameTranslator { /// - /// A component which translates a CLR name (e.g. SomeClass) into a database name (e.g. some_class) - /// according to some scheme. - /// Used for mapping enum and composite types. + /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. /// - public interface INpgsqlNameTranslator - { - /// - /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. - /// - string TranslateTypeName(string clrName); + string TranslateTypeName(string clrName); - /// - /// Given a CLR member name (property or field), translates its name to a database type name. - /// - string TranslateMemberName(string clrName); - } -} + /// + /// Given a CLR member name (property or field), translates its name to a database type name. + /// + string TranslateMemberName(string clrName); +} \ No newline at end of file diff --git a/src/Npgsql/NameTranslation/NpgsqlNullNameTranslator.cs b/src/Npgsql/NameTranslation/NpgsqlNullNameTranslator.cs index 18b0338964..f754169a72 100644 --- a/src/Npgsql/NameTranslation/NpgsqlNullNameTranslator.cs +++ b/src/Npgsql/NameTranslation/NpgsqlNullNameTranslator.cs @@ -1,20 +1,19 @@ using System; -namespace Npgsql.NameTranslation +namespace Npgsql.NameTranslation; + +/// +/// A name translator which preserves CLR names (e.g. SomeClass) when mapping names to the database. +/// +public sealed class NpgsqlNullNameTranslator : INpgsqlNameTranslator { /// - /// A name translator which preserves CLR names (e.g. SomeClass) when mapping names to the database. + /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. /// - public class NpgsqlNullNameTranslator : INpgsqlNameTranslator - { - /// - /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. - /// - public string TranslateTypeName(string clrName) => clrName ?? throw new ArgumentNullException(nameof(clrName)); + public string TranslateTypeName(string clrName) => clrName ?? throw new ArgumentNullException(nameof(clrName)); - /// - /// Given a CLR member name (property or field), translates its name to a database type name. - /// - public string TranslateMemberName(string clrName) => clrName ?? throw new ArgumentNullException(nameof(clrName)); - } -} + /// + /// Given a CLR member name (property or field), translates its name to a database type name. + /// + public string TranslateMemberName(string clrName) => clrName ?? throw new ArgumentNullException(nameof(clrName)); +} \ No newline at end of file diff --git a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs index be8e97942c..760ddb1e5a 100644 --- a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs +++ b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs @@ -1,107 +1,139 @@ using System; +using System.Collections.Generic; using System.Globalization; -using System.Linq; using System.Text; -namespace Npgsql.NameTranslation +namespace Npgsql.NameTranslation; + +/// +/// A name translator which converts standard CLR names (e.g. SomeClass) to snake-case database +/// names (some_class) +/// +public sealed class NpgsqlSnakeCaseNameTranslator : INpgsqlNameTranslator { + internal static NpgsqlSnakeCaseNameTranslator Instance { get; } = new(); + + readonly CultureInfo _culture; + /// - /// A name translator which converts standard CLR names (e.g. SomeClass) to snake-case database - /// names (some_class) + /// Creates a new . /// - public class NpgsqlSnakeCaseNameTranslator : INpgsqlNameTranslator + /// + /// An object that supplies culture-specific casing rules. + /// This will be used when converting names to lower case. + /// If then will be used. + /// + public NpgsqlSnakeCaseNameTranslator(CultureInfo? culture = null) + : this(false, culture) { } + + /// + /// Creates a new . + /// + /// + /// Uses the legacy naming convention if , otherwise it uses the new naming convention. + /// + /// + /// An object that supplies culture-specific casing rules. + /// This will be used when converting names to lower case. + /// If then will be used. + /// + public NpgsqlSnakeCaseNameTranslator(bool legacyMode, CultureInfo? culture = null) { - /// - /// Creates a new . - /// - public NpgsqlSnakeCaseNameTranslator() - : this(false) { } - - /// - /// Creates a new . - /// - /// Uses the legacy naming convention if true, otherwise it uses the new naming convention. - public NpgsqlSnakeCaseNameTranslator(bool legacyMode) - => LegacyMode = legacyMode; - - bool LegacyMode { get; } - - /// - /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. - /// - public string TranslateTypeName(string clrName) => TranslateMemberName(clrName); - - /// - /// Given a CLR member name (property or field), translates its name to a database type name. - /// - public string TranslateMemberName(string clrName) - { - if (clrName == null) - throw new ArgumentNullException(nameof(clrName)); + LegacyMode = legacyMode; + _culture = culture ?? CultureInfo.InvariantCulture; + } - return LegacyMode - ? string.Concat(clrName.Select((c, i) => i > 0 && char.IsUpper(c) ? "_" + c.ToString() : c.ToString())).ToLower() - : ConvertToSnakeCase(clrName); - } + bool LegacyMode { get; } + + /// + /// Given a CLR type name (e.g class, struct, enum), translates its name to a database type name. + /// + public string TranslateTypeName(string clrName) => TranslateMemberName(clrName); + + /// + /// Given a CLR member name (property or field), translates its name to a database type name. + /// + public string TranslateMemberName(string clrName) + { + if (clrName == null) + throw new ArgumentNullException(nameof(clrName)); + + return LegacyMode + ? string.Concat(LegacyModeMap(clrName)).ToLower(_culture) + : ConvertToSnakeCase(clrName, _culture); - /// - /// Converts a string to its snake_case equivalent. - /// - /// The value to convert. - public static string ConvertToSnakeCase(string name) + IEnumerable LegacyModeMap(string clrName) { - if (string.IsNullOrEmpty(name)) - return name; + for (var i = 0; i < clrName.Length; i++) + { + var c = clrName[i]; + yield return i > 0 && char.IsUpper(c) ? "_" + c.ToString() : c.ToString(); + } + } + } - var builder = new StringBuilder(name.Length + Math.Min(2, name.Length / 5)); - var previousCategory = default(UnicodeCategory?); + /// + /// Converts a string to its snake_case equivalent. + /// + /// The value to convert. + /// + /// An object that supplies culture-specific casing rules. + /// This will be used when converting names to lower case. + /// If then will be used. + /// + public static string ConvertToSnakeCase(string name, CultureInfo? culture = null) + { + if (string.IsNullOrEmpty(name)) + return name; - for (var currentIndex = 0; currentIndex < name.Length; currentIndex++) + var builder = new StringBuilder(name.Length + Math.Min(2, name.Length / 5)); + var previousCategory = default(UnicodeCategory?); + + for (var currentIndex = 0; currentIndex < name.Length; currentIndex++) + { + var currentChar = name[currentIndex]; + if (currentChar == '_') { - var currentChar = name[currentIndex]; - if (currentChar == '_') + builder.Append('_'); + previousCategory = null; + continue; + } + + var currentCategory = char.GetUnicodeCategory(currentChar); + switch (currentCategory) + { + case UnicodeCategory.UppercaseLetter: + case UnicodeCategory.TitlecaseLetter: + if (previousCategory == UnicodeCategory.SpaceSeparator || + previousCategory == UnicodeCategory.LowercaseLetter || + previousCategory != UnicodeCategory.DecimalDigitNumber && + previousCategory != null && + currentIndex > 0 && + currentIndex + 1 < name.Length && + char.IsLower(name[currentIndex + 1])) { builder.Append('_'); - previousCategory = null; - continue; } - var currentCategory = char.GetUnicodeCategory(currentChar); - switch (currentCategory) - { - case UnicodeCategory.UppercaseLetter: - case UnicodeCategory.TitlecaseLetter: - if (previousCategory == UnicodeCategory.SpaceSeparator || - previousCategory == UnicodeCategory.LowercaseLetter || - previousCategory != UnicodeCategory.DecimalDigitNumber && - previousCategory != null && - currentIndex > 0 && - currentIndex + 1 < name.Length && - char.IsLower(name[currentIndex + 1])) - { - builder.Append('_'); - } - - currentChar = char.ToLower(currentChar); - break; - - case UnicodeCategory.LowercaseLetter: - case UnicodeCategory.DecimalDigitNumber: - if (previousCategory == UnicodeCategory.SpaceSeparator) - builder.Append('_'); - break; - - default: - if (previousCategory != null) - previousCategory = UnicodeCategory.SpaceSeparator; - continue; - } + currentChar = char.ToLower(currentChar, culture ?? CultureInfo.InvariantCulture); + break; + + case UnicodeCategory.LowercaseLetter: + case UnicodeCategory.DecimalDigitNumber: + if (previousCategory == UnicodeCategory.SpaceSeparator) + builder.Append('_'); + break; - builder.Append(currentChar); - previousCategory = currentCategory; + default: + if (previousCategory != null) + previousCategory = UnicodeCategory.SpaceSeparator; + continue; } - return builder.ToString(); + builder.Append(currentChar); + previousCategory = currentCategory; } + + return builder.ToString(); } } diff --git a/src/Npgsql/Netstandard20/CodeAnalysis.cs b/src/Npgsql/Netstandard20/CodeAnalysis.cs deleted file mode 100644 index ed41012925..0000000000 --- a/src/Npgsql/Netstandard20/CodeAnalysis.cs +++ /dev/null @@ -1,66 +0,0 @@ -#if NETSTANDARD2_0 - -#pragma warning disable 1591 - -// ReSharper disable once CheckNamespace -namespace System.Diagnostics.CodeAnalysis -{ - [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] - sealed class AllowNullAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] - sealed class DisallowNullAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Method)] - sealed class DoesNotReturnAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Parameter)] - sealed class DoesNotReturnIfAttribute : Attribute - { - public DoesNotReturnIfAttribute(bool parameterValue) => ParameterValue = parameterValue; - public bool ParameterValue { get; } - } - - [AttributeUsageAttribute(AttributeTargets.Assembly | AttributeTargets.Class | AttributeTargets.Constructor | AttributeTargets.Event | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Struct, AllowMultiple = false)] - sealed class ExcludeFromCodeCoverageAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] - sealed class MaybeNullAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Parameter)] - sealed class MaybeNullWhenAttribute : Attribute - { - public MaybeNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; - public bool ReturnValue { get; } - } - - [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] - sealed class NotNullAttribute : Attribute - { - } - - [AttributeUsageAttribute(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, AllowMultiple = true)] - sealed class NotNullIfNotNullAttribute : Attribute - { - public NotNullIfNotNullAttribute(string parameterName) => ParameterName = parameterName; - public string ParameterName { get; } - } - - [AttributeUsageAttribute(AttributeTargets.Parameter)] - sealed class NotNullWhenAttribute : Attribute - { - public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; - public bool ReturnValue { get; } - } -} -#endif diff --git a/src/Npgsql/NoSynchronizationContextScope.cs b/src/Npgsql/NoSynchronizationContextScope.cs deleted file mode 100644 index 3db068c508..0000000000 --- a/src/Npgsql/NoSynchronizationContextScope.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using System.Threading; - -namespace Npgsql -{ - /// - /// This mechanism is used to temporarily set the current synchronization context to null while - /// executing Npgsql code, making all await continuations execute on the thread pool. This replaces - /// the need to place ConfigureAwait(false) everywhere, and should be used in all surface async methods, - /// without exception. - /// - /// Warning: do not use this directly in async methods, use it in sync wrappers of async methods - /// (see https://github.com/npgsql/npgsql/issues/1593) - /// - /// - /// https://stackoverflow.com/a/28307965/640325 - /// - static class NoSynchronizationContextScope - { - internal static Disposable Enter() => new Disposable(SynchronizationContext.Current); - - internal struct Disposable : IDisposable - { - readonly SynchronizationContext? _synchronizationContext; - - internal Disposable(SynchronizationContext? synchronizationContext) - { - if (synchronizationContext != null) - SynchronizationContext.SetSynchronizationContext(null); - - _synchronizationContext = synchronizationContext; - } - - public void Dispose() - => SynchronizationContext.SetSynchronizationContext(_synchronizationContext); - } - } -} diff --git a/src/Npgsql/Npgsql.csproj b/src/Npgsql/Npgsql.csproj index 449f48400c..c0f5b64946 100644 --- a/src/Npgsql/Npgsql.csproj +++ b/src/Npgsql/Npgsql.csproj @@ -1,25 +1,60 @@  + - Shay Rojansky;Yoh Deadfall;Brar Piening;Nikita Kazmin;Austin Drenski;Emil Lenngren;Francisco Figueiredo Jr.;Kenji Uno + Shay Rojansky;Nikita Kazmin;Brar Piening;Nino Floris;Yoh Deadfall;;Austin Drenski;Emil Lenngren;Francisco Figueiredo Jr.;Kenji Uno Npgsql is the open source .NET data provider for PostgreSQL. - npgsql postgresql postgres ado ado.net database sql - - netstandard2.0;netstandard2.1;netcoreapp3.1;net5.0 - net5.0 - true + npgsql;postgresql;postgres;ado;ado.net;database;sql + README.md + netstandard2.0;netstandard2.1;net6.0;net7.0;net8.0 + net8.0 + $(NoWarn);CA2017 + $(NoWarn);NPG9001 + $(NoWarn);NPG9002 + - + + + + + + + + + + - - - - + + - - + + + + + + + + + + + + + + + + + ResXFileCodeGenerator + NpgsqlStrings.Designer.cs + + + + + + True + True + NpgsqlStrings.resx + diff --git a/src/Npgsql/NpgsqlActivitySource.cs b/src/Npgsql/NpgsqlActivitySource.cs new file mode 100644 index 0000000000..224bb2e658 --- /dev/null +++ b/src/Npgsql/NpgsqlActivitySource.cs @@ -0,0 +1,120 @@ +using Npgsql.Internal; +using System; +using System.Data; +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; + +namespace Npgsql; + +static class NpgsqlActivitySource +{ + static readonly ActivitySource Source = new("Npgsql", "0.1.0"); + + internal static bool IsEnabled => Source.HasListeners(); + + internal static Activity? CommandStart(NpgsqlConnector connector, string commandText, CommandType commandType) + { + var settings = connector.Settings; + + var dbName = settings.Database ?? connector.InferredUserName; + string? dbOperation = null; + string? dbSqlTable = null; + string activityName; + switch (commandType) + { + case CommandType.StoredProcedure: + dbOperation = NpgsqlCommand.EnableStoredProcedureCompatMode ? "SELECT" : "CALL"; + // In this case our activity name follows the concept of the CommandType.TableDirect case + // (" .") but replaces db.sql.table with the procedure name + // which seems to match the spec's intent without being explicitly specified that way (it suggests + // using the procedure name but doesn't mention using db.operation or db.name in that case). + activityName = $"{dbOperation} {dbName}.{commandText}"; + break; + case CommandType.TableDirect: + dbOperation = "SELECT"; + // The OpenTelemetry spec actually asks to include the database name into db.sql.table + // but then again mixes the concept of database and schema. + // As I interpret it, it actually wants db.sql.table to include the schema name and not the + // database name if the concept of schemas exists in the database system. + // This also makes sense in the context of the activity name which otherwise would include the + // database name twice. + dbSqlTable = commandText; + activityName = $"{dbOperation} {dbName}.{dbSqlTable}"; + break; + case CommandType.Text: + activityName = dbName; + break; + default: + throw new ArgumentOutOfRangeException(nameof(commandType), commandType, null); + } + + var activity = Source.StartActivity(activityName, ActivityKind.Client); + if (activity is not { IsAllDataRequested: true }) + return activity; + + activity.SetTag("db.system", "postgresql"); + activity.SetTag("db.connection_string", connector.UserFacingConnectionString); + activity.SetTag("db.user", connector.InferredUserName); + // We trace the actual (maybe inferred) database name we're connected to, even if it + // wasn't specified in the connection string + activity.SetTag("db.name", dbName); + activity.SetTag("db.statement", commandText); + activity.SetTag("db.connection_id", connector.Id); + if (dbOperation != null) + activity.SetTag("db.operation", dbOperation); + if (dbSqlTable != null) + activity.SetTag("db.sql.table", dbSqlTable); + + var endPoint = connector.ConnectedEndPoint; + Debug.Assert(endPoint is not null); + switch (endPoint) + { + case IPEndPoint ipEndPoint: + activity.SetTag("net.transport", "ip_tcp"); + activity.SetTag("net.peer.ip", ipEndPoint.Address.ToString()); + if (ipEndPoint.Port != 5432) + activity.SetTag("net.peer.port", ipEndPoint.Port); + activity.SetTag("net.peer.name", settings.Host); + break; + + case UnixDomainSocketEndPoint: + activity.SetTag("net.transport", "unix"); + activity.SetTag("net.peer.name", settings.Host); + break; + + default: + throw new ArgumentOutOfRangeException("Invalid endpoint type: " + endPoint.GetType()); + } + + return activity; + } + + internal static void ReceivedFirstResponse(Activity activity) + { + var activityEvent = new ActivityEvent("received-first-response"); + activity.AddEvent(activityEvent); + } + + internal static void CommandStop(Activity activity) + { + activity.SetTag("otel.status_code", "OK"); + activity.Dispose(); + } + + internal static void SetException(Activity activity, Exception ex, bool escaped = true) + { + var tags = new ActivityTagsCollection + { + { "exception.type", ex.GetType().FullName }, + { "exception.message", ex.Message }, + { "exception.stacktrace", ex.ToString() }, + { "exception.escaped", escaped } + }; + var activityEvent = new ActivityEvent("exception", tags: tags); + activity.AddEvent(activityEvent); + activity.SetTag("otel.status_code", "ERROR"); + activity.SetTag("otel.status_description", ex is PostgresException pgEx ? pgEx.SqlState : ex.Message); + activity.Dispose(); + } +} diff --git a/src/Npgsql/NpgsqlBatch.cs b/src/Npgsql/NpgsqlBatch.cs new file mode 100644 index 0000000000..446cb4746f --- /dev/null +++ b/src/Npgsql/NpgsqlBatch.cs @@ -0,0 +1,203 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; + +namespace Npgsql; + +/// +public class NpgsqlBatch : DbBatch +{ + internal const int DefaultBatchCommandsSize = 5; + + private protected NpgsqlCommand Command { get; } + + /// + protected override DbBatchCommandCollection DbBatchCommands => BatchCommands; + + /// + public new NpgsqlBatchCommandCollection BatchCommands { get; } + + /// + public override int Timeout + { + get => Command.CommandTimeout; + set => Command.CommandTimeout = value; + } + + /// + public new NpgsqlConnection? Connection + { + get => Command.Connection; + set => Command.Connection = value; + } + + /// + protected override DbConnection? DbConnection + { + get => Connection; + set => Connection = (NpgsqlConnection?)value; + } + + /// + public new NpgsqlTransaction? Transaction + { + get => Command.Transaction; + set => Command.Transaction = value; + } + + /// + protected override DbTransaction? DbTransaction + { + get => Transaction; + set => Transaction = (NpgsqlTransaction?)value; + } + + /// + /// Controls whether to place error barriers between all batch commands within this batch. Default to . + /// + /// + /// + /// By default, any exception in a command causes later commands in the batch to be skipped, and earlier commands to be rolled back. + /// Enabling error barriers ensures that errors do not affect other commands in the batch. + /// + /// + /// Note that if the batch is executed within an explicit transaction, the first error places the transaction in a failed state, + /// causing all later commands to fail in any case. As a result, this option is useful mainly when there is no explicit transaction. + /// + /// + /// At the PostgreSQL wire protocol level, this corresponds to inserting a Sync message between each command, rather than grouping + /// all the batch's commands behind a single terminating Sync. + /// + /// + /// To control error barriers on a command-by-command basis, see . + /// + /// + public bool EnableErrorBarriers + { + get => Command.EnableErrorBarriers; + set => Command.EnableErrorBarriers = value; + } + + /// + /// Marks all of the batch's result columns as either known or unknown. + /// Unknown results column are requested them from PostgreSQL in text format, and Npgsql makes no + /// attempt to parse them. They will be accessible as strings only. + /// + internal bool AllResultTypesAreUnknown + { + get => Command.AllResultTypesAreUnknown; + set => Command.AllResultTypesAreUnknown = value; + } + + /// + /// Initializes a new . + /// + /// A that represents the connection to a PostgreSQL server. + /// The in which the executes. + public NpgsqlBatch(NpgsqlConnection? connection = null, NpgsqlTransaction? transaction = null) + { + GC.SuppressFinalize(this); + Command = new(DefaultBatchCommandsSize); + BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); + + Connection = connection; + Transaction = transaction; + } + + internal NpgsqlBatch(NpgsqlConnector connector) + { + GC.SuppressFinalize(this); + Command = new(connector, DefaultBatchCommandsSize); + BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); + } + + private protected NpgsqlBatch(NpgsqlDataSourceCommand command) + { + GC.SuppressFinalize(this); + Command = command; + BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); + } + + /// + protected override DbBatchCommand CreateDbBatchCommand() => CreateBatchCommand(); + + /// + public new NpgsqlBatchCommand CreateBatchCommand() + => new NpgsqlBatchCommand(); + + /// + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + => ExecuteReader(behavior); + + /// + public new NpgsqlDataReader ExecuteReader(CommandBehavior behavior = CommandBehavior.Default) + => Command.ExecuteReader(behavior); + + /// + protected override async Task ExecuteDbDataReaderAsync( + CommandBehavior behavior, + CancellationToken cancellationToken) + => await ExecuteReaderAsync(behavior, cancellationToken).ConfigureAwait(false); + + /// + public new Task ExecuteReaderAsync(CancellationToken cancellationToken = default) + => Command.ExecuteReaderAsync(cancellationToken); + + /// + public new Task ExecuteReaderAsync( + CommandBehavior behavior, + CancellationToken cancellationToken = default) + => Command.ExecuteReaderAsync(behavior, cancellationToken); + + /// + public override int ExecuteNonQuery() + => Command.ExecuteNonQuery(); + + /// + public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken = default) + => Command.ExecuteNonQueryAsync(cancellationToken); + + /// + public override object? ExecuteScalar() + => Command.ExecuteScalar(); + + /// + public override Task ExecuteScalarAsync(CancellationToken cancellationToken = default) + => Command.ExecuteScalarAsync(cancellationToken); + + /// + public override void Prepare() + => Command.Prepare(); + + /// + public override Task PrepareAsync(CancellationToken cancellationToken = default) + => Command.PrepareAsync(cancellationToken); + + /// + public override void Cancel() => Command.Cancel(); + + /// + public override void Dispose() + { + Command.ResetTransaction(); + if (Command.IsCacheable && Connection is not null && Connection.CachedBatch is null) + { + BatchCommands.Clear(); + Command.Reset(); + Connection.CachedBatch = this; + return; + } + + Command.IsCacheable = false; + } + + internal static NpgsqlBatch CreateCachedBatch(NpgsqlConnection connection) + { + var batch = new NpgsqlBatch(connection); + batch.Command.IsCacheable = true; + return batch; + } +} diff --git a/src/Npgsql/NpgsqlBatchCommand.cs b/src/Npgsql/NpgsqlBatchCommand.cs new file mode 100644 index 0000000000..4123a91506 --- /dev/null +++ b/src/Npgsql/NpgsqlBatchCommand.cs @@ -0,0 +1,303 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Npgsql.BackendMessages; +using Npgsql.Internal; + +namespace Npgsql; + +/// +public sealed class NpgsqlBatchCommand : DbBatchCommand +{ + internal static readonly List EmptyParameters = new(); + + string _commandText; + + /// + [AllowNull] + public override string CommandText + { + get => _commandText; + set + { + _commandText = value ?? string.Empty; + + ResetPreparation(); + // TODO: Technically should do this also if the parameter list (or type) changes + } + } + + /// + public override CommandType CommandType { get; set; } = CommandType.Text; + + /// + protected override DbParameterCollection DbParameterCollection => Parameters; + + internal NpgsqlParameterCollection? _parameters; + /// + public new NpgsqlParameterCollection Parameters => _parameters ??= new(); + +#pragma warning disable CA1822 // Mark members as static + +#if NET8_0_OR_GREATER + /// + public override NpgsqlParameter CreateParameter() +#else + /// + /// Creates a new instance of a object. + /// + /// An object. + public NpgsqlParameter CreateParameter() +#endif + => new(); + +#if NET8_0_OR_GREATER + /// + public override bool CanCreateParameter +#else + /// + /// Returns whether the method is implemented. + /// + public bool CanCreateParameter +#endif + => true; + +#pragma warning restore CA1822 // Mark members as static + + /// + /// Appends an error barrier after this batch command. Defaults to the value of on the + /// batch. + /// + /// + /// + /// By default, any exception in a command causes later commands in the batch to be skipped, and earlier commands to be rolled back. + /// Appending an error barrier ensures that errors from this command (or previous ones) won't cause later commands to be skipped, + /// and that errors from later commands won't cause this command (or previous ones) to be rolled back). + /// + /// + /// Note that if the batch is executed within an explicit transaction, the first error places the transaction in a failed state, + /// causing all later commands to fail in any case. As a result, this option is useful mainly when there is no explicit transaction. + /// + /// + /// At the PostgreSQL wire protocol level, this corresponds to inserting a Sync message after this command, rather than grouping + /// all the batch's commands behind a single terminating Sync. + /// + /// + /// Controlling error barriers on a command-by-command basis is an advanced feature, consider enabling error barriers for the entire + /// batch via . + /// + /// + public bool? AppendErrorBarrier { get; set; } + + /// + /// The number of rows affected or retrieved. + /// + /// + /// See the command tag in the CommandComplete message for the meaning of this value for each , + /// https://www.postgresql.org/docs/current/static/protocol-message-formats.html + /// + public ulong Rows { get; internal set; } + + /// + public override int RecordsAffected + { + get + { + switch (StatementType) + { + case StatementType.Update: + case StatementType.Insert: + case StatementType.Delete: + case StatementType.Copy: + case StatementType.Move: + case StatementType.Merge: + return Rows > int.MaxValue + ? throw new OverflowException($"The number of records affected exceeds int.MaxValue. Use {nameof(Rows)}.") + : (int)Rows; + default: + return -1; + } + } + } + + /// + /// Specifies the type of query, e.g. SELECT. + /// + public StatementType StatementType { get; internal set; } + + /// + /// For an INSERT, the object ID of the inserted row if is 1 and + /// the target table has OIDs; otherwise 0. + /// + public uint OID { get; internal set; } + + /// + /// The SQL as it will be sent to PostgreSQL, after any rewriting performed by Npgsql (e.g. named to positional parameter + /// placeholders). + /// + internal string? FinalCommandText { get; set; } + + /// + /// The list of parameters, ordered positionally, as it will be sent to PostgreSQL. + /// + /// + /// If the user provided positional parameters, this references the (in batching mode) or the list + /// backing (in non-batching) mode. If the user provided named parameters, this is a + /// separate list containing the re-ordered parameters. + /// + internal List PositionalParameters + { + get => _inputParameters ??= _ownedInputParameters ??= new(); + set => _inputParameters = value; + } + + internal bool HasParameters => _inputParameters?.Count > 0 || _ownedInputParameters?.Count > 0; + + internal List CurrentParametersReadOnly => HasParameters ? PositionalParameters : EmptyParameters; + + List? _ownedInputParameters; + List? _inputParameters; + + /// + /// The RowDescription message for this query. If null, the query does not return rows (e.g. INSERT) + /// + internal RowDescriptionMessage? Description + { + get => PreparedStatement == null ? _description : PreparedStatement.Description; + set + { + if (PreparedStatement == null) + _description = value; + else + PreparedStatement.Description = value; + } + } + + RowDescriptionMessage? _description; + + /// + /// If this statement has been automatically prepared, references the . + /// Null otherwise. + /// + internal PreparedStatement? PreparedStatement + { + get => _preparedStatement != null && _preparedStatement.State == PreparedState.Unprepared + ? _preparedStatement = null + : _preparedStatement; + set => _preparedStatement = value; + } + + PreparedStatement? _preparedStatement; + + internal NpgsqlConnector? ConnectorPreparedOn { get; set; } + + internal bool IsPreparing; + + /// + /// Holds the server-side (prepared) ASCII statement name. Empty string for non-prepared statements. + /// + internal byte[] StatementName => PreparedStatement?.Name ?? Array.Empty(); + + /// + /// Whether this statement has already been prepared (including automatic preparation). + /// + internal bool IsPrepared => PreparedStatement?.IsPrepared == true; + + /// + /// Returns a prepared statement for this statement (including automatic preparation). + /// + internal bool TryGetPrepared([NotNullWhen(true)] out PreparedStatement? preparedStatement) + { + preparedStatement = PreparedStatement; + return preparedStatement?.IsPrepared == true; + } + + /// + /// Initializes a new . + /// + public NpgsqlBatchCommand() : this(string.Empty) {} + + /// + /// Initializes a new . + /// + /// The text of the . + public NpgsqlBatchCommand(string commandText) + => _commandText = commandText; + + internal bool ExplicitPrepare(NpgsqlConnector connector) + { + if (!IsPrepared) + { + PreparedStatement = connector.PreparedStatementManager.GetOrAddExplicit(this); + + if (PreparedStatement?.State == PreparedState.NotPrepared) + { + PreparedStatement.State = PreparedState.BeingPrepared; + IsPreparing = true; + return true; + } + } + + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal bool TryAutoPrepare(NpgsqlConnector connector) + { + // If this statement isn't prepared, see if it gets implicitly prepared. + // Note that this may return null (not enough usages for automatic preparation). + if (!TryGetPrepared(out var preparedStatement)) + preparedStatement = PreparedStatement = connector.PreparedStatementManager.TryGetAutoPrepared(this); + if (preparedStatement is not null) + { + if (preparedStatement.State == PreparedState.NotPrepared) + { + preparedStatement.State = PreparedState.BeingPrepared; + IsPreparing = true; + } + + return true; + } + + return false; + } + + internal void Reset() + { + CommandText = string.Empty; + StatementType = StatementType.Select; + _description = null; + Rows = 0; + OID = 0; + PreparedStatement = null; + + if (ReferenceEquals(_inputParameters, _ownedInputParameters)) + PositionalParameters.Clear(); + else if (_inputParameters is not null) + _inputParameters = null; // We're pointing at a user's NpgsqlParameterCollection + Debug.Assert(_inputParameters is null || _inputParameters.Count == 0); + Debug.Assert(_ownedInputParameters is null || _ownedInputParameters.Count == 0); + } + + internal void ApplyCommandComplete(CommandCompleteMessage msg) + { + StatementType = msg.StatementType; + Rows = msg.Rows; + OID = msg.OID; + } + + internal void ResetPreparation() + { + PreparedStatement = null; + ConnectorPreparedOn = null; + } + + /// + /// Returns the . + /// + public override string ToString() => CommandText; +} diff --git a/src/Npgsql/NpgsqlBatchCommandCollection.cs b/src/Npgsql/NpgsqlBatchCommandCollection.cs new file mode 100644 index 0000000000..a79afa359b --- /dev/null +++ b/src/Npgsql/NpgsqlBatchCommandCollection.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql; + +/// +public class NpgsqlBatchCommandCollection : DbBatchCommandCollection, IList +{ + readonly List _list; + + internal NpgsqlBatchCommandCollection(List batchCommands) + => _list = batchCommands; + + /// + public override int Count => _list.Count; + + /// + public override bool IsReadOnly => false; + + IEnumerator IEnumerable.GetEnumerator() => _list.GetEnumerator(); + + /// + public override IEnumerator GetEnumerator() => _list.GetEnumerator(); + + /// + public void Add(NpgsqlBatchCommand item) => _list.Add(item); + + /// + public override void Add(DbBatchCommand item) => Add(Cast(item)); + + /// + public override void Clear() => _list.Clear(); + + /// + public bool Contains(NpgsqlBatchCommand item) => _list.Contains(item); + + /// + public override bool Contains(DbBatchCommand item) => Contains(Cast(item)); + + /// + public void CopyTo(NpgsqlBatchCommand[] array, int arrayIndex) => _list.CopyTo(array, arrayIndex); + + /// + public override void CopyTo(DbBatchCommand[] array, int arrayIndex) + { + if (array is NpgsqlBatchCommand[] typedArray) + { + CopyTo(typedArray, arrayIndex); + return; + } + + throw new InvalidCastException( + $"{nameof(array)} is not of type {nameof(NpgsqlBatchCommand)} and cannot be used in this batch command collection."); + } + + /// + public int IndexOf(NpgsqlBatchCommand item) => _list.IndexOf(item); + + /// + public override int IndexOf(DbBatchCommand item) => IndexOf(Cast(item)); + + /// + public void Insert(int index, NpgsqlBatchCommand item) => _list.Insert(index, item); + + /// + public override void Insert(int index, DbBatchCommand item) => Insert(index, Cast(item)); + + /// + public bool Remove(NpgsqlBatchCommand item) => _list.Remove(item); + + /// + public override bool Remove(DbBatchCommand item) => Remove(Cast(item)); + + /// + public override void RemoveAt(int index) => _list.RemoveAt(index); + + NpgsqlBatchCommand IList.this[int index] + { + get => _list[index]; + set => _list[index] = value; + } + + /// + public new NpgsqlBatchCommand this[int index] + { + get => _list[index]; + set => _list[index] = value; + } + + /// + protected override DbBatchCommand GetBatchCommand(int index) + => _list[index]; + + /// + protected override void SetBatchCommand(int index, DbBatchCommand batchCommand) + => _list[index] = Cast(batchCommand); + + static NpgsqlBatchCommand Cast(DbBatchCommand? value) + { + var castedValue = value as NpgsqlBatchCommand; + if (castedValue is null) + ThrowInvalidCastException(value); + + return castedValue; + } + + [DoesNotReturn] + static void ThrowInvalidCastException(DbBatchCommand? value) => + throw new InvalidCastException( + $"The value \"{value}\" is not of type \"{nameof(NpgsqlBatchCommand)}\" and cannot be used in this batch command collection."); +} diff --git a/src/Npgsql/NpgsqlBinaryExporter.cs b/src/Npgsql/NpgsqlBinaryExporter.cs index 3d5dac08a8..406962d837 100644 --- a/src/Npgsql/NpgsqlBinaryExporter.cs +++ b/src/Npgsql/NpgsqlBinaryExporter.cs @@ -1,445 +1,533 @@ using System; using System.Diagnostics; -using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; -using Npgsql.Logging; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using NpgsqlTypes; using static Npgsql.Util.Statics; -namespace Npgsql +namespace Npgsql; + +/// +/// Provides an API for a binary COPY TO operation, a high-performance data export mechanism from +/// a PostgreSQL table. Initiated by +/// +public sealed class NpgsqlBinaryExporter : ICancelable { + const int BeforeRow = -2; + const int BeforeColumn = -1; + + #region Fields and Properties + + NpgsqlConnector _connector; + NpgsqlReadBuffer _buf; + bool _isConsumed, _isDisposed; + long _endOfMessagePos; + + short _column; + ulong _rowsExported; + + PgReader PgReader => _buf.PgReader; + + /// + /// The number of columns, as returned from the backend in the CopyInResponse. + /// + int NumColumns { get; set; } + + PgConverterInfo[] _columnInfoCache; + + readonly ILogger _copyLogger; + /// - /// Provides an API for a binary COPY TO operation, a high-performance data export mechanism from - /// a PostgreSQL table. Initiated by + /// Current timeout /// - public sealed class NpgsqlBinaryExporter : ICancelable, IAsyncDisposable + public TimeSpan Timeout { - #region Fields and Properties + set => _buf.Timeout = value; + } - NpgsqlConnector _connector; - NpgsqlReadBuffer _buf; - ConnectorTypeMapper _typeMapper; - bool _isConsumed, _isDisposed; - int _leftToReadInDataMsg, _columnLen; + #endregion - short _column; + #region Construction / Initialization - /// - /// The number of columns, as returned from the backend in the CopyInResponse. - /// - internal int NumColumns { get; } + internal NpgsqlBinaryExporter(NpgsqlConnector connector) + { + _connector = connector; + _buf = connector.ReadBuffer; + _column = BeforeRow; + _columnInfoCache = null!; + _copyLogger = connector.LoggingConfiguration.CopyLogger; + } - readonly NpgsqlTypeHandler?[] _typeHandlerCache; + internal async Task Init(string copyToCommand, bool async, CancellationToken cancellationToken = default) + { + await _connector.WriteQuery(copyToCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlBinaryExporter)); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - /// - /// Current timeout - /// - public TimeSpan Timeout + CopyOutResponseMessage copyOutResponse; + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) { - set + case BackendMessageCode.CopyOutResponse: + copyOutResponse = (CopyOutResponseMessage)msg; + if (!copyOutResponse.IsBinary) { - _buf.Timeout = value; - // While calling Complete(), we're using the connector, which overwrites the buffer's timeout with it's own - _connector.UserTimeout = (int)value.TotalMilliseconds; + throw _connector.Break( + new ArgumentException("copyToCommand triggered a text transfer, only binary is allowed", + nameof(copyToCommand))); } + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); } - #endregion + NumColumns = copyOutResponse.NumColumns; + _columnInfoCache = new PgConverterInfo[NumColumns]; + _rowsExported = 0; + _endOfMessagePos = _buf.CumulativeReadPosition; + await ReadHeader(async).ConfigureAwait(false); + } - #region Construction / Initialization + async Task ReadHeader(bool async) + { + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + _endOfMessagePos = _buf.CumulativeReadPosition + Expect(msg, _connector).Length; + var headerLen = NpgsqlRawCopyStream.BinarySignature.Length + 4 + 4; + await _buf.Ensure(headerLen, async).ConfigureAwait(false); - internal NpgsqlBinaryExporter(NpgsqlConnector connector, string copyToCommand) - { - _connector = connector; - _buf = connector.ReadBuffer; - _typeMapper = connector.TypeMapper; - _columnLen = int.MinValue; // Mark that the (first) column length hasn't been read yet - _column = -1; + foreach (var t in NpgsqlRawCopyStream.BinarySignature) + if (_buf.ReadByte() != t) + throw new NpgsqlException("Invalid COPY binary signature at beginning!"); - _connector.WriteQuery(copyToCommand); - _connector.Flush(); + var flags = _buf.ReadInt32(); + if (flags != 0) + throw new NotSupportedException("Unsupported flags in COPY operation (OID inclusion?)"); - using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); + _buf.ReadInt32(); // Header extensions, currently unused + } - CopyOutResponseMessage copyOutResponse; - var msg = _connector.ReadMessage(async: false).GetAwaiter().GetResult(); - switch (msg.Code) - { - case BackendMessageCode.CopyOutResponse: - copyOutResponse = (CopyOutResponseMessage)msg; - if (!copyOutResponse.IsBinary) - { - throw _connector.Break( - new ArgumentException("copyToCommand triggered a text transfer, only binary is allowed", - nameof(copyToCommand))); - } - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); - } + #endregion - NumColumns = copyOutResponse.NumColumns; - _typeHandlerCache = new NpgsqlTypeHandler[NumColumns]; - ReadHeader(); - } + #region Read - void ReadHeader() - { - _leftToReadInDataMsg = Expect(_connector.ReadMessage(), _connector).Length; - var headerLen = NpgsqlRawCopyStream.BinarySignature.Length + 4 + 4; - _buf.Ensure(headerLen); + /// + /// Starts reading a single row, must be invoked before reading any columns. + /// + /// + /// The number of columns in the row. -1 if there are no further rows. + /// Note: This will currently be the same value for all rows, but this may change in the future. + /// + public int StartRow() => StartRow(false).GetAwaiter().GetResult(); - if (NpgsqlRawCopyStream.BinarySignature.Any(t => _buf.ReadByte() != t)) - throw new NpgsqlException("Invalid COPY binary signature at beginning!"); + /// + /// Starts reading a single row, must be invoked before reading any columns. + /// + /// + /// The number of columns in the row. -1 if there are no further rows. + /// Note: This will currently be the same value for all rows, but this may change in the future. + /// + public ValueTask StartRowAsync(CancellationToken cancellationToken = default) => StartRow(true, cancellationToken); + + async ValueTask StartRow(bool async, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + if (_isConsumed) + return -1; - var flags = _buf.ReadInt32(); - if (flags != 0) - throw new NotSupportedException("Unsupported flags in COPY operation (OID inclusion?)"); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken); - _buf.ReadInt32(); // Header extensions, currently unused - _leftToReadInDataMsg -= headerLen; + // Consume and advance any active column. + if (_column >= 0) + { + if (async) + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + else + PgReader.Commit(resuming: false); + _column++; } - #endregion - - #region Read - - /// - /// Starts reading a single row, must be invoked before reading any columns. - /// - /// - /// The number of columns in the row. -1 if there are no further rows. - /// Note: This will currently be the same value for all rows, but this may change in the future. - /// - public int StartRow() => StartRow(false).GetAwaiter().GetResult(); - - /// - /// Starts reading a single row, must be invoked before reading any columns. - /// - /// - /// The number of columns in the row. -1 if there are no further rows. - /// Note: This will currently be the same value for all rows, but this may change in the future. - /// - public ValueTask StartRowAsync(CancellationToken cancellationToken = default) + // The very first row (i.e. _column == -1) is included in the header's CopyData message. + // Otherwise we need to read in a new CopyData row (the docs specify that there's a CopyData + // message per row). + if (_column == NumColumns) { - using (NoSynchronizationContextScope.Enter()) - return StartRow(true, cancellationToken); + var msg = Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + _endOfMessagePos = _buf.CumulativeReadPosition + msg.Length; } + else if (_column != BeforeRow) + ThrowHelper.ThrowInvalidOperationException("Already in the middle of a row"); - async ValueTask StartRow(bool async, CancellationToken cancellationToken = default) + await _buf.Ensure(2, async).ConfigureAwait(false); + + var numColumns = _buf.ReadInt16(); + if (numColumns == -1) { - CheckDisposed(); - if (_isConsumed) - return -1; + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + _column = BeforeRow; + _isConsumed = true; + return -1; + } - using var registration = _connector.StartNestedCancellableOperation(cancellationToken); + Debug.Assert(numColumns == NumColumns); - // The very first row (i.e. _column == -1) is included in the header's CopyData message. - // Otherwise we need to read in a new CopyData row (the docs specify that there's a CopyData - // message per row). - if (_column == NumColumns) - _leftToReadInDataMsg = Expect(await _connector.ReadMessage(async), _connector).Length; - else if (_column != -1) - throw new InvalidOperationException("Already in the middle of a row"); + _column = BeforeColumn; + _rowsExported++; + return NumColumns; + } - await _buf.Ensure(2, async); - _leftToReadInDataMsg -= 2; + /// + /// Reads the current column, returns its value and moves ahead to the next column. + /// If the column is null an exception is thrown. + /// + /// + /// The type of the column to be read. This must correspond to the actual type or data + /// corruption will occur. If in doubt, use to manually + /// specify the type. + /// + /// The value of the column + public T Read() + => Read(null); - var numColumns = _buf.ReadInt16(); - if (numColumns == -1) - { - Debug.Assert(_leftToReadInDataMsg == 0); - Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - _column = -1; - _isConsumed = true; - return -1; - } + /// + /// Reads the current column, returns its value and moves ahead to the next column. + /// If the column is null an exception is thrown. + /// + /// + /// The type of the column to be read. This must correspond to the actual type or data + /// corruption will occur. If in doubt, use to manually + /// specify the type. + /// + /// The value of the column + public ValueTask ReadAsync(CancellationToken cancellationToken = default) + => ReadAsync(null, cancellationToken); - Debug.Assert(numColumns == NumColumns); + /// + /// Reads the current column, returns its value according to and + /// moves ahead to the next column. + /// If the column is null an exception is thrown. + /// + /// + /// In some cases isn't enough to infer the data type coming in from the + /// database. This parameter can be used to unambiguously specify the type. An example is the JSONB + /// type, for which will be a simple string but for which + /// must be specified as . + /// + /// The .NET type of the column to be read. + /// The value of the column + public T Read(NpgsqlDbType type) + => Read((NpgsqlDbType?)type); - _column = 0; - return NumColumns; - } + /// + /// Reads the current column, returns its value according to and + /// moves ahead to the next column. + /// If the column is null an exception is thrown. + /// + /// + /// In some cases isn't enough to infer the data type coming in from the + /// database. This parameter can be used to unambiguously specify the type. An example is the JSONB + /// type, for which will be a simple string but for which + /// must be specified as . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The .NET type of the column to be read. + /// The value of the column + public ValueTask ReadAsync(NpgsqlDbType type, CancellationToken cancellationToken = default) + => ReadAsync((NpgsqlDbType?)type, cancellationToken); + + T Read(NpgsqlDbType? type) + { + ThrowIfNotOnRow(); - /// - /// Reads the current column, returns its value and moves ahead to the next column. - /// If the column is null an exception is thrown. - /// - /// - /// The type of the column to be read. This must correspond to the actual type or data - /// corruption will occur. If in doubt, use to manually - /// specify the type. - /// - /// The value of the column - public T Read() => Read(false).GetAwaiter().GetResult(); - - /// - /// Reads the current column, returns its value and moves ahead to the next column. - /// If the column is null an exception is thrown. - /// - /// - /// The type of the column to be read. This must correspond to the actual type or data - /// corruption will occur. If in doubt, use to manually - /// specify the type. - /// - /// The value of the column - public ValueTask ReadAsync(CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return Read(true, cancellationToken); - } + if (!IsInitializedAndAtStart) + MoveNextColumn(resumableOp: false); - ValueTask Read(bool async, CancellationToken cancellationToken = default) + var reader = PgReader; + try { - CheckDisposed(); + if (reader.FieldSize is -1) + return DbNullOrThrow(); - if (_column == -1 || _column == NumColumns) - throw new InvalidOperationException("Not reading a row"); + var info = GetInfo(typeof(T), type, out var asObject); - var type = typeof(T); - var handler = _typeHandlerCache[_column]; - if (handler == null) - handler = _typeHandlerCache[_column] = _typeMapper.GetByClrType(type); + reader.StartRead(info.BufferRequirement); + var result = asObject + ? (T)info.Converter.ReadAsObject(reader) + : info.Converter.UnsafeDowncast().Read(reader); + reader.EndRead(); - return DoRead(handler, async, cancellationToken); + return result; } - - /// - /// Reads the current column, returns its value according to and - /// moves ahead to the next column. - /// If the column is null an exception is thrown. - /// - /// - /// In some cases isn't enough to infer the data type coming in from the - /// database. This parameter and be used to unambiguously specify the type. An example is the JSONB - /// type, for which will be a simple string but for which - /// must be specified as . - /// - /// The .NET type of the column to be read. - /// The value of the column - public T Read(NpgsqlDbType type) => Read(type, false).GetAwaiter().GetResult(); - - /// - /// Reads the current column, returns its value according to and - /// moves ahead to the next column. - /// If the column is null an exception is thrown. - /// - /// - /// In some cases isn't enough to infer the data type coming in from the - /// database. This parameter and be used to unambiguously specify the type. An example is the JSONB - /// type, for which will be a simple string but for which - /// must be specified as . - /// - /// - /// The .NET type of the column to be read. - /// The value of the column - public ValueTask ReadAsync(NpgsqlDbType type, CancellationToken cancellationToken = default) + finally { - using (NoSynchronizationContextScope.Enter()) - return Read(type, true, cancellationToken); + // Don't delay committing the current column, just do it immediately (as opposed to on the next action: Read, IsNull, Skip). + // Zero length columns would otherwise create an edge-case where we'd have to immediately commit as we won't know whether we're at the end. + // To guarantee the commit happens in that case we would still need this try finally, at which point it's just better to be consistent. + reader.Commit(resuming: false); } + } - ValueTask Read(NpgsqlDbType type, bool async, CancellationToken cancellationToken = default) + async ValueTask ReadAsync(NpgsqlDbType? type, CancellationToken cancellationToken) + { + ThrowIfNotOnRow(); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + if (!IsInitializedAndAtStart) + await MoveNextColumnAsync(resumableOp: false).ConfigureAwait(false); + + var reader = PgReader; + try { - CheckDisposed(); - if (_column == -1 || _column == NumColumns) - throw new InvalidOperationException("Not reading a row"); + if (reader.FieldSize is -1) + return DbNullOrThrow(); - var handler = _typeHandlerCache[_column]; - if (handler == null) - handler = _typeHandlerCache[_column] = _typeMapper.GetByNpgsqlDbType(type); + var info = GetInfo(typeof(T), type, out var asObject); - return DoRead(handler, async, cancellationToken); - } + await reader.StartReadAsync(info.BufferRequirement, cancellationToken).ConfigureAwait(false); + var result = asObject + ? (T)await info.Converter.ReadAsObjectAsync(reader, cancellationToken).ConfigureAwait(false) + : await info.Converter.UnsafeDowncast().ReadAsync(reader, cancellationToken).ConfigureAwait(false); + await reader.EndReadAsync().ConfigureAwait(false); - async ValueTask DoRead(NpgsqlTypeHandler handler, bool async, CancellationToken cancellationToken = default) + return result; + } + finally { - try - { - using var registration = _connector.StartNestedCancellableOperation(cancellationToken); - - await ReadColumnLenIfNeeded(async); - - if (_columnLen == -1) - { -#pragma warning disable CS8653 // A default expression introduces a null value when 'T' is a non-nullable reference type. - // When T is a Nullable, we support returning null - if (NullableHandler.Exists) - return default!; -#pragma warning restore CS8653 - throw new InvalidCastException("Column is null"); - } - - // If we know the entire column is already in memory, use the code path without async - var result = NullableHandler.Exists - ? _columnLen <= _buf.ReadBytesLeft - ? NullableHandler.Read(handler, _buf, _columnLen) - : await NullableHandler.ReadAsync(handler, _buf, _columnLen, async) - : _columnLen <= _buf.ReadBytesLeft - ? handler.Read(_buf, _columnLen) - : await handler.Read(_buf, _columnLen, async); - - _leftToReadInDataMsg -= _columnLen; - _columnLen = int.MinValue; // Mark that the (next) column length hasn't been read yet - _column++; - return result; - } - catch (Exception e) - { - _connector.Break(e); - Cleanup(); - throw; - } + // Don't delay committing the current column, just do it immediately (as opposed to on the next action: Read, IsNull, Skip). + // Zero length columns would otherwise create an edge-case where we'd have to immediately commit as we won't know whether we're at the end. + // To guarantee the commit happens in that case we would still need this try finally, at which point it's just better to be consistent. + await reader.CommitAsync(resuming: false).ConfigureAwait(false); } + } + + static T DbNullOrThrow() + { + // When T is a Nullable, we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + throw new InvalidCastException("Column is null"); + } + + PgConverterInfo GetInfo(Type type, NpgsqlDbType? npgsqlDbType, out bool asObject) + { + ref var cachedInfo = ref _columnInfoCache[_column]; + var converterInfo = cachedInfo.IsDefault ? cachedInfo = CreateConverterInfo(type, npgsqlDbType) : cachedInfo; + asObject = converterInfo.IsBoxingConverter; + return converterInfo; + } - /// - /// Returns whether the current column is null. - /// - public bool IsNull + PgConverterInfo CreateConverterInfo(Type type, NpgsqlDbType? npgsqlDbType = null) + { + var options = _connector.SerializerOptions; + PgTypeId? pgTypeId = null; + if (npgsqlDbType.HasValue) { - get - { - ReadColumnLenIfNeeded(false).GetAwaiter().GetResult(); - return _columnLen == -1; - } + pgTypeId = npgsqlDbType.Value.ToDataTypeName() is { } name + ? options.GetCanonicalTypeId(name) + // Handle plugin types via lookup. + : GetRepresentationalOrDefault(npgsqlDbType.Value.ToUnqualifiedDataTypeNameOrThrow()); } + var info = options.GetTypeInfo(type, pgTypeId) + ?? throw new NotSupportedException($"Reading is not supported for type '{type}'{(npgsqlDbType is null ? "" : $" and NpgsqlDbType '{npgsqlDbType}'")}"); - /// - /// Skips the current column without interpreting its value. - /// - public void Skip() => Skip(false).GetAwaiter().GetResult(); + // Binary export has no type info so we only do caller-directed interpretation of data. + return info.Bind(new Field("?", + info.PgTypeId ?? ((PgResolverTypeInfo)info).GetDefaultResolution(null).PgTypeId, -1), DataFormat.Binary); - /// - /// Skips the current column without interpreting its value. - /// - public Task SkipAsync(CancellationToken cancellationToken = default) + PgTypeId GetRepresentationalOrDefault(string dataTypeName) { - using (NoSynchronizationContextScope.Enter()) - return Skip(true, cancellationToken); + var type = options.DatabaseInfo.GetPostgresType(dataTypeName); + return options.ToCanonicalTypeId(type.GetRepresentationalType()); } + } - async Task Skip(bool async, CancellationToken cancellationToken = default) + /// + /// Returns whether the current column is null. + /// + public bool IsNull + { + get { - CheckDisposed(); + ThrowIfNotOnRow(); + if (!IsInitializedAndAtStart) + return MoveNextColumn(resumableOp: true) is -1; - using var registration = _connector.StartNestedCancellableOperation(cancellationToken); + return PgReader.FieldSize is - 1; + } + } - await ReadColumnLenIfNeeded(async); - if (_columnLen != -1) - await _buf.Skip(_columnLen, async); + /// + /// Skips the current column without interpreting its value. + /// + public void Skip() + { + ThrowIfNotOnRow(); - _columnLen = int.MinValue; - _column++; - } + if (!IsInitializedAndAtStart) + MoveNextColumn(resumableOp: false); - #endregion + PgReader.Commit(resuming: false); + } - #region Utilities + /// + /// Skips the current column without interpreting its value. + /// + public async Task SkipAsync(CancellationToken cancellationToken = default) + { + ThrowIfNotOnRow(); - async Task ReadColumnLenIfNeeded(bool async) - { - if (_columnLen == int.MinValue) - { - await _buf.Ensure(4, async); - _columnLen = _buf.ReadInt32(); - _leftToReadInDataMsg -= 4; - } - } + using var registration = _connector.StartNestedCancellableOperation(cancellationToken); - void CheckDisposed() - { - if (_isDisposed) - throw new ObjectDisposedException(GetType().FullName, "The COPY operation has already ended."); - } + if (!IsInitializedAndAtStart) + await MoveNextColumnAsync(resumableOp: false).ConfigureAwait(false); - #endregion + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + } - #region Cancel / Close / Dispose + #endregion - /// - /// Cancels an ongoing export. - /// - public void Cancel() => _connector.PerformUserCancellation(); + #region Utilities - /// - /// Completes that binary export and sets the connection back to idle state - /// - public void Dispose() => DisposeAsync(false).GetAwaiter().GetResult(); + bool IsInitializedAndAtStart => PgReader.Initialized && (PgReader.FieldSize is -1 || PgReader.FieldOffset is 0); - /// - /// Async completes that binary export and sets the connection back to idle state - /// - /// - public ValueTask DisposeAsync() + int MoveNextColumn(bool resumableOp) + { + PgReader.Commit(resuming: false); + + if (_column + 1 == NumColumns) + ThrowHelper.ThrowInvalidOperationException("No more columns left in the current row"); + _column++; + _buf.Ensure(sizeof(int)); + var columnLen = _buf.ReadInt32(); + PgReader.Init(columnLen, DataFormat.Binary, resumableOp); + return PgReader.FieldSize; + } + + async ValueTask MoveNextColumnAsync(bool resumableOp) + { + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + + if (_column + 1 == NumColumns) + ThrowHelper.ThrowInvalidOperationException("No more columns left in the current row"); + _column++; + await _buf.Ensure(sizeof(int), async: true).ConfigureAwait(false); + var columnLen = _buf.ReadInt32(); + PgReader.Init(columnLen, DataFormat.Binary, resumableOp); + return PgReader.FieldSize; + } + + void ThrowIfNotOnRow() + { + ThrowIfDisposed(); + if (_column is BeforeRow) + ThrowHelper.ThrowInvalidOperationException("Not reading a row"); + } + + void ThrowIfDisposed() + { + if (_isDisposed) + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlBinaryExporter), "The COPY operation has already ended."); + } + + #endregion + + #region Cancel / Close / Dispose + + /// + /// Cancels an ongoing export. + /// + public void Cancel() => _connector.PerformUserCancellation(); + + /// + /// Async cancels an ongoing export. + /// + public Task CancelAsync() + { + Cancel(); + return Task.CompletedTask; + } + + /// + /// Completes that binary export and sets the connection back to idle state + /// + public void Dispose() => DisposeAsync(async: false).GetAwaiter().GetResult(); + + /// + /// Async completes that binary export and sets the connection back to idle state + /// + /// + public ValueTask DisposeAsync() => DisposeAsync(async: true); + + async ValueTask DisposeAsync(bool async) + { + if (_isDisposed) + return; + + if (_isConsumed) { - using (NoSynchronizationContextScope.Enter()) - return DisposeAsync(true); + LogMessages.BinaryCopyOperationCompleted(_copyLogger, _rowsExported, _connector.Id); } - - async ValueTask DisposeAsync(bool async) + else if (!_connector.IsBroken) { - if (_isDisposed) - return; - - if (!_isConsumed) + try { - try - { - using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); - // Finish the current CopyData message - _buf.Skip(_leftToReadInDataMsg); - // Read to the end - _connector.SkipUntil(BackendMessageCode.CopyDone); - // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup - Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - } - catch (OperationCanceledException e) when (e.InnerException is PostgresException pg && pg.SqlState == PostgresErrorCodes.QueryCanceled) - { - Log.Debug($"Caught an exception while disposing the {nameof(NpgsqlBinaryExporter)}, indicating that it was cancelled.", e, _connector.Id); - } - catch (Exception e) - { - Log.Error($"Caught an exception while disposing the {nameof(NpgsqlBinaryExporter)}.", e, _connector.Id); - } + using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); + // Be sure to commit the reader. + if (async) + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + else + PgReader.Commit(resuming: false); + // Finish the current CopyData message + await _buf.Skip(checked((int)(_endOfMessagePos - _buf.CumulativeReadPosition)), async).ConfigureAwait(false); + // Read to the end + _connector.SkipUntil(BackendMessageCode.CopyDone); + // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + } + catch (OperationCanceledException e) when (e.InnerException is PostgresException { SqlState: PostgresErrorCodes.QueryCanceled }) + { + LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); + } + catch (Exception e) + { + LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); } - - _connector.EndUserAction(); - Cleanup(); } -#pragma warning disable CS8625 + _connector.EndUserAction(); + Cleanup(); + void Cleanup() { + Debug.Assert(!_isDisposed); var connector = _connector; - Log.Debug("COPY operation ended", connector?.Id ?? -1); - if (connector != null) + if (!ReferenceEquals(connector, null)) { connector.CurrentCopyOperation = null; _connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); - _connector = null; + _connector = null!; } - _typeMapper = null; - _buf = null; + _buf = null!; _isDisposed = true; } -#pragma warning restore CS8625 - - #endregion } + + #endregion } diff --git a/src/Npgsql/NpgsqlBinaryImporter.cs b/src/Npgsql/NpgsqlBinaryImporter.cs index 2b42f29e86..f80807af3e 100644 --- a/src/Npgsql/NpgsqlBinaryImporter.cs +++ b/src/Npgsql/NpgsqlBinaryImporter.cs @@ -1,601 +1,578 @@ using System; -using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using JetBrains.Annotations; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; -using Npgsql.Logging; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using NpgsqlTypes; using static Npgsql.Util.Statics; -namespace Npgsql +namespace Npgsql; + +/// +/// Provides an API for a binary COPY FROM operation, a high-performance data import mechanism to +/// a PostgreSQL table. Initiated by +/// +/// +/// See https://www.postgresql.org/docs/current/static/sql-copy.html. +/// +public sealed class NpgsqlBinaryImporter : ICancelable { - /// - /// Provides an API for a binary COPY FROM operation, a high-performance data import mechanism to - /// a PostgreSQL table. Initiated by - /// - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public sealed class NpgsqlBinaryImporter : ICancelable, IAsyncDisposable - { - #region Fields and Properties + #region Fields and Properties - NpgsqlConnector _connector; - NpgsqlWriteBuffer _buf; + NpgsqlConnector _connector; + NpgsqlWriteBuffer _buf; - ImporterState _state; + ImporterState _state; - /// - /// The number of columns in the current (not-yet-written) row. - /// - short _column; + /// + /// The number of columns in the current (not-yet-written) row. + /// + short _column; + ulong _rowsImported; - /// - /// The number of columns, as returned from the backend in the CopyInResponse. - /// - internal int NumColumns { get; } + /// + /// The number of columns, as returned from the backend in the CopyInResponse. + /// + int NumColumns => _params.Length; - bool InMiddleOfRow => _column != -1 && _column != NumColumns; + bool InMiddleOfRow => _column != -1 && _column != NumColumns; - readonly NpgsqlParameter?[] _params; + NpgsqlParameter?[] _params; - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlBinaryImporter)); + readonly ILogger _copyLogger; + PgWriter _pgWriter = null!; // Setup in Init - /// - /// Current timeout - /// - public TimeSpan Timeout + /// + /// Current timeout + /// + public TimeSpan Timeout + { + set { - set - { - _buf.Timeout = value; - // While calling Complete(), we're using the connector, which overwrites the buffer's timeout with it's own - _connector.UserTimeout = (int)value.TotalMilliseconds; - } + _buf.Timeout = value; + _connector.ReadBuffer.Timeout = value; } + } - #endregion + #endregion - #region Construction / Initialization + #region Construction / Initialization - internal NpgsqlBinaryImporter(NpgsqlConnector connector, string copyFromCommand) - { - _connector = connector; - _buf = connector.WriteBuffer; - _column = -1; + internal NpgsqlBinaryImporter(NpgsqlConnector connector) + { + _connector = connector; + _buf = connector.WriteBuffer; + _column = -1; + _params = null!; + _copyLogger = connector.LoggingConfiguration.CopyLogger; + } - _connector.WriteQuery(copyFromCommand); - _connector.Flush(); + internal async Task Init(string copyFromCommand, bool async, CancellationToken cancellationToken = default) + { + await _connector.WriteQuery(copyFromCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); - using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - CopyInResponseMessage copyInResponse; - var msg = _connector.ReadMessage(async: false).GetAwaiter().GetResult(); - switch (msg.Code) + CopyInResponseMessage copyInResponse; + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) + { + case BackendMessageCode.CopyInResponse: + copyInResponse = (CopyInResponseMessage)msg; + if (!copyInResponse.IsBinary) { - case BackendMessageCode.CopyInResponse: - copyInResponse = (CopyInResponseMessage)msg; - if (!copyInResponse.IsBinary) - { - throw _connector.Break( - new ArgumentException("copyFromCommand triggered a text transfer, only binary is allowed", - nameof(copyFromCommand))); - } - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); + throw _connector.Break( + new ArgumentException("copyFromCommand triggered a text transfer, only binary is allowed", + nameof(copyFromCommand))); } - - NumColumns = copyInResponse.NumColumns; - _params = new NpgsqlParameter[NumColumns]; - _buf.StartCopyMode(); - WriteHeader(); + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); } - void WriteHeader() - { - _buf.WriteBytes(NpgsqlRawCopyStream.BinarySignature, 0, NpgsqlRawCopyStream.BinarySignature.Length); - _buf.WriteInt32(0); // Flags field. OID inclusion not supported at the moment. - _buf.WriteInt32(0); // Header extension area length - } + _params = new NpgsqlParameter[copyInResponse.NumColumns]; + _rowsImported = 0; + _buf.StartCopyMode(); + WriteHeader(); + // Only init after header. + _pgWriter = _buf.GetWriter(_connector.DatabaseInfo); + } - #endregion + void WriteHeader() + { + _buf.WriteBytes(NpgsqlRawCopyStream.BinarySignature, 0, NpgsqlRawCopyStream.BinarySignature.Length); + _buf.WriteInt32(0); // Flags field. OID inclusion not supported at the moment. + _buf.WriteInt32(0); // Header extension area length + } - #region Write + #endregion - /// - /// Starts writing a single row, must be invoked before writing any columns. - /// - public void StartRow() => StartRow(false).GetAwaiter().GetResult(); + #region Write - /// - /// Starts writing a single row, must be invoked before writing any columns. - /// - public Task StartRowAsync(CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return StartRow(true, cancellationToken); - } + /// + /// Starts writing a single row, must be invoked before writing any columns. + /// + public void StartRow() => StartRow(false).GetAwaiter().GetResult(); - async Task StartRow(bool async, CancellationToken cancellationToken = default) - { - CheckReady(); + /// + /// Starts writing a single row, must be invoked before writing any columns. + /// + public Task StartRowAsync(CancellationToken cancellationToken = default) => StartRow(async: true, cancellationToken); - if (_column != -1 && _column != NumColumns) - ThrowHelper.ThrowInvalidOperationException_BinaryImportParametersMismatch(NumColumns, _column); + async Task StartRow(bool async, CancellationToken cancellationToken = default) + { + CheckReady(); + cancellationToken.ThrowIfCancellationRequested(); + if (_column is not -1 && _column != NumColumns) + ThrowColumnMismatch(); + + if (_buf.WriteSpaceLeft < 2) + await _buf.Flush(async, cancellationToken).ConfigureAwait(false); + _buf.WriteInt16((short)NumColumns); + + _pgWriter.RefreshBuffer(); + _column = 0; + _rowsImported++; + } - try - { - if (_buf.WriteSpaceLeft < 2) - await _buf.Flush(async, cancellationToken); - _buf.WriteInt16(NumColumns); + /// + /// Writes a single column in the current row. + /// + /// The value to be written + /// + /// The type of the column to be written. This must correspond to the actual type or data + /// corruption will occur. If in doubt, use to manually + /// specify the type. + /// + public void Write(T value) + => Write(async: false, value, npgsqlDbType: null, dataTypeName: null).GetAwaiter().GetResult(); - _column = 0; - } - catch - { - // An exception here will have already broken the connection etc. - Cleanup(); - throw; - } - } + /// + /// Writes a single column in the current row. + /// + /// The value to be written + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + /// The type of the column to be written. This must correspond to the actual type or data + /// corruption will occur. If in doubt, use to manually + /// specify the type. + /// + public Task WriteAsync(T value, CancellationToken cancellationToken = default) + => Write(async: true, value, npgsqlDbType: null, dataTypeName: null, cancellationToken); - /// - /// Writes a single column in the current row. - /// - /// The value to be written - /// - /// The type of the column to be written. This must correspond to the actual type or data - /// corruption will occur. If in doubt, use to manually - /// specify the type. - /// - public void Write([AllowNull] T value) => Write(value, false).GetAwaiter().GetResult(); - - /// - /// Writes a single column in the current row. - /// - /// The value to be written - /// - /// - /// The type of the column to be written. This must correspond to the actual type or data - /// corruption will occur. If in doubt, use to manually - /// specify the type. - /// - public Task WriteAsync([AllowNull] T value, CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return Write(value, true, cancellationToken); - } + /// + /// Writes a single column in the current row as type . + /// + /// The value to be written + /// + /// In some cases isn't enough to infer the data type to be written to + /// the database. This parameter can be used to unambiguously specify the type. An example is + /// the JSONB type, for which will be a simple string but for which + /// must be specified as . + /// + /// The .NET type of the column to be written. + public void Write(T value, NpgsqlDbType npgsqlDbType) => + Write(async: false, value, npgsqlDbType, dataTypeName: null).GetAwaiter().GetResult(); - Task Write([AllowNull] T value, bool async, CancellationToken cancellationToken = default) - { - CheckColumnIndex(); + /// + /// Writes a single column in the current row as type . + /// + /// The value to be written + /// + /// In some cases isn't enough to infer the data type to be written to + /// the database. This parameter can be used to unambiguously specify the type. An example is + /// the JSONB type, for which will be a simple string but for which + /// must be specified as . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The .NET type of the column to be written. + public Task WriteAsync(T value, NpgsqlDbType npgsqlDbType, CancellationToken cancellationToken = default) + => Write(async: true, value, npgsqlDbType, dataTypeName: null, cancellationToken); - var p = _params[_column]; - if (p == null) - { - // First row, create the parameter objects - _params[_column] = p = typeof(T) == typeof(object) - ? new NpgsqlParameter() - : new NpgsqlParameter(); - } + /// + /// Writes a single column in the current row as type . + /// + /// The value to be written + /// + /// In some cases isn't enough to infer the data type to be written to + /// the database. This parameter and be used to unambiguously specify the type. + /// + /// The .NET type of the column to be written. + public void Write(T value, string dataTypeName) => + Write(async: false, value, npgsqlDbType: null, dataTypeName).GetAwaiter().GetResult(); - return Write(value, p, async, cancellationToken); - } + /// + /// Writes a single column in the current row as type . + /// + /// The value to be written + /// + /// In some cases isn't enough to infer the data type to be written to + /// the database. This parameter and be used to unambiguously specify the type. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The .NET type of the column to be written. + public Task WriteAsync(T value, string dataTypeName, CancellationToken cancellationToken = default) + => Write(async: true, value, npgsqlDbType: null, dataTypeName, cancellationToken); + + Task Write(bool async, T value, NpgsqlDbType? npgsqlDbType, string? dataTypeName, CancellationToken cancellationToken = default) + { + // Statically handle DBNull for backwards compatibility, generic parameters where T = DBNull normally won't find a mapping. + // Also handle null values for object typed parameters, as parameters only accept DBNull.Value when T = object. + if (typeof(T) == typeof(DBNull) || (typeof(T) == typeof(object) && value is null)) + return WriteNull(async, cancellationToken); - /// - /// Writes a single column in the current row as type . - /// - /// The value to be written - /// - /// In some cases isn't enough to infer the data type to be written to - /// the database. This parameter and be used to unambiguously specify the type. An example is - /// the JSONB type, for which will be a simple string but for which - /// must be specified as . - /// - /// The .NET type of the column to be written. - public void Write([AllowNull] T value, NpgsqlDbType npgsqlDbType) => - Write(value, npgsqlDbType, false).GetAwaiter().GetResult(); - - /// - /// Writes a single column in the current row as type . - /// - /// The value to be written - /// - /// In some cases isn't enough to infer the data type to be written to - /// the database. This parameter and be used to unambiguously specify the type. An example is - /// the JSONB type, for which will be a simple string but for which - /// must be specified as . - /// - /// - /// The .NET type of the column to be written. - public Task WriteAsync([AllowNull] T value, NpgsqlDbType npgsqlDbType, CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return Write(value, npgsqlDbType, true, cancellationToken); - } + return Core(async, value, npgsqlDbType, dataTypeName, cancellationToken); - Task Write([AllowNull] T value, NpgsqlDbType npgsqlDbType, bool async, CancellationToken cancellationToken = default) + async Task Core(bool async, T value, NpgsqlDbType? npgsqlDbType, string? dataTypeName, CancellationToken cancellationToken = default) { + CheckReady(); + cancellationToken.ThrowIfCancellationRequested(); CheckColumnIndex(); - var p = _params[_column]; - if (p == null) + // Create the parameter objects for the first row or if the value type changes. + var newParam = false; + if (_params[_column] is not NpgsqlParameter param) { - // First row, create the parameter objects - _params[_column] = p = typeof(T) == typeof(object) - ? new NpgsqlParameter() - : new NpgsqlParameter(); - p.NpgsqlDbType = npgsqlDbType; + newParam = true; + param = new NpgsqlParameter(); + if (npgsqlDbType is not null) + param._npgsqlDbType = npgsqlDbType; + if (dataTypeName is not null) + param._dataTypeName = dataTypeName; } - if (npgsqlDbType != p.NpgsqlDbType) - throw new InvalidOperationException($"Can't change {nameof(p.NpgsqlDbType)} from {p.NpgsqlDbType} to {npgsqlDbType}"); - - return Write(value, p, async, cancellationToken); - } - - /// - /// Writes a single column in the current row as type . - /// - /// The value to be written - /// - /// In some cases isn't enough to infer the data type to be written to - /// the database. This parameter and be used to unambiguously specify the type. - /// - /// The .NET type of the column to be written. - public void Write([AllowNull] T value, string dataTypeName) => - Write(value, dataTypeName, false).GetAwaiter().GetResult(); - - /// - /// Writes a single column in the current row as type . - /// - /// The value to be written - /// - /// In some cases isn't enough to infer the data type to be written to - /// the database. This parameter and be used to unambiguously specify the type. - /// - /// - /// The .NET type of the column to be written. - public Task WriteAsync([AllowNull] T value, string dataTypeName, CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return Write(value, dataTypeName, true, cancellationToken); - } - - Task Write([AllowNull] T value, string dataTypeName, bool async, CancellationToken cancellationToken = default) - { - CheckColumnIndex(); - - var p = _params[_column]; - if (p == null) + // We only retrieve previous values if anything actually changed. + // For object typed parameters we must do so whenever setting NpgsqlParameter.Value would reset the type info. + PgTypeInfo? previousTypeInfo = null; + PgConverter? previousConverter = null; + PgTypeId previousTypeId = default; + if (!newParam && ( + (typeof(T) == typeof(object) && param.ShouldResetObjectTypeInfo(value)) + || param._npgsqlDbType != npgsqlDbType + || param._dataTypeName != dataTypeName)) { - // First row, create the parameter objects - _params[_column] = p = typeof(T) == typeof(object) - ? new NpgsqlParameter() - : new NpgsqlParameter(); - p.DataTypeName = dataTypeName; + param.GetResolutionInfo(out previousTypeInfo, out previousConverter, out previousTypeId); + if (!newParam) + { + param.ResetDbType(); + if (npgsqlDbType is not null) + param._npgsqlDbType = npgsqlDbType; + if (dataTypeName is not null) + param._dataTypeName = dataTypeName; + } } - //if (dataTypeName!= p.DataTypeName) - // throw new InvalidOperationException($"Can't change {nameof(p.DataTypeName)} from {p.DataTypeName} to {dataTypeName}"); + // These actions can reset or change the type info, we'll check afterwards whether we're still consistent with the original values. + param.TypedValue = value; + param.ResolveTypeInfo(_connector.SerializerOptions); - return Write(value, p, async, cancellationToken); - } - - async Task Write([AllowNull] T value, NpgsqlParameter param, bool async, CancellationToken cancellationToken = default) - { - CheckReady(); - if (_column == -1) - throw new InvalidOperationException("A row hasn't been started"); - - if (value == null || value is DBNull) + if (previousTypeInfo is not null && previousConverter is not null && param.PgTypeId != previousTypeId) { - await WriteNull(async, cancellationToken); - return; + var currentPgTypeId = param.PgTypeId; + // We should only rollback values when the stored instance was used. We'll throw before writing the new instance back anyway. + // Also always rolling back could set PgTypeInfos that were resolved for a type that doesn't match the T of the NpgsqlParameter. + if (!newParam) + param.SetResolutionInfo(previousTypeInfo, previousConverter, previousTypeId); + throw new InvalidOperationException($"Write for column {_column} resolves to a different PostgreSQL type: {currentPgTypeId} than the first row resolved to ({previousTypeId}). " + + $"Please make sure to use clr types that resolve to the same PostgreSQL type across rows. " + + $"Alternatively pass the same NpgsqlDbType or DataTypeName to ensure the PostgreSQL type ends up to be identical." ); } + if (newParam) + _params[_column] = param; + + param.Bind(out _, out _, requiredFormat: DataFormat.Binary); + try { - if (typeof(T) == typeof(object)) - { - param.Value = value; - } - else - { - if (!(param is NpgsqlParameter typedParam)) - { - _params[_column] = typedParam = new NpgsqlParameter(); - typedParam.NpgsqlDbType = param.NpgsqlDbType; - } - typedParam.TypedValue = value; - } - param.ResolveHandler(_connector.TypeMapper); - param.ValidateAndGetLength(); - param.LengthCache?.Rewind(); - await param.WriteWithLength(_buf, async, cancellationToken); - param.LengthCache?.Clear(); - _column++; + await param.Write(async, _pgWriter.WithFlushMode(async ? FlushMode.NonBlocking : FlushMode.Blocking), cancellationToken) + .ConfigureAwait(false); } - catch + catch (Exception ex) { - // An exception here will have already broken the connection etc. - Cleanup(); + _connector.Break(ex); throw; } + + _column++; } + } - /// - /// Writes a single null column value. - /// - public void WriteNull() => WriteNull(false).GetAwaiter().GetResult(); + /// + /// Writes a single null column value. + /// + public void WriteNull() => WriteNull(false).GetAwaiter().GetResult(); - /// - /// Writes a single null column value. - /// - public Task WriteNullAsync(CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return WriteNull(true, cancellationToken); - } + /// + /// Writes a single null column value. + /// + public Task WriteNullAsync(CancellationToken cancellationToken = default) => WriteNull(async: true, cancellationToken); - async Task WriteNull(bool async, CancellationToken cancellationToken = default) - { - CheckReady(); - if (_column == -1) - throw new InvalidOperationException("A row hasn't been started"); + async Task WriteNull(bool async, CancellationToken cancellationToken = default) + { + CheckReady(); + if (cancellationToken.IsCancellationRequested) + cancellationToken.ThrowIfCancellationRequested(); + CheckColumnIndex(); - try - { - if (_buf.WriteSpaceLeft < 4) - await _buf.Flush(async, cancellationToken); + if (_buf.WriteSpaceLeft < 4) + await _buf.Flush(async, cancellationToken).ConfigureAwait(false); - _buf.WriteInt32(-1); - _column++; - } - catch - { - // An exception here will have already broken the connection etc. - Cleanup(); - throw; - } - } + _buf.WriteInt32(-1); + _pgWriter.RefreshBuffer(); + _column++; + } - /// - /// Writes an entire row of columns. - /// Equivalent to calling , followed by multiple - /// on each value. - /// - /// An array of column values to be written as a single row - public void WriteRow(params object[] values) => WriteRow(false, CancellationToken.None, values).GetAwaiter().GetResult(); - - /// - /// Writes an entire row of columns. - /// Equivalent to calling , followed by multiple - /// on each value. - /// - /// - /// An array of column values to be written as a single row - public Task WriteRowAsync(CancellationToken cancellationToken = default, params object[] values) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return WriteRow(true, cancellationToken, values); - } + /// + /// Writes an entire row of columns. + /// Equivalent to calling , followed by multiple + /// on each value. + /// + /// An array of column values to be written as a single row + public void WriteRow(params object?[] values) => WriteRow(false, CancellationToken.None, values).GetAwaiter().GetResult(); - async Task WriteRow(bool async, CancellationToken cancellationToken = default, params object[] values) - { - await StartRow(async, cancellationToken); - foreach (var value in values) - await Write(value, async, cancellationToken); - } + /// + /// Writes an entire row of columns. + /// Equivalent to calling , followed by multiple + /// on each value. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// An array of column values to be written as a single row + public Task WriteRowAsync(CancellationToken cancellationToken = default, params object?[] values) + => WriteRow(async: true, cancellationToken, values); + + async Task WriteRow(bool async, CancellationToken cancellationToken = default, params object?[] values) + { + await StartRow(async, cancellationToken).ConfigureAwait(false); + foreach (var value in values) + await Write(async, value, npgsqlDbType: null, dataTypeName: null, cancellationToken).ConfigureAwait(false); + } - void CheckColumnIndex() + void CheckColumnIndex() + { + if (_column is -1 || _column >= NumColumns) + Throw(); + + [MethodImpl(MethodImplOptions.NoInlining)] + void Throw() { + if (_column is -1) + throw new InvalidOperationException("A row hasn't been started"); + if (_column >= NumColumns) - ThrowHelper.ThrowInvalidOperationException_BinaryImportParametersMismatch(NumColumns, _column + 1); + ThrowColumnMismatch(); } + } - #endregion + #endregion - #region Commit / Cancel / Close / Dispose + #region Commit / Cancel / Close / Dispose - /// - /// Completes the import operation. The writer is unusable after this operation. - /// - public ulong Complete() => Complete(false).GetAwaiter().GetResult(); + /// + /// Completes the import operation. The writer is unusable after this operation. + /// + public ulong Complete() => Complete(false).GetAwaiter().GetResult(); - /// - /// Completes the import operation. The writer is unusable after this operation. - /// - public ValueTask CompleteAsync(CancellationToken cancellationToken = default) + /// + /// Completes the import operation. The writer is unusable after this operation. + /// + public ValueTask CompleteAsync(CancellationToken cancellationToken = default) => Complete(async: true, cancellationToken); + + async ValueTask Complete(bool async, CancellationToken cancellationToken = default) + { + CheckReady(); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + if (InMiddleOfRow) { - if (cancellationToken.IsCancellationRequested) - return new ValueTask(Task.FromCanceled(cancellationToken)); - using (NoSynchronizationContextScope.Enter()) - return Complete(true, cancellationToken); + await Cancel(async, cancellationToken).ConfigureAwait(false); + throw new InvalidOperationException("Binary importer closed in the middle of a row, cancelling import."); } - async ValueTask Complete(bool async, CancellationToken cancellationToken = default) + try { - CheckReady(); + // Write trailer + if (_buf.WriteSpaceLeft < 2) + await _buf.Flush(async, cancellationToken).ConfigureAwait(false); + _buf.WriteInt16(-1); - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + await _buf.Flush(async, cancellationToken).ConfigureAwait(false); + _buf.EndCopyMode(); + await _connector.WriteCopyDone(async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + var cmdComplete = Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + _state = ImporterState.Committed; + return cmdComplete.Rows; + } + catch + { + Cleanup(); + throw; + } + } - if (InMiddleOfRow) - { - await Cancel(async, cancellationToken); - throw new InvalidOperationException("Binary importer closed in the middle of a row, cancelling import."); - } + void ICancelable.Cancel() => Close(); - try - { - await WriteTrailer(async, cancellationToken); - await _buf.Flush(async, cancellationToken); - _buf.EndCopyMode(); - await _connector.WriteCopyDone(async, cancellationToken); - await _connector.Flush(async, cancellationToken); - var cmdComplete = Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - _state = ImporterState.Committed; - return cmdComplete.Rows; - } - catch - { - // An exception here will have already broken the connection etc. - Cleanup(); - throw; - } - } + async Task ICancelable.CancelAsync() => await CloseAsync().ConfigureAwait(false); - void ICancelable.Cancel() => Close(); + /// + /// + /// Terminates the ongoing binary import and puts the connection back into the idle state, where regular commands can be executed. + /// + /// + /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will + /// be reverted. + /// + /// + public void Dispose() => Close(); - /// - /// Cancels that binary import and sets the connection back to idle state - /// - public void Dispose() => Close(); + /// + /// + /// Async terminates the ongoing binary import and puts the connection back into the idle state, where regular commands can be executed. + /// + /// + /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will + /// be reverted. + /// + /// + public ValueTask DisposeAsync() => CloseAsync(true); - /// - /// Async cancels that binary import and sets the connection back to idle state - /// - /// - public ValueTask DisposeAsync() + async Task Cancel(bool async, CancellationToken cancellationToken = default) + { + _state = ImporterState.Cancelled; + _buf.Clear(); + _buf.EndCopyMode(); + await _connector.WriteCopyFail(async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + try { - using (NoSynchronizationContextScope.Enter()) - return CloseAsync(true); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + // The CopyFail should immediately trigger an exception from the read above. + throw _connector.Break( + new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); } - - async Task Cancel(bool async, CancellationToken cancellationToken = default) + catch (PostgresException e) { - _state = ImporterState.Cancelled; - _buf.Clear(); - _buf.EndCopyMode(); - await _connector.WriteCopyFail(async, cancellationToken); - await _connector.Flush(async, cancellationToken); - try - { - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - var msg = await _connector.ReadMessage(async); - // The CopyFail should immediately trigger an exception from the read above. - throw _connector.Break( - new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); - } - catch (PostgresException e) - { - if (e.SqlState != PostgresErrorCodes.QueryCanceled) - throw; - } + if (e.SqlState != PostgresErrorCodes.QueryCanceled) + throw; } + } - /// - /// Completes the import process and signals to the database to write everything. - /// - public void Close() => CloseAsync(false).GetAwaiter().GetResult(); - - /// - /// Async completes the import process and signals to the database to write everything. - /// - /// - /// - public ValueTask CloseAsync(CancellationToken cancellationToken = default) - { - if (cancellationToken.IsCancellationRequested) - return new ValueTask(Task.FromCanceled(cancellationToken)); - using (NoSynchronizationContextScope.Enter()) - return CloseAsync(true, cancellationToken); - } + /// + /// + /// Terminates the ongoing binary import and puts the connection back into the idle state, where regular commands can be executed. + /// + /// + /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will + /// be reverted. + /// + /// + public void Close() => CloseAsync(async: false).GetAwaiter().GetResult(); - async ValueTask CloseAsync(bool async, CancellationToken cancellationToken = default) - { - switch (_state) - { - case ImporterState.Disposed: - return; - case ImporterState.Ready: - await Cancel(async, cancellationToken); - break; - case ImporterState.Cancelled: - case ImporterState.Committed: - break; - default: - throw new Exception("Invalid state: " + _state); - } + /// + /// + /// Async terminates the ongoing binary import and puts the connection back into the idle state, where regular commands can be executed. + /// + /// + /// Note that if hasn't been invoked before calling this, the import will be cancelled and all changes will + /// be reverted. + /// + /// + public ValueTask CloseAsync(CancellationToken cancellationToken = default) => CloseAsync(async: true, cancellationToken); - _connector.EndUserAction(); - Cleanup(); + async ValueTask CloseAsync(bool async, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + switch (_state) + { + case ImporterState.Disposed: + return; + case ImporterState.Ready: + await Cancel(async, cancellationToken).ConfigureAwait(false); + break; + case ImporterState.Cancelled: + case ImporterState.Committed: + break; + default: + throw new Exception("Invalid state: " + _state); } + Cleanup(); + } + #pragma warning disable CS8625 - void Cleanup() - { - var connector = _connector; - Log.Debug("COPY operation ended", connector?.Id ?? -1); + void Cleanup() + { + if (_state == ImporterState.Disposed) + return; + var connector = _connector; - if (connector != null) - { - connector.CurrentCopyOperation = null; - _connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); - _connector = null; - } + LogMessages.BinaryCopyOperationCompleted(_copyLogger, _rowsImported, connector?.Id ?? -1); - _buf = null; - _state = ImporterState.Disposed; + if (connector != null) + { + connector.EndUserAction(); + connector.CurrentCopyOperation = null; + connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); + _connector = null; } + + _buf = null; + _state = ImporterState.Disposed; + } #pragma warning restore CS8625 - async Task WriteTrailer(bool async, CancellationToken cancellationToken = default) - { - if (_buf.WriteSpaceLeft < 2) - await _buf.Flush(async, cancellationToken); - _buf.WriteInt16(-1); - } + void CheckReady() + { + if (_state is not ImporterState.Ready and var state) + Throw(state); - void CheckReady() - { - switch (_state) + [MethodImpl(MethodImplOptions.NoInlining)] + static void Throw(ImporterState state) + => throw (state switch { - case ImporterState.Ready: - return; - case ImporterState.Disposed: - throw new ObjectDisposedException(GetType().FullName, "The COPY operation has already ended."); - case ImporterState.Cancelled: - throw new InvalidOperationException("The COPY operation has already been cancelled."); - case ImporterState.Committed: - throw new InvalidOperationException("The COPY operation has already been committed."); - default: - throw new Exception("Invalid state: " + _state); - } - } - - #endregion + ImporterState.Disposed => new ObjectDisposedException(typeof(NpgsqlBinaryImporter).FullName, + "The COPY operation has already ended."), + ImporterState.Cancelled => new InvalidOperationException("The COPY operation has already been cancelled."), + ImporterState.Committed => new InvalidOperationException("The COPY operation has already been committed."), + _ => new Exception("Invalid state: " + state) + }); + } - #region Enums + #endregion - enum ImporterState - { - Ready, - Committed, - Cancelled, - Disposed - } + #region Enums - #endregion Enums + enum ImporterState + { + Ready, + Committed, + Cancelled, + Disposed } + + #endregion Enums + + void ThrowColumnMismatch() + => throw new InvalidOperationException($"The binary import operation was started with {NumColumns} column(s), but {_column + 1} value(s) were provided."); } diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index ac7b95699f..5f9b6528a0 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -4,353 +4,413 @@ using System.Data; using System.Data.Common; using System.Diagnostics; -using System.Linq; using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Globalization; using Npgsql.BackendMessages; -using Npgsql.Logging; -using Npgsql.TypeMapping; using Npgsql.Util; using NpgsqlTypes; using static Npgsql.Util.Statics; -using System.Collections; using System.Diagnostics.CodeAnalysis; - -namespace Npgsql +using System.Threading.Channels; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Properties; + +namespace Npgsql; + +/// +/// Represents a SQL statement or function (stored procedure) to execute +/// against a PostgreSQL database. This class cannot be inherited. +/// +// ReSharper disable once RedundantNameQualifier +[System.ComponentModel.DesignerCategory("")] +public class NpgsqlCommand : DbCommand, ICloneable, IComponent { + #region Fields + + NpgsqlTransaction? _transaction; + + readonly NpgsqlConnector? _connector; + /// - /// Represents a SQL statement or function (stored procedure) to execute - /// against a PostgreSQL database. This class cannot be inherited. + /// If this command is (explicitly) prepared, references the connector on which the preparation happened. + /// Used to detect when the connector was changed (i.e. connection open/close), meaning that the command + /// is no longer prepared. /// - // ReSharper disable once RedundantNameQualifier - [System.ComponentModel.DesignerCategory("")] - public sealed class NpgsqlCommand : DbCommand, ICloneable - { - #region Fields + NpgsqlConnector? _connectorPreparedOn; - NpgsqlConnection? _connection; + string _commandText; + CommandBehavior _behavior; + int? _timeout; + internal NpgsqlParameterCollection? _parameters; - /// - /// If this command is (explicitly) prepared, references the connector on which the preparation happened. - /// Used to detect when the connector was changed (i.e. connection open/close), meaning that the command - /// is no longer prepared. - /// - NpgsqlConnector? _connectorPreparedOn; + /// + /// Whether this is wrapped by an . + /// + internal bool IsWrappedByBatch { get; } - string _commandText; - CommandBehavior _behavior; - int? _timeout; - readonly NpgsqlParameterCollection _parameters; + internal List InternalBatchCommands { get; } - internal readonly List _statements; + Activity? CurrentActivity; - /// - /// Returns details about each statement that this command has executed. - /// Is only populated when an Execute* method is called. - /// - public IReadOnlyList Statements => _statements.AsReadOnly(); + /// + /// Returns details about each statement that this command has executed. + /// Is only populated when an Execute* method is called. + /// + [Obsolete("Use the new DbBatch API")] + public IReadOnlyList Statements => InternalBatchCommands.AsReadOnly(); - UpdateRowSource _updateRowSource = UpdateRowSource.Both; + UpdateRowSource _updateRowSource = UpdateRowSource.Both; - bool IsExplicitlyPrepared => _connectorPreparedOn != null; + bool IsExplicitlyPrepared => _connectorPreparedOn != null; - static readonly List EmptyParameters = new List(); + /// + /// Whether this command is cached by and returned by . + /// + internal bool IsCacheable { get; set; } - static readonly SingleThreadSynchronizationContext SingleThreadSynchronizationContext = new SingleThreadSynchronizationContext("NpgsqlRemainingAsyncSendWorker"); +#if DEBUG + internal static bool EnableSqlRewriting; + internal static bool EnableStoredProcedureCompatMode; +#else + internal static readonly bool EnableSqlRewriting; + internal static readonly bool EnableStoredProcedureCompatMode; +#endif - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlCommand)); + internal bool EnableErrorBarriers { get; set; } - #endregion Fields + static readonly TaskScheduler ConstrainedConcurrencyScheduler = + new ConcurrentExclusiveSchedulerPair(TaskScheduler.Default, Math.Max(1, Environment.ProcessorCount / 2)).ConcurrentScheduler; - #region Constants + #endregion Fields - internal const int DefaultTimeout = 30; + #region Constants - #endregion + internal const int DefaultTimeout = 30; - #region Constructors + #endregion - /// - /// Initializes a new instance of the NpgsqlCommand class. - /// - public NpgsqlCommand() : this(null, null, null) {} + #region Constructors - /// - /// Initializes a new instance of the NpgsqlCommand class with the text of the query. - /// - /// The text of the query. - // ReSharper disable once IntroduceOptionalParameters.Global - public NpgsqlCommand(string? cmdText) : this(cmdText, null, null) {} + static NpgsqlCommand() + { + EnableSqlRewriting = !AppContext.TryGetSwitch("Npgsql.EnableSqlRewriting", out var enabled) || enabled; + EnableStoredProcedureCompatMode = AppContext.TryGetSwitch("Npgsql.EnableStoredProcedureCompatMode", out enabled) && enabled; + } - /// - /// Initializes a new instance of the NpgsqlCommand class with the text of the query and a NpgsqlConnection. - /// - /// The text of the query. - /// A NpgsqlConnection that represents the connection to a PostgreSQL server. - // ReSharper disable once IntroduceOptionalParameters.Global - public NpgsqlCommand(string? cmdText, NpgsqlConnection? connection) : this(cmdText, connection, null) {} + /// + /// Initializes a new instance of the class. + /// + public NpgsqlCommand() : this(null, null, null) {} - /// - /// Initializes a new instance of the NpgsqlCommand class with the text of the query, a NpgsqlConnection, and the NpgsqlTransaction. - /// - /// The text of the query. - /// A NpgsqlConnection that represents the connection to a PostgreSQL server. - /// The NpgsqlTransaction in which the NpgsqlCommand executes. - public NpgsqlCommand(string? cmdText, NpgsqlConnection? connection, NpgsqlTransaction? transaction) - { - GC.SuppressFinalize(this); - _statements = new List(1); - _parameters = new NpgsqlParameterCollection(); - _commandText = cmdText ?? string.Empty; - _connection = connection; - Transaction = transaction; - CommandType = CommandType.Text; - } + /// + /// Initializes a new instance of the class with the text of the query. + /// + /// The text of the query. + // ReSharper disable once IntroduceOptionalParameters.Global + public NpgsqlCommand(string? cmdText) : this(cmdText, null, null) {} - #endregion Constructors + /// + /// Initializes a new instance of the class with the text of the query and a + /// . + /// + /// The text of the query. + /// A that represents the connection to a PostgreSQL server. + // ReSharper disable once IntroduceOptionalParameters.Global + public NpgsqlCommand(string? cmdText, NpgsqlConnection? connection) + { + GC.SuppressFinalize(this); + InternalBatchCommands = new List(1); + _commandText = cmdText ?? string.Empty; + InternalConnection = connection; + CommandType = CommandType.Text; + } - #region Public properties + /// + /// Initializes a new instance of the class with the text of the query, a + /// , and the . + /// + /// The text of the query. + /// A that represents the connection to a PostgreSQL server. + /// The in which the executes. + public NpgsqlCommand(string? cmdText, NpgsqlConnection? connection, NpgsqlTransaction? transaction) + : this(cmdText, connection) + => Transaction = transaction; - /// - /// Gets or sets the SQL statement or function (stored procedure) to execute at the data source. - /// - /// The Transact-SQL statement or stored procedure to execute. The default is an empty string. - [AllowNull, DefaultValue("")] - [Category("Data")] - public override string CommandText - { - get => _commandText; - set - { - _commandText = State == CommandState.Idle - ? value ?? string.Empty - : throw new InvalidOperationException("An open data reader exists for this command."); + /// + /// Used when this instance is wrapped inside an . + /// + internal NpgsqlCommand(int batchCommandCapacity, NpgsqlConnection? connection = null) + { + GC.SuppressFinalize(this); + InternalBatchCommands = new List(batchCommandCapacity); + InternalConnection = connection; + CommandType = CommandType.Text; + IsWrappedByBatch = true; + + // These can/should never be used in this mode + _commandText = null!; + _parameters = null!; + } - ResetExplicitPreparation(); - // TODO: Technically should do this also if the parameter list (or type) changes - } - } + internal NpgsqlCommand(string? cmdText, NpgsqlConnector connector) : this(cmdText) + => _connector = connector; - /// - /// Gets or sets the wait time (in seconds) before terminating the attempt to execute a command and generating an error. - /// - /// The time (in seconds) to wait for the command to execute. The default value is 30 seconds. - [DefaultValue(DefaultTimeout)] - public override int CommandTimeout - { - get => _timeout ?? (_connection?.CommandTimeout ?? DefaultTimeout); - set - { - if (value < 0) { - throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be less than zero."); - } + /// + /// Used when this instance is wrapped inside an . + /// + internal NpgsqlCommand(NpgsqlConnector connector, int batchCommandCapacity) + : this(batchCommandCapacity) + => _connector = connector; - _timeout = value; - } - } + internal static NpgsqlCommand CreateCachedCommand(NpgsqlConnection connection) + => new(null, connection) { IsCacheable = true }; - /// - /// Gets or sets a value indicating how the - /// CommandText property is to be interpreted. - /// - /// One of the CommandType values. The default is CommandType.Text. - [DefaultValue(CommandType.Text)] - [Category("Data")] - public override CommandType CommandType { get; set; } - - /// - /// DB connection. - /// - protected override DbConnection? DbConnection - { - get => _connection; - set => _connection = (NpgsqlConnection?)value; - } + #endregion Constructors - /// - /// Gets or sets the NpgsqlConnection - /// used by this instance of the NpgsqlCommand. - /// - /// The connection to a data source. The default value is a null reference. - [DefaultValue(null)] - [Category("Behavior")] - public new NpgsqlConnection? Connection + #region Public properties + + /// + /// Gets or sets the SQL statement or function (stored procedure) to execute at the data source. + /// + /// The SQL statement or function (stored procedure) to execute. The default is an empty string. + [AllowNull, DefaultValue("")] + [Category("Data")] + public override string CommandText + { + get => _commandText; + set { - get => _connection; - set - { - if (_connection == value) - return; + Debug.Assert(!IsWrappedByBatch); - _connection = State == CommandState.Idle - ? value - : throw new InvalidOperationException("An open data reader exists for this command."); + if (State != CommandState.Idle) + ThrowHelper.ThrowInvalidOperationException("An open data reader exists for this command."); - Transaction = null; - } + _commandText = value ?? string.Empty; + + ResetPreparation(); + // TODO: Technically should do this also if the parameter list (or type) changes } + } - /// - /// Design time visible. - /// - public override bool DesignTimeVisible { get; set; } - - /// - /// Gets or sets how command results are applied to the DataRow when used by the - /// DbDataAdapter.Update(DataSet) method. - /// - /// One of the UpdateRowSource values. - [Category("Behavior"), DefaultValue(UpdateRowSource.Both)] - public override UpdateRowSource UpdatedRowSource + /// + /// Gets or sets the wait time (in seconds) before terminating the attempt to execute a command and generating an error. + /// + /// The time (in seconds) to wait for the command to execute. The default value is 30 seconds. + [DefaultValue(DefaultTimeout)] + public override int CommandTimeout + { + get => _timeout ?? (InternalConnection?.CommandTimeout ?? DefaultTimeout); + set { - get => _updateRowSource; - set - { - switch (value) - { - // validate value (required based on base type contract) - case UpdateRowSource.None: - case UpdateRowSource.OutputParameters: - case UpdateRowSource.FirstReturnedRecord: - case UpdateRowSource.Both: - _updateRowSource = value; - break; - default: - throw new ArgumentOutOfRangeException(); - } + if (value < 0) { + throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be less than zero."); } + + _timeout = value; } + } + + /// + /// Gets or sets a value indicating how the property is to be interpreted. + /// + /// + /// One of the values. The default is . + /// + [DefaultValue(CommandType.Text)] + [Category("Data")] + public override CommandType CommandType { get; set; } + + internal NpgsqlConnection? InternalConnection { get; private set; } - /// - /// Returns whether this query will execute as a prepared (compiled) query. - /// - public bool IsPrepared => - _connectorPreparedOn == _connection?.Connector && - _statements.Any() && _statements.All(s => s.PreparedStatement?.IsPrepared == true); + /// + /// DB connection. + /// + protected override DbConnection? DbConnection + { + get => InternalConnection; + set + { + if (InternalConnection == value) + return; + + InternalConnection = State == CommandState.Idle + ? (NpgsqlConnection?)value + : throw new InvalidOperationException("An open data reader exists for this command."); + + Transaction = null; + } + } - #endregion Public properties + /// + /// Gets or sets the used by this instance of the . + /// + /// The connection to a data source. The default value is . + [DefaultValue(null)] + [Category("Behavior")] + public new NpgsqlConnection? Connection + { + get => (NpgsqlConnection?)DbConnection; + set => DbConnection = value; + } - #region Known/unknown Result Types Management + /// + /// Design time visible. + /// + public override bool DesignTimeVisible { get; set; } - /// - /// Marks all of the query's result columns as either known or unknown. - /// Unknown results column are requested them from PostgreSQL in text format, and Npgsql makes no - /// attempt to parse them. They will be accessible as strings only. - /// - public bool AllResultTypesAreUnknown + /// + /// Gets or sets how command results are applied to the DataRow when used by the + /// DbDataAdapter.Update(DataSet) method. + /// + /// One of the values. + [Category("Behavior"), DefaultValue(UpdateRowSource.Both)] + public override UpdateRowSource UpdatedRowSource + { + get => _updateRowSource; + set { - get => _allResultTypesAreUnknown; - set + switch (value) { - // TODO: Check that this isn't modified after calling prepare - _unknownResultTypeList = null; - _allResultTypesAreUnknown = value; + // validate value (required based on base type contract) + case UpdateRowSource.None: + case UpdateRowSource.OutputParameters: + case UpdateRowSource.FirstReturnedRecord: + case UpdateRowSource.Both: + _updateRowSource = value; + break; + default: + throw new ArgumentOutOfRangeException(); } } + } - bool _allResultTypesAreUnknown; - - /// - /// Marks the query's result columns as known or unknown, on a column-by-column basis. - /// Unknown results column are requested them from PostgreSQL in text format, and Npgsql makes no - /// attempt to parse them. They will be accessible as strings only. - /// - /// - /// If the query includes several queries (e.g. SELECT 1; SELECT 2), this will only apply to the first - /// one. The rest of the queries will be fetched and parsed as usual. - /// - /// The array size must correspond exactly to the number of result columns the query returns, or an - /// error will be raised. - /// - public bool[]? UnknownResultTypeList + /// + /// Returns whether this query will execute as a prepared (compiled) query. + /// + public bool IsPrepared + { + get { - get => _unknownResultTypeList; - set + return _connectorPreparedOn == (InternalConnection?.Connector ?? _connector) && AllPrepared(); + + bool AllPrepared() { - // TODO: Check that this isn't modified after calling prepare - _allResultTypesAreUnknown = false; - _unknownResultTypeList = value; + if (InternalBatchCommands.Count is 0) + return false; + + foreach (var s in InternalBatchCommands) + if (s.PreparedStatement is null || !s.PreparedStatement.IsPrepared) + return false; + return true; } } + } - bool[]? _unknownResultTypeList; + #endregion Public properties + + #region Known/unknown Result Types Management + + /// + /// Marks all of the query's result columns as either known or unknown. + /// Unknown result columns are requested from PostgreSQL in text format, and Npgsql makes no + /// attempt to parse them. They will be accessible as strings only. + /// + public bool AllResultTypesAreUnknown + { + get => _allResultTypesAreUnknown; + set + { + // TODO: Check that this isn't modified after calling prepare + _unknownResultTypeList = null; + _allResultTypesAreUnknown = value; + } + } - #endregion + bool _allResultTypesAreUnknown; - #region Result Types Management + /// + /// Marks the query's result columns as known or unknown, on a column-by-column basis. + /// Unknown result columns are requested from PostgreSQL in text format, and Npgsql makes no + /// attempt to parse them. They will be accessible as strings only. + /// + /// + /// If the query includes several queries (e.g. SELECT 1; SELECT 2), this will only apply to the first + /// one. The rest of the queries will be fetched and parsed as usual. + /// + /// The array size must correspond exactly to the number of result columns the query returns, or an + /// error will be raised. + /// + public bool[]? UnknownResultTypeList + { + get => _unknownResultTypeList; + set + { + // TODO: Check that this isn't modified after calling prepare + _allResultTypesAreUnknown = false; + _unknownResultTypeList = value; + } + } - /// - /// Marks result types to be used when using GetValue on a data reader, on a column-by-column basis. - /// Used for Entity Framework 5-6 compability. - /// Only primitive numerical types and DateTimeOffset are supported. - /// Set the whole array or just a value to null to use default type. - /// - internal Type[]? ObjectResultTypes { get; set; } + bool[]? _unknownResultTypeList; - #endregion + #endregion - #region State management + #region State management - int _state; + volatile int _state; - /// - /// The current state of the command - /// - internal CommandState State + /// + /// The current state of the command + /// + internal CommandState State + { + get => (CommandState)_state; + set { - private get { return (CommandState)_state; } - set - { - var newState = (int)value; - if (newState == _state) - return; - Interlocked.Exchange(ref _state, newState); - } + var newState = (int)value; + if (newState == _state) + return; + _state = newState; } + } - void ResetExplicitPreparation() => _connectorPreparedOn = null; + internal void ResetPreparation() => _connectorPreparedOn = null; - #endregion State management + #endregion State management - #region Parameters + #region Parameters - /// - /// Creates a new instance of an DbParameter object. - /// - /// An DbParameter object. - protected override DbParameter CreateDbParameter() - { - return CreateParameter(); - } + /// + /// Creates a new instance of an object. + /// + /// A object. + protected override DbParameter CreateDbParameter() => CreateParameter(); - /// - /// Creates a new instance of a NpgsqlParameter object. - /// - /// A NpgsqlParameter object. - public new NpgsqlParameter CreateParameter() - { - return new NpgsqlParameter(); - } + /// + /// Creates a new instance of a object. + /// + /// An object. + public new NpgsqlParameter CreateParameter() => new(); - /// - /// DB parameter collection. - /// - protected override DbParameterCollection DbParameterCollection => Parameters; + /// + /// DB parameter collection. + /// + protected override DbParameterCollection DbParameterCollection => Parameters; - /// - /// Gets the NpgsqlParameterCollection. - /// - /// The parameters of the SQL statement or function (stored procedure). The default is an empty collection. - public new NpgsqlParameterCollection Parameters => _parameters; + /// + /// Gets the . + /// + /// The parameters of the SQL statement or function (stored procedure). The default is an empty collection. + public new NpgsqlParameterCollection Parameters => _parameters ??= new(); - #endregion + #endregion - #region DeriveParameters + #region DeriveParameters - const string DeriveParametersForFunctionQuery = @" + const string DeriveParametersForFunctionQuery = @" SELECT CASE WHEN pg_proc.proargnames IS NULL THEN array_cat(array_fill(''::name,ARRAY[pg_proc.pronargs]),array_agg(pg_attribute.attname ORDER BY pg_attribute.attnum)) @@ -372,137 +432,152 @@ FROM pg_proc GROUP BY pg_proc.proargnames, pg_proc.proargtypes, pg_proc.proallargtypes, pg_proc.proargmodes, pg_proc.pronargs; "; - internal void DeriveParameters() - { - var conn = CheckAndGetConnection(); + internal void DeriveParameters() + { + var conn = CheckAndGetConnection(); + Debug.Assert(conn is not null); + + if (string.IsNullOrEmpty(CommandText)) + throw new InvalidOperationException("CommandText property has not been initialized"); - using var _ = conn.StartTemporaryBindingScope(out var connector); + using var _ = conn.StartTemporaryBindingScope(out var connector); - if (Statements.Any(s => s.PreparedStatement?.IsExplicit == true)) + foreach (var s in InternalBatchCommands) + if (s.PreparedStatement?.IsExplicit == true) throw new NpgsqlException("Deriving parameters isn't supported for commands that are already prepared."); - // Here we unprepare statements that possibly are auto-prepared - Unprepare(); + // Here we unprepare statements that possibly are auto-prepared + Unprepare(); - Parameters.Clear(); + Parameters.Clear(); - switch (CommandType) - { - case CommandType.Text: - DeriveParametersForQuery(connector); - break; - case CommandType.StoredProcedure: - DeriveParametersForFunction(); - break; - default: - throw new NotSupportedException("Cannot derive parameters for CommandType " + CommandType); - } + switch (CommandType) + { + case CommandType.Text: + DeriveParametersForQuery(connector); + break; + case CommandType.StoredProcedure: + DeriveParametersForFunction(); + break; + default: + throw new NotSupportedException("Cannot derive parameters for CommandType " + CommandType); } + } - void DeriveParametersForFunction() - { - using var c = new NpgsqlCommand(DeriveParametersForFunctionQuery, _connection); - c.Parameters.Add(new NpgsqlParameter("proname", NpgsqlDbType.Text)); - c.Parameters[0].Value = CommandText; + void DeriveParametersForFunction() + { + using var c = new NpgsqlCommand(DeriveParametersForFunctionQuery, InternalConnection); + c.Parameters.Add(new NpgsqlParameter("proname", NpgsqlDbType.Text)); + c.Parameters[0].Value = CommandText; - string[]? names = null; - uint[]? types = null; - char[]? modes = null; + string[]? names = null; + uint[]? types = null; + char[]? modes = null; - using (var rdr = c.ExecuteReader(CommandBehavior.SingleRow | CommandBehavior.SingleResult)) + using (var rdr = c.ExecuteReader(CommandBehavior.SingleRow | CommandBehavior.SingleResult)) + { + if (rdr.Read()) { - if (rdr.Read()) + if (!rdr.IsDBNull(0)) + names = rdr.GetFieldValue(0); + if (!rdr.IsDBNull(2)) + types = rdr.GetFieldValue(2); + if (!rdr.IsDBNull(3)) + modes = rdr.GetFieldValue(3); + if (types == null) { - if (!rdr.IsDBNull(0)) - names = rdr.GetFieldValue(0); - if (!rdr.IsDBNull(2)) - types = rdr.GetFieldValue(2); - if (!rdr.IsDBNull(3)) - modes = rdr.GetFieldValue(3); - if (types == null) - { - if (rdr.IsDBNull(1) || rdr.GetFieldValue(1).Length == 0) - return; // Parameter-less function - types = rdr.GetFieldValue(1); - } + if (rdr.IsDBNull(1) || rdr.GetFieldValue(1).Length == 0) + return; // Parameter-less function + types = rdr.GetFieldValue(1); } - else - throw new InvalidOperationException($"{CommandText} does not exist in pg_proc"); } + else + throw new InvalidOperationException($"{CommandText} does not exist in pg_proc"); + } - var typeMapper = c._connection!.Connector!.TypeMapper; + var serializerOptions = c.InternalConnection!.Connector!.SerializerOptions; - for (var i = 0; i < types.Length; i++) + for (var i = 0; i < types.Length; i++) + { + var param = new NpgsqlParameter(); + + var postgresType = serializerOptions.DatabaseInfo.GetPostgresType(types[i]); + var npgsqlDbType = postgresType.DataTypeName.ToNpgsqlDbType(); + param.DataTypeName = postgresType.DisplayName; + param.PostgresType = postgresType; + if (npgsqlDbType.HasValue) + param.NpgsqlDbType = npgsqlDbType.Value; + + if (names != null && i < names.Length) + param.ParameterName = names[i]; + else + param.ParameterName = "parameter" + (i + 1); + + if (modes == null) // All params are IN, or server < 8.1.0 (and only IN is supported) + param.Direction = ParameterDirection.Input; + else { - var param = new NpgsqlParameter(); - - var (npgsqlDbType, postgresType) = typeMapper.GetTypeInfoByOid(types[i]); - - param.DataTypeName = postgresType.DisplayName; - param.PostgresType = postgresType; - if (npgsqlDbType.HasValue) - param.NpgsqlDbType = npgsqlDbType.Value; - - if (names != null && i < names.Length) - param.ParameterName = names[i]; - else - param.ParameterName = "parameter" + (i + 1); - - if (modes == null) // All params are IN, or server < 8.1.0 (and only IN is supported) - param.Direction = ParameterDirection.Input; - else + param.Direction = modes[i] switch { - param.Direction = modes[i] switch - { - 'i' => ParameterDirection.Input, - 'o' => ParameterDirection.Output, - 't' => ParameterDirection.Output, - 'b' => ParameterDirection.InputOutput, - 'v' => throw new NotSupportedException("Cannot derive function parameter of type VARIADIC"), - _ => throw new ArgumentOutOfRangeException("Unknown code in proargmodes while deriving: " + modes[i]) - }; - } - - Parameters.Add(param); + 'i' => ParameterDirection.Input, + 'o' => ParameterDirection.Output, + 't' => ParameterDirection.Output, + 'b' => ParameterDirection.InputOutput, + 'v' => throw new NotSupportedException("Cannot derive function parameter of type VARIADIC"), + _ => throw new ArgumentOutOfRangeException("Unknown code in proargmodes while deriving: " + modes[i]) + }; } + + Parameters.Add(param); } + } - void DeriveParametersForQuery(NpgsqlConnector connector) + void DeriveParametersForQuery(NpgsqlConnector connector) + { + using (connector.StartUserAction()) { - using (connector.StartUserAction()) - { - Log.Debug($"Deriving Parameters for query: {CommandText}", connector.Id); - ProcessRawQuery(true); + LogMessages.DerivingParameters(connector.CommandLogger, CommandText, connector.Id); - var sendTask = SendDeriveParameters(connector, false); - if (sendTask.IsFaulted) - sendTask.GetAwaiter().GetResult(); + if (IsWrappedByBatch) + foreach (var batchCommand in InternalBatchCommands) + connector.SqlQueryParser.ParseRawQuery(batchCommand, connector.UseConformingStrings, deriveParameters: true); + else + connector.SqlQueryParser.ParseRawQuery(this, connector.UseConformingStrings, deriveParameters: true); + + var sendTask = SendDeriveParameters(connector, false); + if (sendTask.IsFaulted) + sendTask.GetAwaiter().GetResult(); - foreach (var statement in _statements) + try + { + foreach (var batchCommand in InternalBatchCommands) { Expect( connector.ReadMessage(async: false).GetAwaiter().GetResult(), connector); var paramTypeOIDs = Expect( connector.ReadMessage(async: false).GetAwaiter().GetResult(), connector).TypeOIDs; - if (statement.InputParameters.Count != paramTypeOIDs.Count) + if (batchCommand.PositionalParameters.Count != paramTypeOIDs.Count) { connector.SkipUntil(BackendMessageCode.ReadyForQuery); Parameters.Clear(); - throw new NpgsqlException("There was a mismatch in the number of derived parameters between the Npgsql SQL parser and the PostgreSQL parser. Please report this as bug to the Npgsql developers (https://github.com/npgsql/npgsql/issues)."); + throw new NpgsqlException( + "There was a mismatch in the number of derived parameters between the Npgsql SQL parser and the PostgreSQL parser. Please report this as bug to the Npgsql developers (https://github.com/npgsql/npgsql/issues)."); } for (var i = 0; i < paramTypeOIDs.Count; i++) { try { - var param = statement.InputParameters[i]; + var param = batchCommand.PositionalParameters[i]; var paramOid = paramTypeOIDs[i]; - var (npgsqlDbType, postgresType) = connector.TypeMapper.GetTypeInfoByOid(paramOid); - + var postgresType = connector.SerializerOptions.DatabaseInfo.GetPostgresType(paramOid); + // We want to keep any domain types visible on the parameter, it will internally do a representational lookup again if necessary. + var npgsqlDbType = postgresType.GetRepresentationalType().DataTypeName.ToNpgsqlDbType(); if (param.NpgsqlDbType != NpgsqlDbType.Unknown && param.NpgsqlDbType != npgsqlDbType) - throw new NpgsqlException("The backend parser inferred different types for parameters with the same name. Please try explicit casting within your SQL statement or batch or use different placeholder names."); + throw new NpgsqlException( + "The backend parser inferred different types for parameters with the same name. Please try explicit casting within your SQL statement or batch or use different placeholder names."); param.DataTypeName = postgresType.DisplayName; param.PostgresType = postgresType; @@ -529,100 +604,128 @@ void DeriveParametersForQuery(NpgsqlConnector connector) } Expect(connector.ReadMessage(async: false).GetAwaiter().GetResult(), connector); - sendTask.GetAwaiter().GetResult(); + } + finally + { + try + { + // Make sure sendTask is complete so we don't race against asynchronous flush + sendTask.GetAwaiter().GetResult(); + } + catch + { + // ignored + } } } + } - #endregion + #endregion - #region Prepare + #region Prepare - /// - /// Creates a server-side prepared statement on the PostgreSQL server. - /// This will make repeated future executions of this command much faster. - /// - public override void Prepare() => Prepare(false).GetAwaiter().GetResult(); + /// + /// Creates a server-side prepared statement on the PostgreSQL server. + /// This will make repeated future executions of this command much faster. + /// + public override void Prepare() => Prepare(false).GetAwaiter().GetResult(); - /// - /// Creates a server-side prepared statement on the PostgreSQL server. - /// This will make repeated future executions of this command much faster. - /// - /// The token to monitor for cancellation requests. The default value is . + /// + /// Creates a server-side prepared statement on the PostgreSQL server. + /// This will make repeated future executions of this command much faster. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// #if NETSTANDARD2_0 - public Task PrepareAsync(CancellationToken cancellationToken = default) + public virtual Task PrepareAsync(CancellationToken cancellationToken = default) #else - public override Task PrepareAsync(CancellationToken cancellationToken = default) + public override Task PrepareAsync(CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return Prepare(true, cancellationToken); - } + => Prepare(async: true, cancellationToken); + + Task Prepare(bool async, CancellationToken cancellationToken = default) + { + var connection = CheckAndGetConnection(); + Debug.Assert(connection is not null); + if (connection.Settings.Multiplexing) + throw new NotSupportedException("Explicit preparation not supported with multiplexing"); + var connector = connection.Connector!; + var logger = connector.CommandLogger; + + var needToPrepare = false; - Task Prepare(bool async, CancellationToken cancellationToken = default) + if (IsWrappedByBatch) { - var connection = CheckAndGetConnection(); - if (connection.Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); - var connector = connection.Connector!; + foreach (var batchCommand in InternalBatchCommands) + { + batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateValues: false, batchCommand.CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); - for (var i = 0; i < Parameters.Count; i++) - Parameters[i].Bind(connector.TypeMapper); + needToPrepare = batchCommand.ExplicitPrepare(connector) || needToPrepare; + batchCommand.ConnectorPreparedOn = connector; + } - ProcessRawQuery(); - Log.Debug($"Preparing: {CommandText}", connector.Id); + if (logger.IsEnabled(LogLevel.Debug) && needToPrepare) + LogMessages.PreparingCommandExplicitly(logger, string.Join("; ", CommandTexts()), connector.Id); - var needToPrepare = false; - foreach (var statement in _statements) + IEnumerable CommandTexts() { - if (statement.IsPrepared) - continue; - statement.PreparedStatement = connector.PreparedStatementManager.GetOrAddExplicit(statement); - if (statement.PreparedStatement?.State == PreparedState.NotPrepared) - { - statement.PreparedStatement.State = PreparedState.BeingPrepared; - statement.IsPreparing = true; - needToPrepare = true; - } + foreach (var c in InternalBatchCommands) + yield return c.CommandText; } + } + else + { + _parameters?.ProcessParameters(connector.SerializerOptions, validateValues: false, CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); + + foreach (var batchCommand in InternalBatchCommands) + needToPrepare = batchCommand.ExplicitPrepare(connector) || needToPrepare; + + if (logger.IsEnabled(LogLevel.Debug) && needToPrepare) + LogMessages.PreparingCommandExplicitly(logger, CommandText, connector.Id); + } - _connectorPreparedOn = connector; + _connectorPreparedOn = connector; - // It's possible the command was already prepared, or that persistent prepared statements were found for - // all statements. Nothing to do here, move along. - return needToPrepare - ? PrepareLong(this, async, connector, cancellationToken) - : Task.CompletedTask; + // It's possible the command was already prepared, or that persistent prepared statements were found for + // all statements. Nothing to do here, move along. + return needToPrepare + ? PrepareLong(this, async, connector, cancellationToken) + : Task.CompletedTask; - static async Task PrepareLong(NpgsqlCommand command, bool async, NpgsqlConnector connector, CancellationToken cancellationToken) + static async Task PrepareLong(NpgsqlCommand command, bool async, NpgsqlConnector connector, CancellationToken cancellationToken) + { + try { using (connector.StartUserAction(cancellationToken)) { - var sendTask = command.SendPrepare(connector, async, cancellationToken); + var sendTask = command.SendPrepare(connector, async, CancellationToken.None); if (sendTask.IsFaulted) sendTask.GetAwaiter().GetResult(); - // Loop over statements, skipping those that are already prepared (because they were persisted) - var isFirst = true; - for (var i = 0; i < command._statements.Count; i++) + try { - var statement = command._statements[i]; - if (!statement.IsPreparing) - continue; + // Loop over statements, skipping those that are already prepared (because they were persisted) + var isFirst = true; + foreach (var batchCommand in command.InternalBatchCommands) + { + if (!batchCommand.IsPreparing) + continue; - var pStatement = statement.PreparedStatement!; + var pStatement = batchCommand.PreparedStatement!; - try - { if (pStatement.StatementBeingReplaced != null) { - Expect(await connector.ReadMessage(async), connector); + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); pStatement.StatementBeingReplaced.CompleteUnprepare(); pStatement.StatementBeingReplaced = null; } - Expect(await connector.ReadMessage(async), connector); - Expect(await connector.ReadMessage(async), connector); - var msg = await connector.ReadMessage(async); + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); + var msg = await connector.ReadMessage(async).ConfigureAwait(false); switch (msg.Code) { case BackendMessageCode.RowDescription: @@ -630,849 +733,1173 @@ static async Task PrepareLong(NpgsqlCommand command, bool async, NpgsqlConnector // by the connection) var description = ((RowDescriptionMessage)msg).Clone(); command.FixupRowDescription(description, isFirst); - statement.Description = description; + batchCommand.Description = description; break; case BackendMessageCode.NoData: - statement.Description = null; + batchCommand.Description = null; break; default: throw connector.UnexpectedMessageReceived(msg.Code); } - statement.IsPreparing = false; - pStatement.CompletePrepare(); + pStatement.State = PreparedState.Prepared; + connector.PreparedStatementManager.NumPrepared++; + batchCommand.IsPreparing = false; isFirst = false; } + + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); + } + finally + { + try + { + // Make sure sendTask is complete so we don't race against asynchronous flush + if (async) + await sendTask.ConfigureAwait(false); + else + sendTask.GetAwaiter().GetResult(); + } catch { - // The statement wasn't prepared successfully, update the bookkeeping for it and - // all following statements - for (; i < command._statements.Count; i++) - { - statement = command._statements[i]; - if (statement.IsPreparing) - { - statement.IsPreparing = false; - statement.PreparedStatement!.CompleteUnprepare(); - } - } - - throw; + // ignored } } - - Expect(await connector.ReadMessage(async), connector); - - if (async) - await sendTask; - else - sendTask.GetAwaiter().GetResult(); } + + LogMessages.CommandPreparedExplicitly(connector.CommandLogger, connector.Id); } - } + catch + { + // The statements weren't prepared successfully, update the bookkeeping for them + foreach (var batchCommand in command.InternalBatchCommands) + { + if (batchCommand.IsPreparing) + { + batchCommand.IsPreparing = false; + batchCommand.PreparedStatement!.AbortPrepare(); + } + } - /// - /// Unprepares a command, closing server-side statements associated with it. - /// Note that this only affects commands explicitly prepared with , not - /// automatically prepared statements. - /// - public void Unprepare() - => Unprepare(false).GetAwaiter().GetResult(); - - /// - /// Unprepares a command, closing server-side statements associated with it. - /// Note that this only affects commands explicitly prepared with , not - /// automatically prepared statements. - /// - /// The token to monitor for cancellation requests. The default value is . - public Task UnprepareAsync(CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return Unprepare(true, cancellationToken); + throw; + } } + } - async Task Unprepare(bool async, CancellationToken cancellationToken = default) - { - var connection = CheckAndGetConnection(); - if (connection.Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); - if (_statements.All(s => !s.IsPrepared)) - return; + /// + /// Unprepares a command, closing server-side statements associated with it. + /// Note that this only affects commands explicitly prepared with , not + /// automatically prepared statements. + /// + public void Unprepare() + => Unprepare(false).GetAwaiter().GetResult(); - var connector = connection.Connector!; + /// + /// Unprepares a command, closing server-side statements associated with it. + /// Note that this only affects commands explicitly prepared with , not + /// automatically prepared statements. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task UnprepareAsync(CancellationToken cancellationToken = default) + => Unprepare(async: true, cancellationToken); - Log.Debug("Closing command's prepared statements", connector.Id); - using (connector.StartUserAction(cancellationToken)) + async Task Unprepare(bool async, CancellationToken cancellationToken = default) + { + var connection = CheckAndGetConnection(); + Debug.Assert(connection is not null); + if (connection.Settings.Multiplexing) + throw new NotSupportedException("Explicit preparation not supported with multiplexing"); + + var forall = true; + foreach (var statement in InternalBatchCommands) + if (statement.IsPrepared) { - var sendTask = SendClose(connector, async, cancellationToken); - if (sendTask.IsFaulted) - sendTask.GetAwaiter().GetResult(); - foreach (var statement in _statements) - if (statement.PreparedStatement?.State == PreparedState.BeingUnprepared) - { - Expect(await connector.ReadMessage(async), connector); - statement.PreparedStatement.CompleteUnprepare(); - statement.PreparedStatement = null; - } - Expect(await connector.ReadMessage(async), connector); - if (async) - await sendTask; - else - sendTask.GetAwaiter().GetResult(); + forall = false; + break; } - } + if (forall) + return; - #endregion Prepare + var connector = connection.Connector!; - #region Query analysis + LogMessages.UnpreparingCommand(connector.CommandLogger, connector.Id); - internal void ProcessRawQuery(bool deriveParameters = false) + using (connector.StartUserAction(cancellationToken)) { - if (string.IsNullOrEmpty(CommandText)) - throw new InvalidOperationException("CommandText property has not been initialized"); - - NpgsqlStatement statement; - switch (CommandType) { - case CommandType.Text: - var parser = new SqlQueryParser(); - parser.ParseRawQuery(CommandText, _parameters, _statements, deriveParameters); - - if (_statements.Count > 1 && _parameters.HasOutputParameters) - throw new NotSupportedException("Commands with multipl e queries cannot have out parameters"); - break; + // Just wait for SendClose to complete since each statement takes no more than 20 bytes + await SendClose(connector, async, cancellationToken).ConfigureAwait(false); - case CommandType.TableDirect: - if (_statements.Count == 0) - statement = new NpgsqlStatement(); - else + foreach (var batchCommand in InternalBatchCommands) + { + if (batchCommand.PreparedStatement?.State == PreparedState.BeingUnprepared) { - statement = _statements[0]; - statement.Reset(); - _statements.Clear(); - } - _statements.Add(statement); - statement.SQL = "SELECT * FROM " + CommandText; - break; + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); - case CommandType.StoredProcedure: - var inputList = _parameters.Where(p => p.IsInputDirection).ToList(); - var numInput = inputList.Count; - var sb = new StringBuilder(); - sb.Append("SELECT * FROM "); - sb.Append(CommandText); - sb.Append('('); - var hasWrittenFirst = false; - for (var i = 1; i <= numInput; i++) { - var param = inputList[i - 1]; - if (param.TrimmedName == "") - { - if (hasWrittenFirst) - sb.Append(','); - sb.Append('$'); - sb.Append(i); - hasWrittenFirst = true; - } - } - for (var i = 1; i <= numInput; i++) - { - var param = inputList[i - 1]; - if (param.TrimmedName != "") - { - if (hasWrittenFirst) - sb.Append(','); - sb.Append('"'); - sb.Append(param.TrimmedName.Replace("\"", "\"\"")); - sb.Append("\" := "); - sb.Append('$'); - sb.Append(i); - hasWrittenFirst = true; - } - } - sb.Append(')'); + var pStatement = batchCommand.PreparedStatement; + pStatement.CompleteUnprepare(); - if (_statements.Count == 0) - statement = new NpgsqlStatement(); - else - { - statement = _statements[0]; - statement.Reset(); - _statements.Clear(); + if (!pStatement.IsExplicit) + connector.PreparedStatementManager.AutoPrepared[pStatement.AutoPreparedSlotIndex] = null; + + batchCommand.PreparedStatement = null; } - statement.SQL = sb.ToString(); - statement.InputParameters.AddRange(inputList); - _statements.Add(statement); - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {CommandType} of enum {nameof(CommandType)}. Please file a bug."); } - foreach (var s in _statements) - if (s.InputParameters.Count > ushort.MaxValue) - throw new NpgsqlException($"A statement cannot have more than {ushort.MaxValue} parameters"); + Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); } + } + + #endregion Prepare - #endregion + #region Query analysis - #region Execute + internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStrings, NpgsqlBatchCommand? batchCommand) + { + var (commandText, commandType, parameters) = batchCommand is null + ? (CommandText, CommandType, _parameters) + : (batchCommand.CommandText, batchCommand.CommandType, batchCommand._parameters); - void ValidateParameters(ConnectorTypeMapper typeMapper) + if (string.IsNullOrEmpty(commandText)) + ThrowHelper.ThrowInvalidOperationException("CommandText property has not been initialized"); + + switch (commandType) { - for (var i = 0; i < Parameters.Count; i++) + case CommandType.Text: + switch (parameters?.PlaceholderType ?? PlaceholderType.NoParameters) { - var p = Parameters[i]; - if (!p.IsInputDirection) - continue; - p.Bind(typeMapper); - p.LengthCache?.Clear(); - p.ValidateAndGetLength(); - } - } + case PlaceholderType.Positional: + // In positional parameter mode, we don't need to parse/rewrite the CommandText or reorder the parameters - just use + // them as is. If the SQL contains a semicolon (legacy batching) when positional parameters are in use, we just send + // that and PostgreSQL will error (this behavior is by-design - use the new batching API). + if (batchCommand is null) + { + batchCommand = TruncateStatementsToOne(); + batchCommand.FinalCommandText = CommandText; + if (parameters is not null) + { + batchCommand.PositionalParameters = parameters.InternalList; + batchCommand._parameters = parameters; + } + } + else + { + batchCommand.FinalCommandText = batchCommand.CommandText; + if (parameters is not null) + batchCommand.PositionalParameters = parameters.InternalList; + } - #endregion + ValidateParameterCount(batchCommand); + break; - #region Message Creation / Population + case PlaceholderType.NoParameters: + // Unless the EnableSqlRewriting AppContext switch is explicitly disabled, queries with no parameters are parsed just + // like queries with named parameters, since they may contain a semicolon (legacy batching). + if (EnableSqlRewriting) + goto case PlaceholderType.Named; + goto case PlaceholderType.Positional; - internal bool FlushOccurred { get; set; } + case PlaceholderType.Named: + if (!EnableSqlRewriting) + ThrowHelper.ThrowNotSupportedException($"Named parameters are not supported when Npgsql.{nameof(EnableSqlRewriting)} is disabled"); - void BeginSend(NpgsqlConnector connector) - { - connector.WriteBuffer.Timeout = TimeSpan.FromSeconds(CommandTimeout); - connector.WriteBuffer.CurrentCommand = this; - FlushOccurred = false; - } + // The parser is cached on NpgsqlConnector - unless we're in multiplexing mode. + parser ??= new SqlQueryParser(); - void CleanupSend() - { - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - if (SynchronizationContext.Current != null) // Check first because SetSynchronizationContext allocates - SynchronizationContext.SetSynchronizationContext(null); - } + if (batchCommand is null) + { + parser.ParseRawQuery(this, standardConformingStrings); + if (InternalBatchCommands.Count > 1 && _parameters?.HasOutputParameters == true) + ThrowHelper.ThrowNotSupportedException("Commands with multiple queries cannot have out parameters"); + for (var i = 0; i < InternalBatchCommands.Count; i++) + ValidateParameterCount(InternalBatchCommands[i]); + } + else + { + parser.ParseRawQuery(batchCommand, standardConformingStrings); + ValidateParameterCount(batchCommand); + } - internal Task Write(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) - { - return (_behavior & CommandBehavior.SchemaOnly) == 0 - ? WriteExecute(connector, async) - : WriteExecuteSchemaOnly(connector, async); + break; - async Task WriteExecute(NpgsqlConnector connector, bool async) - { - for (var i = 0; i < _statements.Count; i++) - { - // The following is only for deadlock avoidance when doing sync I/O (so never in multiplexing) - async = ForceAsyncIfNecessary(async, i); + case PlaceholderType.Mixed: + ThrowHelper.ThrowNotSupportedException("Mixing named and positional parameters isn't supported"); + break; - var statement = _statements[i]; - var pStatement = statement.PreparedStatement; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(PlaceholderType), $"Unknown {nameof(PlaceholderType)} value: {{0}}", _parameters?.PlaceholderType ?? PlaceholderType.NoParameters); + break; + } - if (pStatement == null || statement.IsPreparing) - { - // The statement should either execute unprepared, or is being auto-prepared. - // Send Parse, Bind, Describe + break; - // We may have a prepared statement that replaces an existing statement - close the latter first. - if (pStatement?.StatementBeingReplaced != null) - await connector.WriteClose(StatementOrPortal.Statement, pStatement.StatementBeingReplaced.Name!, async, cancellationToken); + case CommandType.TableDirect: + batchCommand ??= TruncateStatementsToOne(); + batchCommand.FinalCommandText = "SELECT * FROM " + CommandText; + break; - await connector.WriteParse(statement.SQL, statement.StatementName, statement.InputParameters, async, cancellationToken); + case CommandType.StoredProcedure: + var sqlBuilder = new StringBuilder() + .Append(EnableStoredProcedureCompatMode ? "SELECT * FROM " : "CALL ") + .Append(commandText) + .Append('('); - await connector.WriteBind( - statement.InputParameters, string.Empty, statement.StatementName, AllResultTypesAreUnknown, - i == 0 ? UnknownResultTypeList : null, - async, cancellationToken); + var isFirstParam = true; + var seenNamedParam = false; + var inputParameters = NpgsqlBatchCommand.EmptyParameters; + if (parameters is not null) + { + inputParameters = new List(parameters.Count); + for (var i = 0; i < parameters.Count; i++) + { + var parameter = parameters[i]; - await connector.WriteDescribe(StatementOrPortal.Portal, string.Empty, async, cancellationToken); - } + // With functions, output parameters are never present when calling the function (they only define the schema of the + // returned table). With stored procedures they must be specified in the CALL argument list (see below). + if (EnableStoredProcedureCompatMode && parameter.Direction == ParameterDirection.Output) + continue; + + if (isFirstParam) + isFirstParam = false; else + sqlBuilder.Append(", "); + + if (parameter.IsPositional) { - // The statement is already prepared, only a Bind is needed - await connector.WriteBind( - statement.InputParameters, string.Empty, statement.StatementName, AllResultTypesAreUnknown, - i == 0 ? UnknownResultTypeList : null, - async, cancellationToken); + if (seenNamedParam) + ThrowHelper.ThrowArgumentException(NpgsqlStrings.PositionalParameterAfterNamed); } + else + { + seenNamedParam = true; - await connector.WriteExecute(0, async, cancellationToken); + sqlBuilder + .Append('"') + .Append(parameter.TrimmedName.Replace("\"", "\"\"")) + .Append("\" := "); + } - if (pStatement != null) - pStatement.LastUsed = DateTime.UtcNow; + if (parameter.Direction == ParameterDirection.Output) + sqlBuilder.Append("NULL"); + else + { + inputParameters!.Add(parameter); + sqlBuilder.Append('$').Append(inputParameters.Count); + } } - - await connector.WriteSync(async, cancellationToken); } - async Task WriteExecuteSchemaOnly(NpgsqlConnector connector, bool async) - { - var wroteSomething = false; - for (var i = 0; i < _statements.Count; i++) - { - async = ForceAsyncIfNecessary(async, i); + sqlBuilder.Append(')'); - var statement = _statements[i]; + batchCommand ??= TruncateStatementsToOne(); + batchCommand.FinalCommandText = sqlBuilder.ToString(); + batchCommand._parameters = parameters; + batchCommand.PositionalParameters.AddRange(inputParameters); + ValidateParameterCount(batchCommand); - if (statement.PreparedStatement?.State == PreparedState.Prepared) - continue; // Prepared, we already have the RowDescription - Debug.Assert(statement.PreparedStatement == null); + break; - await connector.WriteParse(statement.SQL, string.Empty, statement.InputParameters, async, cancellationToken); - await connector.WriteDescribe(StatementOrPortal.Statement, statement.StatementName, async, cancellationToken); - wroteSomething = true; - } + default: + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(CommandType), $"Internal Npgsql bug: unexpected value {{0}} of enum {nameof(CommandType)}. Please file a bug.", commandType); + break; + } - if (wroteSomething) - await connector.WriteSync(async, cancellationToken); - } + static void ValidateParameterCount(NpgsqlBatchCommand batchCommand) + { + if (batchCommand.HasParameters && batchCommand.PositionalParameters.Count > ushort.MaxValue) + ThrowHelper.ThrowNpgsqlException("A statement cannot have more than 65535 parameters"); } + } + + #endregion + + #region Message Creation / Population + + void BeginSend(NpgsqlConnector connector) + => connector.WriteBuffer.Timeout = TimeSpan.FromSeconds(CommandTimeout); - async Task SendDeriveParameters(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) + internal Task Write(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken = default) + { + return (_behavior & CommandBehavior.SchemaOnly) == 0 + ? WriteExecute(connector, async, flush, cancellationToken) + : WriteExecuteSchemaOnly(connector, async, flush, cancellationToken); + + async Task WriteExecute(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken) { - BeginSend(connector); + NpgsqlBatchCommand? batchCommand = null; - for (var i = 0; i < _statements.Count; i++) + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) { - async = ForceAsyncIfNecessary(async, i); + // The following is only for deadlock avoidance when doing sync I/O (so never in multiplexing) + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - var statement = _statements[i]; + batchCommand = InternalBatchCommands[i]; + var pStatement = batchCommand.PreparedStatement; - await connector.WriteParse(statement.SQL, string.Empty, EmptyParameters, async, cancellationToken); - await connector.WriteDescribe(StatementOrPortal.Statement, string.Empty, async, cancellationToken); - } + Debug.Assert(batchCommand.FinalCommandText is not null); - await connector.WriteSync(async, cancellationToken); - await connector.Flush(async, cancellationToken); + if (pStatement == null || batchCommand.IsPreparing) + { + // The statement should either execute unprepared, or is being auto-prepared. + // Send Parse, Bind, Describe - CleanupSend(); - } + // We may have a prepared statement that replaces an existing statement - close the latter first. + if (pStatement?.StatementBeingReplaced != null) + await connector.WriteClose(StatementOrPortal.Statement, pStatement.StatementBeingReplaced.Name!, async, cancellationToken).ConfigureAwait(false); - async Task SendPrepare(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) - { - BeginSend(connector); + await connector.WriteParse(batchCommand.FinalCommandText, batchCommand.StatementName, + batchCommand.CurrentParametersReadOnly, async, cancellationToken).ConfigureAwait(false); - for (var i = 0; i < _statements.Count; i++) - { - async = ForceAsyncIfNecessary(async, i); + await connector.WriteBind( + batchCommand.CurrentParametersReadOnly, + string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, + i == 0 ? UnknownResultTypeList : null, + async, cancellationToken).ConfigureAwait(false); - var statement = _statements[i]; - var pStatement = statement.PreparedStatement; + await connector.WriteDescribe(StatementOrPortal.Portal, Array.Empty(), async, cancellationToken).ConfigureAwait(false); + } + else + { + // The statement is already prepared, only a Bind is needed + await connector.WriteBind( + batchCommand.CurrentParametersReadOnly, + string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, + i == 0 ? UnknownResultTypeList : null, + async, cancellationToken).ConfigureAwait(false); + } - // A statement may be already prepared, already in preparation (i.e. same statement twice - // in the same command), or we can't prepare (overloaded SQL) - if (!statement.IsPreparing) - continue; + await connector.WriteExecute(0, async, cancellationToken).ConfigureAwait(false); - // We may have a prepared statement that replaces an existing statement - close the latter first. - var statementToClose = pStatement!.StatementBeingReplaced; - if (statementToClose != null) - await connector.WriteClose(StatementOrPortal.Statement, statementToClose.Name!, async, cancellationToken); + if (batchCommand.AppendErrorBarrier ?? EnableErrorBarriers) + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - await connector.WriteParse(statement.SQL, pStatement.Name!, statement.InputParameters, async, cancellationToken); - await connector.WriteDescribe(StatementOrPortal.Statement, pStatement.Name!, async, cancellationToken); + pStatement?.RefreshLastUsed(); } - await connector.WriteSync(async, cancellationToken); - await connector.Flush(async, cancellationToken); + if (batchCommand is null || !(batchCommand.AppendErrorBarrier ?? EnableErrorBarriers)) + { + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + } - CleanupSend(); + if (flush) + await connector.Flush(async, cancellationToken).ConfigureAwait(false); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - bool ForceAsyncIfNecessary(bool async, int numberOfStatementInBatch) + async Task WriteExecuteSchemaOnly(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken) { - if (!async && FlushOccurred && numberOfStatementInBatch > 0) + var wroteSomething = false; + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) { - // We're synchronously sending the non-first statement in a batch and a flush - // has already occured. Switch to async. See long comment in Execute() above. - async = true; - SynchronizationContext.SetSynchronizationContext(SingleThreadSynchronizationContext); + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + + var batchCommand = InternalBatchCommands[i]; + + if (batchCommand.PreparedStatement?.State == PreparedState.Prepared) + continue; // Prepared, we already have the RowDescription + + await connector.WriteParse(batchCommand.FinalCommandText!, batchCommand.StatementName, + batchCommand.CurrentParametersReadOnly, + async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); + wroteSomething = true; } - return async; + if (wroteSomething) + { + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + if (flush) + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } } + } - async Task SendClose(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) + async Task SendDeriveParameters(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) + { + BeginSend(connector); + + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) { - BeginSend(connector); + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - foreach (var statement in _statements.Where(s => s.IsPrepared)) - { - if (FlushOccurred) - { - async = true; - SynchronizationContext.SetSynchronizationContext(SingleThreadSynchronizationContext); - } + var batchCommand = InternalBatchCommands[i]; - await connector.WriteClose(StatementOrPortal.Statement, statement.StatementName, async, cancellationToken); - statement.PreparedStatement!.State = PreparedState.BeingUnprepared; - } + await connector.WriteParse(batchCommand.FinalCommandText!, Array.Empty(), NpgsqlBatchCommand.EmptyParameters, async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, Array.Empty(), async, cancellationToken).ConfigureAwait(false); + } + + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + + async Task SendPrepare(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) + { + BeginSend(connector); + + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) + { + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + + var batchCommand = InternalBatchCommands[i]; + var pStatement = batchCommand.PreparedStatement; + + // A statement may be already prepared, already in preparation (i.e. same statement twice + // in the same command), or we can't prepare (overloaded SQL) + if (!batchCommand.IsPreparing) + continue; - await connector.WriteSync(async, cancellationToken); - await connector.Flush(async, cancellationToken); + // We may have a prepared statement that replaces an existing statement - close the latter first. + var statementToClose = pStatement!.StatementBeingReplaced; + if (statementToClose != null) + await connector.WriteClose(StatementOrPortal.Statement, statementToClose.Name!, async, cancellationToken).ConfigureAwait(false); - CleanupSend(); + await connector.WriteParse(batchCommand.FinalCommandText!, pStatement.Name!, batchCommand.CurrentParametersReadOnly, async, + cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, pStatement.Name!, async, cancellationToken).ConfigureAwait(false); } - #endregion + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } - #region Execute Non Query + [MethodImpl(MethodImplOptions.AggressiveInlining)] + bool ShouldSchedule(ref bool async, int indexOfStatementInBatch) + { + if (indexOfStatementInBatch <= 0) + return false; + + // We're synchronously sending the non-first statement in a batch - switch to async writing. + // See long comment in Execute() above. + + // TODO: we can simply do all batch writing asynchronously, instead of starting with the 2nd statement. + // For now, writing the first statement synchronously gives us a better chance of handling and bubbling up errors correctly + // (see sendTask.IsFaulted in Execute()). Once #1323 is done, that shouldn't be needed any more and entire batches should + // be written asynchronously. + async = true; + return TaskScheduler.Current != ConstrainedConcurrencyScheduler; + } - /// - /// Executes a SQL statement against the connection and returns the number of rows affected. - /// - /// The number of rows affected if known; -1 otherwise. - public override int ExecuteNonQuery() => ExecuteNonQuery(false, CancellationToken.None).GetAwaiter().GetResult(); + async Task SendClose(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) + { + BeginSend(connector); - /// - /// Asynchronous version of - /// - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation, with the number of rows affected if known; -1 otherwise. - public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken) + foreach (var batchCommand in InternalBatchCommands) { - using (NoSynchronizationContextScope.Enter()) - return ExecuteNonQuery(true, cancellationToken); + if (!batchCommand.IsPrepared) + continue; + // No need to force async here since each statement takes no more than 20 bytes + await connector.WriteClose(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); + batchCommand.PreparedStatement!.State = PreparedState.BeingUnprepared; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - async Task ExecuteNonQuery(bool async, CancellationToken cancellationToken) + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + + #endregion + + #region Execute Non Query + + /// + /// Executes a SQL statement against the connection and returns the number of rows affected. + /// + /// The number of rows affected if known; -1 otherwise. + public override int ExecuteNonQuery() => ExecuteNonQuery(false, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Asynchronous version of + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation, with the number of rows affected if known; -1 otherwise. + public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken) + => ExecuteNonQuery(async: true, cancellationToken); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + async Task ExecuteNonQuery(bool async, CancellationToken cancellationToken) + { + var reader = await ExecuteReader(async, CommandBehavior.Default, cancellationToken).ConfigureAwait(false); + try { - using var reader = await ExecuteReader(CommandBehavior.Default, async, cancellationToken); - while (async ? await reader.NextResultAsync(cancellationToken) : reader.NextResult()) ; + while (async ? await reader.NextResultAsync(cancellationToken).ConfigureAwait(false) : reader.NextResult()) ; return reader.RecordsAffected; } - - #endregion Execute Non Query - - #region Execute Scalar - - /// - /// Executes the query, and returns the first column of the first row - /// in the result set returned by the query. Extra columns or rows are ignored. - /// - /// The first column of the first row in the result set, - /// or a null reference if the result set is empty. - public override object? ExecuteScalar() => ExecuteScalar(false, CancellationToken.None).GetAwaiter().GetResult(); - - /// - /// Asynchronous version of - /// - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation, with the first column of the - /// first row in the result set, or a null reference if the result set is empty. - public override Task ExecuteScalarAsync(CancellationToken cancellationToken) + finally { - using (NoSynchronizationContextScope.Enter()) - return ExecuteScalar(true, cancellationToken).AsTask(); + if (async) + await reader.DisposeAsync().ConfigureAwait(false); + else + reader.Dispose(); } + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - async ValueTask ExecuteScalar(bool async, CancellationToken cancellationToken) - { - var behavior = CommandBehavior.SingleRow; - if (!Parameters.HasOutputParameters) - behavior |= CommandBehavior.SequentialAccess; + #endregion Execute Non Query - using var reader = await ExecuteReader(behavior, async, cancellationToken); - return reader.Read() && reader.FieldCount != 0 ? reader.GetValue(0) : null; - } + #region Execute Scalar - #endregion Execute Scalar - - #region Execute Reader - - /// - /// Executes the command text against the connection. - /// - /// A task representing the operation. - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) - => ExecuteReader(behavior); - - /// - /// Executes the command text against the connection. - /// - /// An instance of . - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation. - protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) - => await ExecuteReaderAsync(behavior, cancellationToken); - - /// - /// Executes the against the - /// and returns a . - /// - /// One of the enumeration values that specified the command behavior. - /// A task representing the operation. - public new NpgsqlDataReader ExecuteReader(CommandBehavior behavior = CommandBehavior.Default) - => ExecuteReader(behavior, async: false, CancellationToken.None).GetAwaiter().GetResult(); - - /// - /// An asynchronous version of , which executes - /// the against the - /// and returns a . - /// - /// The token to monitor for cancellation requests. The default value is . - /// A task representing the asynchronous operation. - public new Task ExecuteReaderAsync(CancellationToken cancellationToken = default) - => ExecuteReaderAsync(CommandBehavior.Default, cancellationToken); - - /// - /// An asynchronous version of , - /// which executes the against the - /// and returns a . - /// - /// One of the enumeration values that specified the command behavior. - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation. - public new Task ExecuteReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken = default) + /// + /// Executes the query, and returns the first column of the first row + /// in the result set returned by the query. Extra columns or rows are ignored. + /// + /// The first column of the first row in the result set, + /// or a null reference if the result set is empty. + public override object? ExecuteScalar() => ExecuteScalar(false, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Asynchronous version of + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation, with the first column of the + /// first row in the result set, or a null reference if the result set is empty. + public override Task ExecuteScalarAsync(CancellationToken cancellationToken) + => ExecuteScalar(async: true, cancellationToken).AsTask(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + async ValueTask ExecuteScalar(bool async, CancellationToken cancellationToken) + { + var behavior = CommandBehavior.SingleRow; + if (IsWrappedByBatch || _parameters?.HasOutputParameters != true) + behavior |= CommandBehavior.SequentialAccess; + + var reader = await ExecuteReader(async, behavior, cancellationToken).ConfigureAwait(false); + try { - using (NoSynchronizationContextScope.Enter()) - return ExecuteReader(behavior, async: true, cancellationToken).AsTask(); + var read = async ? await reader.ReadAsync(cancellationToken).ConfigureAwait(false) : reader.Read(); + return read && reader.FieldCount != 0 ? reader.GetValue(0) : null; } + finally + { + if (async) + await reader.DisposeAsync().ConfigureAwait(false); + else + reader.Dispose(); + } + } + + #endregion Execute Scalar - // TODO: Maybe pool these? - internal ManualResetValueTaskSource ExecutionCompletion { get; } - = new ManualResetValueTaskSource(); + #region Execute Reader - internal async ValueTask ExecuteReader(CommandBehavior behavior, bool async, CancellationToken cancellationToken) + /// + /// Executes the command text against the connection. + /// + /// A task representing the operation. + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + => ExecuteReader(behavior); + + /// + /// Executes the command text against the connection. + /// + /// An instance of . + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) + => await ExecuteReaderAsync(behavior, cancellationToken).ConfigureAwait(false); + + /// + /// Executes the against the + /// and returns a . + /// + /// One of the enumeration values that specifies the command behavior. + /// A task representing the operation. + public new NpgsqlDataReader ExecuteReader(CommandBehavior behavior = CommandBehavior.Default) + => ExecuteReader(async: false, behavior, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// An asynchronous version of , which executes + /// the against the + /// and returns a . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + public new Task ExecuteReaderAsync(CancellationToken cancellationToken = default) + => ExecuteReaderAsync(CommandBehavior.Default, cancellationToken); + + /// + /// An asynchronous version of , + /// which executes the against the + /// and returns a . + /// + /// One of the enumeration values that specifies the command behavior. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + public new Task ExecuteReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken = default) + => ExecuteReader(async: true, behavior, cancellationToken).AsTask(); + + // TODO: Maybe pool these? + internal ManualResetValueTaskSource ExecutionCompletion { get; } + = new(); + + internal virtual async ValueTask ExecuteReader(bool async, CommandBehavior behavior, CancellationToken cancellationToken) + { + var conn = CheckAndGetConnection(); + _behavior = behavior; + + NpgsqlConnector? connector; + if (_connector is not null) + { + Debug.Assert(conn is null); + if (behavior.HasFlag(CommandBehavior.CloseConnection)) + ThrowHelper.ThrowArgumentException($"{nameof(CommandBehavior.CloseConnection)} is not supported with {nameof(NpgsqlConnector)}", nameof(behavior)); + connector = _connector; + } + else { - var conn = CheckAndGetConnection(); - _behavior = behavior; + Debug.Assert(conn is not null); + conn.TryGetBoundConnector(out connector); + } - try + try + { + if (connector is not null) { - if (conn.TryGetBoundConnector(out var connector)) - { - connector.StartUserAction(ConnectorState.Executing, this, CancellationToken.None); + var logger = connector.CommandLogger; - Task? sendTask = null; + cancellationToken.ThrowIfCancellationRequested(); + // We cannot pass a token here, as we'll cancel a non-send query + // Also, we don't pass the cancellation token to StartUserAction, since that would make it scope to the entire action (command execution) + // whereas it should only be scoped to the Execute method. + connector.StartUserAction(ConnectorState.Executing, this, CancellationToken.None); - try + Task? sendTask; + + var validateParameterValues = !behavior.HasFlag(CommandBehavior.SchemaOnly); + long startTimestamp; + + try + { + switch (IsExplicitlyPrepared) { - ValidateParameters(connector.TypeMapper); + case true: + Debug.Assert(_connectorPreparedOn != null); + if (IsWrappedByBatch) + { + foreach (var batchCommand in InternalBatchCommands) + { + if (batchCommand.ConnectorPreparedOn != connector) + { + foreach (var s in InternalBatchCommands) + s.ResetPreparation(); + ResetPreparation(); + goto case false; + } - switch (IsExplicitlyPrepared) + batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, batchCommand.CommandType); + } + } + else { - case true: - Debug.Assert(_connectorPreparedOn != null); if (_connectorPreparedOn != connector) { // The command was prepared, but since then the connector has changed. Detach all prepared statements. - foreach (var s in _statements) + foreach (var s in InternalBatchCommands) s.PreparedStatement = null; - ResetExplicitPreparation(); + ResetPreparation(); goto case false; } + _parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); + } - NpgsqlEventSource.Log.CommandStartPrepared(); - break; + NpgsqlEventSource.Log.CommandStartPrepared(); + connector.DataSource.MetricsReporter.CommandStartPrepared(); + break; - case false: - ProcessRawQuery(); + case false: + var numPrepared = 0; - if (connector.Settings.MaxAutoPrepare > 0) + if (IsWrappedByBatch) + { + for (var i = 0; i < InternalBatchCommands.Count; i++) { - var numPrepared = 0; - foreach (var statement in _statements) - { - // If this statement isn't prepared, see if it gets implicitly prepared. - // Note that this may return null (not enough usages for automatic preparation). - if (!statement.IsPrepared) - statement.PreparedStatement = connector.PreparedStatementManager.TryGetAutoPrepared(statement); - if (statement.PreparedStatement is PreparedStatement pStatement) - { - numPrepared++; - if (pStatement?.State == PreparedState.NotPrepared) - { - pStatement.State = PreparedState.BeingPrepared; - statement.IsPreparing = true; - } - } - } + var batchCommand = InternalBatchCommands[i]; + + batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, batchCommand.CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); - if (numPrepared > 0) + if (connector.Settings.MaxAutoPrepare > 0 && batchCommand.TryAutoPrepare(connector)) { - _connectorPreparedOn = connector; - if (numPrepared == _statements.Count) - NpgsqlEventSource.Log.CommandStartPrepared(); + batchCommand.ConnectorPreparedOn = connector; + numPrepared++; } } - - break; } + else + { + _parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); - State = CommandState.InProgress; - - if (Log.IsEnabled(NpgsqlLogLevel.Debug)) - LogCommand(); - NpgsqlEventSource.Log.CommandStart(CommandText); + if (connector.Settings.MaxAutoPrepare > 0) + for (var i = 0; i < InternalBatchCommands.Count; i++) + if (InternalBatchCommands[i].TryAutoPrepare(connector)) + numPrepared++; + } - // If a cancellation is in progress, wait for it to "complete" before proceeding (#615) - lock (connector.CancelLock) + if (numPrepared > 0) { + _connectorPreparedOn = connector; + if (numPrepared == InternalBatchCommands.Count) + { + NpgsqlEventSource.Log.CommandStartPrepared(); + connector.DataSource.MetricsReporter.CommandStartPrepared(); + } } - // We do not wait for the entire send to complete before proceeding to reading - - // the sending continues in parallel with the user's reading. Waiting for the - // entire send to complete would trigger a deadlock for multi-statement commands, - // where PostgreSQL sends large results for the first statement, while we're sending large - // parameter data for the second. See #641. - // Instead, all sends for non-first statements and for non-first buffers are performed - // asynchronously (even if the user requested sync), in a special synchronization context - // to prevents a dependency on the thread pool (which would also trigger deadlocks). - // The WriteBuffer notifies this command when the first buffer flush occurs, so that the - // send functions can switch to the special async mode when needed. - sendTask = NonMultiplexingWriteWrapper(connector, async, CancellationToken.None); - - // The following is a hack. It raises an exception if one was thrown in the first phases - // of the send (i.e. in parts of the send that executed synchronously). Exceptions may - // still happen later and aren't properly handled. See #1323. - if (sendTask.IsFaulted) - sendTask.GetAwaiter().GetResult(); + break; } - catch + + State = CommandState.InProgress; + + if (logger.IsEnabled(LogLevel.Information)) { - conn.Connector?.EndUserAction(); - throw; - } + connector.QueryLogStopWatch.Restart(); - // TODO: DRY the following with multiplexing, but be careful with the cancellation registration... - var reader = connector.DataReader; - reader.Init(this, behavior, _statements, sendTask); - connector.CurrentReader = reader; - if (async) - await reader.NextResultAsync(cancellationToken); - else - reader.NextResult(); + if (logger.IsEnabled(LogLevel.Debug)) + LogExecutingCompleted(connector, executing: true); + } - return reader; + NpgsqlEventSource.Log.CommandStart(CommandText); + startTimestamp = connector.DataSource.MetricsReporter.ReportCommandStart(); + TraceCommandStart(connector); + + // If a cancellation is in progress, wait for it to "complete" before proceeding (#615) + connector.ResetCancellation(); + + // We do not wait for the entire send to complete before proceeding to reading - + // the sending continues in parallel with the user's reading. Waiting for the + // entire send to complete would trigger a deadlock for multi-statement commands, + // where PostgreSQL sends large results for the first statement, while we're sending large + // parameter data for the second. See #641. + // Instead, all sends for non-first statements are performed asynchronously (even if the user requested sync), + // in a special synchronization context to prevents a dependency on the thread pool (which would also trigger + // deadlocks). + BeginSend(connector); + sendTask = Write(connector, async, flush: true, CancellationToken.None); + + // The following is a hack. It raises an exception if one was thrown in the first phases + // of the send (i.e. in parts of the send that executed synchronously). Exceptions may + // still happen later and aren't properly handled. See #1323. + if (sendTask.IsFaulted) + sendTask.GetAwaiter().GetResult(); } + catch + { + connector.EndUserAction(); + throw; + } + + // TODO: DRY the following with multiplexing, but be careful with the cancellation registration... + var reader = connector.DataReader; + reader.Init(this, behavior, InternalBatchCommands, startTimestamp, sendTask); + connector.CurrentReader = reader; + if (async) + await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); else + reader.NextResult(); + + TraceReceivedFirstResponse(); + + return reader; + } + else + { + Debug.Assert(conn is not null); + Debug.Assert(conn.Settings.Multiplexing); + + // The connection isn't bound to a connector - it's multiplexing time. + var dataSource = (MultiplexingDataSource)conn.NpgsqlDataSource; + + if (!async) { - // The connection isn't bound to a connector - it's multiplexing time. + // The waiting on the ExecutionCompletion ManualResetValueTaskSource is necessarily + // asynchronous, so allowing sync would mean sync-over-async. + ThrowHelper.ThrowNotSupportedException("Synchronous command execution is not supported when multiplexing is on"); + } - if (!async) + if (IsWrappedByBatch) + { + foreach (var batchCommand in InternalBatchCommands) { - // The waiting on the ExecutionCompletion ManualResetValueTaskSource is necessarily - // asynchronous, so allowing sync would mean sync-over-async. - throw new NotSupportedException( - "Synchronous command execution is not supported when multiplexing is on"); + batchCommand._parameters?.ProcessParameters(dataSource.SerializerOptions, validateValues: true, batchCommand.CommandType); + ProcessRawQuery(null, standardConformingStrings: true, batchCommand); } + } + else + { + _parameters?.ProcessParameters(dataSource.SerializerOptions, validateValues: true, CommandType); + ProcessRawQuery(null, standardConformingStrings: true, batchCommand: null); + } - ValidateParameters(conn.Pool!.MultiplexingTypeMapper!); - ProcessRawQuery(); - - State = CommandState.InProgress; + State = CommandState.InProgress; - // TODO: Experiment: do we want to wait on *writing* here, or on *reading*? - // Previous behavior was to wait on reading, which throw the exception from ExecuteReader (and not from - // the first read). But waiting on writing would allow us to do sync writing and async reading. - ExecutionCompletion.Reset(); - await conn.Pool!.MultiplexCommandWriter!.WriteAsync(this, cancellationToken); - connector = await new ValueTask(ExecutionCompletion, ExecutionCompletion.Version); - // TODO: Overload of StartBindingScope? - conn.Connector = connector; - connector.Connection = conn; - conn.ConnectorBindingScope = ConnectorBindingScope.Reader; - - var reader = connector.DataReader; - reader.Init(this, behavior, _statements); - connector.CurrentReader = reader; - await reader.NextResultAsync(cancellationToken); - - return reader; + // TODO: Experiment: do we want to wait on *writing* here, or on *reading*? + // Previous behavior was to wait on reading, which throw the exception from ExecuteReader (and not from + // the first read). But waiting on writing would allow us to do sync writing and async reading. + ExecutionCompletion.Reset(); + try + { + await dataSource.MultiplexCommandWriter.WriteAsync(this, cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException ex) + { + Debug.Assert(ex.InnerException is not null); + throw ex.InnerException; } + connector = await new ValueTask(ExecutionCompletion, ExecutionCompletion.Version).ConfigureAwait(false); + // TODO: Overload of StartBindingScope? + conn.Connector = connector; + connector.Connection = conn; + conn.ConnectorBindingScope = ConnectorBindingScope.Reader; + + var reader = connector.DataReader; + reader.Init(this, behavior, InternalBatchCommands); + connector.CurrentReader = reader; + await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); + + return reader; } - catch (Exception e) - { - var reader = conn.Connector?.CurrentReader; - if (!(e is NpgsqlOperationInProgressException) && reader != null) - await reader.Cleanup(async); + } + catch (Exception e) + { + var reader = connector?.CurrentReader; + if (e is not NpgsqlOperationInProgressException && reader is not null) + await reader.Cleanup(async).ConfigureAwait(false); - State = CommandState.Idle; + TraceSetException(e); - // Reader disposal contains logic for closing the connection if CommandBehavior.CloseConnection is - // specified. However, close here as well in case of an error before the reader was even instantiated - // (e.g. write I/O error) - if ((behavior & CommandBehavior.CloseConnection) == CommandBehavior.CloseConnection) - conn.Close(); - throw; - } + State = CommandState.Idle; - async Task NonMultiplexingWriteWrapper(NpgsqlConnector connector, bool async, CancellationToken cancellationToken2) + // Reader disposal contains logic for closing the connection if CommandBehavior.CloseConnection is + // specified. However, close here as well in case of an error before the reader was even instantiated + // (e.g. write I/O error) + if ((behavior & CommandBehavior.CloseConnection) == CommandBehavior.CloseConnection) { - BeginSend(connector); - await Write(connector, async, cancellationToken2); - await connector.Flush(async, cancellationToken2); - CleanupSend(); + Debug.Assert(_connector is null && conn is not null); + conn.Close(); } + + throw; } + } + + #endregion + + #region Transactions + + /// + /// DB transaction. + /// + protected override DbTransaction? DbTransaction + { + get => _transaction; + set => _transaction = (NpgsqlTransaction?)value; + } + + /// + /// This property is ignored by Npgsql. PostgreSQL only supports a single transaction at a given time on + /// a given connection, and all commands implicitly run inside the current transaction started via + /// + /// + public new NpgsqlTransaction? Transaction + { + get => (NpgsqlTransaction?)DbTransaction; + set => DbTransaction = value; + } - #endregion + #endregion Transactions - #region Transactions + #region Cancel - /// - /// DB transaction. - /// - protected override DbTransaction? DbTransaction + /// + /// Attempts to cancel the execution of an . + /// + /// As per the specs, no exception will be thrown by this method in case of failure. + public override void Cancel() + { + if (State != CommandState.InProgress) + return; + + var connector = Connection?.Connector ?? _connector; + if (connector is null) + return; + + connector.PerformUserCancellation(); + } + + #endregion Cancel + + #region Dispose + + /// + protected override void Dispose(bool disposing) + { + ResetTransaction(); + + State = CommandState.Disposed; + + if (IsCacheable && InternalConnection is not null && InternalConnection.CachedCommand is null) { - get => Transaction; - set => Transaction = (NpgsqlTransaction?)value; + Reset(); + InternalConnection.CachedCommand = this; + return; } - /// - /// This property is ignored by Npgsql. PostgreSQL only supports a single transaction at a given time on - /// a given connection, and all commands implicitly run inside the current transaction started via - /// - /// - public new NpgsqlTransaction? Transaction { get; set; } - - #endregion Transactions - - #region Cancel - - /// - /// Attempts to cancel the execution of an . - /// - /// As per the specs, no exception will be thrown by this method in case of failure. - public override void Cancel() + + IsCacheable = false; + } + + internal void Reset() + { + // TODO: Optimize NpgsqlParameterCollection to recycle NpgsqlParameter instances as well + // TODO: Statements isn't cleared/recycled, leaving this for now, since it'll be replaced by the new batching API + _commandText = string.Empty; + CommandType = CommandType.Text; + // Can be null if it's owned by batch + _parameters?.Clear(); + _timeout = null; + _allResultTypesAreUnknown = false; + EnableErrorBarriers = false; + } + + internal void ResetTransaction() => _transaction = null; + + #endregion + + #region Tracing + + internal void TraceCommandStart(NpgsqlConnector connector) + { + Debug.Assert(CurrentActivity is null); + if (NpgsqlActivitySource.IsEnabled) + CurrentActivity = NpgsqlActivitySource.CommandStart(connector, CommandText, CommandType); + } + + internal void TraceReceivedFirstResponse() + { + if (CurrentActivity is not null) { - if (State != CommandState.InProgress) - return; + NpgsqlActivitySource.ReceivedFirstResponse(CurrentActivity); + } + } - var connection = Connection; - if (connection is null) - return; - if (!connection.IsBound) - throw new NotSupportedException("Cancellation not supported with multiplexing"); + internal void TraceCommandStop() + { + if (CurrentActivity is not null) + { + NpgsqlActivitySource.CommandStop(CurrentActivity); + CurrentActivity = null; + } + } - connection.Connector?.PerformUserCancellation(); + internal void TraceSetException(Exception e) + { + if (CurrentActivity is not null) + { + NpgsqlActivitySource.SetException(CurrentActivity, e); + CurrentActivity = null; } + } - #endregion Cancel + #endregion Tracing - #region Dispose + #region Misc - /// - /// Releases the resources used by the NpgsqlCommand. - /// - protected override void Dispose(bool disposing) + NpgsqlBatchCommand TruncateStatementsToOne() + { + switch (InternalBatchCommands.Count) { - if (State == CommandState.Disposed) - return; - Transaction = null; - _connection = null; - State = CommandState.Disposed; - base.Dispose(disposing); + case 0: + var statement = new NpgsqlBatchCommand(); + InternalBatchCommands.Add(statement); + return statement; + + case 1: + statement = InternalBatchCommands[0]; + statement.Reset(); + return statement; + + default: + statement = InternalBatchCommands[0]; + statement.Reset(); + InternalBatchCommands.Clear(); + InternalBatchCommands.Add(statement); + return statement; } + } - #endregion + /// + /// Fixes up the text/binary flag on result columns. + /// Since Prepare() describes a statement rather than a portal, the resulting RowDescription + /// will have text format on all result columns. Fix that up. + /// + /// + /// Note that UnknownResultTypeList only applies to the first query, while AllResultTypesAreUnknown applies + /// to all of them. + /// + internal void FixupRowDescription(RowDescriptionMessage rowDescription, bool isFirst) + { + for (var i = 0; i < rowDescription.Count; i++) + { + var field = rowDescription[i]; + field.DataFormat = (UnknownResultTypeList == null || !isFirst ? AllResultTypesAreUnknown : UnknownResultTypeList[i]) + ? DataFormat.Text + : DataFormat.Binary; + } + } - #region Misc + internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) + { + var logParameters = connector.LoggingConfiguration.IsParameterLoggingEnabled || connector.Settings.LogParameters; + var logger = connector.LoggingConfiguration.CommandLogger; - /// - /// Fixes up the text/binary flag on result columns. - /// Since Prepare() describes a statement rather than a portal, the resulting RowDescription - /// will have text format on all result columns. Fix that up. - /// - /// - /// Note that UnknownResultTypeList only applies to the first query, while AllResultTypesAreUnknown applies - /// to all of them. - /// - internal void FixupRowDescription(RowDescriptionMessage rowDescription, bool isFirst) + if (InternalBatchCommands.Count == 1) { - for (var i = 0; i < rowDescription.NumFields; i++) + var singleCommand = InternalBatchCommands[0]; + + if (logParameters && singleCommand.HasParameters) { - var field = rowDescription[i]; - field.FormatCode = (UnknownResultTypeList == null || !isFirst ? AllResultTypesAreUnknown : UnknownResultTypeList[i]) - ? FormatCode.Text - : FormatCode.Binary; - field.ResolveHandler(); + if (executing) + { + LogMessages.ExecutingCommandWithParameters( + logger, + singleCommand.FinalCommandText!, + ParametersDbNullAsString(singleCommand), + connector.Id); + } + else + { + LogMessages.CommandExecutionCompletedWithParameters( + logger, + singleCommand.FinalCommandText!, + ParametersDbNullAsString(singleCommand), + connector.QueryLogStopWatch.ElapsedMilliseconds, + connector.Id); + } + } + else + { + if (executing) + LogMessages.ExecutingCommand(logger, singleCommand.FinalCommandText!, connector.Id); + else + LogMessages.CommandExecutionCompleted(logger, singleCommand.FinalCommandText!, connector.QueryLogStopWatch.ElapsedMilliseconds, connector.Id); } } - - void LogCommand() + else { - var connector = _connection!.Connector!; - var sb = new StringBuilder(); - sb.AppendLine("Executing statement(s):"); - foreach (var s in _statements) + if (logParameters) { - sb.Append("\t").AppendLine(s.SQL); - var p = s.InputParameters; - if (p.Count > 0 && (NpgsqlLogManager.IsParameterLoggingEnabled || connector.Settings.LogParameters)) - { - for (var i = 0; i < p.Count; i++) - { - sb.Append("\t").Append("Parameters $").Append(i + 1).Append(":"); - switch (p[i].Value) - { - case IList list: - for (var j = 0; j < list.Count; j++) - { - sb.Append("\t#").Append(j).Append(": ").Append(Convert.ToString(list[j], CultureInfo.InvariantCulture)); - } - break; - case DBNull _: - case null: - sb.Append("\t").Append(Convert.ToString("null", CultureInfo.InvariantCulture)); - break; - default: - sb.Append("\t").Append(Convert.ToString(p[i].Value, CultureInfo.InvariantCulture)); - break; - } - sb.AppendLine(); - } - } + var commands = new (string, object[])[InternalBatchCommands.Count]; + for (var i = 0; i < InternalBatchCommands.Count; i++) + commands[i] = (InternalBatchCommands[i].FinalCommandText!, ParametersDbNullAsString(InternalBatchCommands[i])); + + if (executing) + LogMessages.ExecutingBatchWithParameters(logger, commands, connector.Id); + else + LogMessages.BatchExecutionCompletedWithParameters(logger, commands, connector.QueryLogStopWatch.ElapsedMilliseconds, connector.Id); + } + else + { + var commands = new string[InternalBatchCommands.Count]; + for (var i = 0; i < InternalBatchCommands.Count; i++) + commands[i] = InternalBatchCommands[i].FinalCommandText!; + if (executing) + LogMessages.ExecutingBatch(logger, commands, connector.Id); + else + LogMessages.BatchExecutionCompleted(logger, commands, connector.QueryLogStopWatch.ElapsedMilliseconds, connector.Id); } - Log.Debug(sb.ToString(), connector.Id); - connector.QueryLogStopWatch.Start(); } - /// - /// Create a new command based on this one. - /// - /// A new NpgsqlCommand object. - object ICloneable.Clone() => Clone(); - - /// - /// Create a new command based on this one. - /// - /// A new NpgsqlCommand object. - public NpgsqlCommand Clone() + object[] ParametersDbNullAsString(NpgsqlBatchCommand c) { - var clone = new NpgsqlCommand(CommandText, _connection, Transaction) - { - CommandTimeout = CommandTimeout, CommandType = CommandType, DesignTimeVisible = DesignTimeVisible, _allResultTypesAreUnknown = _allResultTypesAreUnknown, _unknownResultTypeList = _unknownResultTypeList, ObjectResultTypes = ObjectResultTypes - }; - _parameters.CloneTo(clone._parameters); - return clone; + var positionalParameters = c.CurrentParametersReadOnly; + var parameters = new object[positionalParameters.Count]; + for (var i = 0; i < positionalParameters.Count; i++) + parameters[i] = positionalParameters[i].Value == DBNull.Value ? "NULL" : positionalParameters[i].Value!; + return parameters; } + } + + /// + /// Create a new command based on this one. + /// + /// A new NpgsqlCommand object. + object ICloneable.Clone() => Clone(); + + /// + /// Create a new command based on this one. + /// + /// A new NpgsqlCommand object. + public virtual NpgsqlCommand Clone() + { + var clone = new NpgsqlCommand(CommandText, InternalConnection, Transaction) + { + CommandTimeout = CommandTimeout, + CommandType = CommandType, + DesignTimeVisible = DesignTimeVisible, + _allResultTypesAreUnknown = _allResultTypesAreUnknown, + _unknownResultTypeList = _unknownResultTypeList + }; + _parameters?.CloneTo(clone.Parameters); + return clone; + } + + NpgsqlConnection? CheckAndGetConnection() + { + if (State is CommandState.Disposed) + ThrowHelper.ThrowObjectDisposedException(GetType().FullName); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - NpgsqlConnection CheckAndGetConnection() + var conn = InternalConnection; + if (conn is null) { - if (State == CommandState.Disposed) - throw new ObjectDisposedException(GetType().FullName); - if (_connection == null) - throw new InvalidOperationException("Connection property has not been initialized."); - switch (_connection.FullState) - { - case ConnectionState.Open: - case ConnectionState.Connecting: - case ConnectionState.Open | ConnectionState.Executing: - case ConnectionState.Open | ConnectionState.Fetching: - return _connection; - default: - throw new InvalidOperationException("Connection is not open"); - } + if (_connector is null) + ThrowHelper.ThrowInvalidOperationException("Connection property has not been initialized."); + return null; } - #endregion + if (!conn.FullState.HasFlag(ConnectionState.Open)) + ThrowHelper.ThrowInvalidOperationException("Connection is not open"); + + return conn; } - enum CommandState + /// + /// This event is unsupported by Npgsql. Use instead. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public new event EventHandler? Disposed + { + add => throw new NotSupportedException("The Disposed event isn't supported by Npgsql. Use DbConnection.StateChange instead."); + remove => throw new NotSupportedException("The Disposed event isn't supported by Npgsql. Use DbConnection.StateChange instead."); + } + + event EventHandler? IComponent.Disposed { - Idle, - InProgress, - Disposed + add => Disposed += value; + remove => Disposed -= value; } + + #endregion +} + +enum CommandState +{ + Idle, + InProgress, + Disposed } diff --git a/src/Npgsql/NpgsqlCommandBuilder.cs b/src/Npgsql/NpgsqlCommandBuilder.cs index 03080346fa..9665b8356c 100644 --- a/src/Npgsql/NpgsqlCommandBuilder.cs +++ b/src/Npgsql/NpgsqlCommandBuilder.cs @@ -5,192 +5,192 @@ using System.Globalization; using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +/// +/// This class creates database commands for automatic insert, update and delete operations. +/// +[System.ComponentModel.DesignerCategory("")] +public sealed class NpgsqlCommandBuilder : DbCommandBuilder { - /// - /// This class is responsible to create database commands for automatic insert, update and delete operations. - /// - [System.ComponentModel.DesignerCategory("")] - public sealed class NpgsqlCommandBuilder : DbCommandBuilder - { - // Commented out because SetRowUpdatingHandler() is commented, and causes an "is never used" warning - // private NpgsqlRowUpdatingEventHandler rowUpdatingHandler; + // Commented out because SetRowUpdatingHandler() is commented, and causes an "is never used" warning + // private NpgsqlRowUpdatingEventHandler rowUpdatingHandler; - /// - /// Initializes a new instance of the class. - /// - public NpgsqlCommandBuilder() - : this(null) - { - } + /// + /// Initializes a new instance of the class. + /// + public NpgsqlCommandBuilder() + : this(null) + { + } - /// - /// Initializes a new instance of the class. - /// - /// The adapter. - public NpgsqlCommandBuilder(NpgsqlDataAdapter? adapter) - { - DataAdapter = adapter; - QuotePrefix = "\""; - QuoteSuffix = "\""; - } + /// + /// Initializes a new instance of the class. + /// + /// The adapter. + public NpgsqlCommandBuilder(NpgsqlDataAdapter? adapter) + { + DataAdapter = adapter; + QuotePrefix = "\""; + QuoteSuffix = "\""; + } - /// - /// Gets or sets the beginning character or characters to use when specifying database objects (for example, tables or columns) whose names contain characters such as spaces or reserved tokens. - /// - /// - /// The beginning character or characters to use. The default is an empty string. - /// - /// - /// - /// - [AllowNull] - public override string QuotePrefix + /// + /// Gets or sets the beginning character or characters to use when specifying database objects (for example, tables or columns) whose names contain characters such as spaces or reserved tokens. + /// + /// + /// The beginning character or characters to use. The default is an empty string. + /// + /// + /// + /// + [AllowNull] + public override string QuotePrefix + { + get => base.QuotePrefix; + // TODO: Why should it be possible to remove the QuotePrefix? + set { - get => base.QuotePrefix; - // TODO: Why should it be possible to remove the QuotePrefix? - set + if (string.IsNullOrEmpty(value)) + { + base.QuotePrefix = value; + } + else { - if (string.IsNullOrEmpty(value)) - { - base.QuotePrefix = value; - } - else - { - base.QuotePrefix = "\""; - } + base.QuotePrefix = "\""; } } + } - /// - /// Gets or sets the ending character or characters to use when specifying database objects (for example, tables or columns) whose names contain characters such as spaces or reserved tokens. - /// - /// - /// The ending character or characters to use. The default is an empty string. - /// - /// - /// - /// - [AllowNull] - public override string QuoteSuffix + /// + /// Gets or sets the ending character or characters to use when specifying database objects (for example, tables or columns) whose names contain characters such as spaces or reserved tokens. + /// + /// + /// The ending character or characters to use. The default is an empty string. + /// + /// + /// + /// + [AllowNull] + public override string QuoteSuffix + { + get => base.QuoteSuffix; + // TODO: Why should it be possible to remove the QuoteSuffix? + set { - get => base.QuoteSuffix; - // TODO: Why should it be possible to remove the QuoteSuffix? - set + if (string.IsNullOrEmpty(value)) { - if (string.IsNullOrEmpty(value)) - { - base.QuoteSuffix = value; - } - else - { - base.QuoteSuffix = "\""; - } + base.QuoteSuffix = value; + } + else + { + base.QuoteSuffix = "\""; } } + } - /// - /// - /// This method is responsible to derive the command parameter list with values obtained from function definition. - /// It clears the Parameters collection of command. Also, if there is any parameter type which is not supported by Npgsql, an InvalidOperationException will be thrown. - /// Parameters name will be parameter1, parameter2, ... for CommandType.StoredProcedure and named after the placeholder for CommandType.Text - /// - /// NpgsqlCommand whose function parameters will be obtained. - public static void DeriveParameters(NpgsqlCommand command) => command.DeriveParameters(); + /// + /// + /// This method is responsible to derive the command parameter list with values obtained from function definition. + /// It clears the Parameters collection of command. Also, if there is any parameter type which is not supported by Npgsql, an InvalidOperationException will be thrown. + /// Parameters name will be parameter1, parameter2, ... for CommandType.StoredProcedure and named after the placeholder for CommandType.Text + /// + /// NpgsqlCommand whose function parameters will be obtained. + public static void DeriveParameters(NpgsqlCommand command) => command.DeriveParameters(); - /// - /// Gets the automatically generated object required - /// to perform insertions at the data source. - /// - /// - /// The automatically generated object required to perform insertions. - /// - public new NpgsqlCommand GetInsertCommand() => GetInsertCommand(false); + /// + /// Gets the automatically generated object required + /// to perform insertions at the data source. + /// + /// + /// The automatically generated object required to perform insertions. + /// + public new NpgsqlCommand GetInsertCommand() => GetInsertCommand(false); - /// - /// Gets the automatically generated object required to perform insertions - /// at the data source, optionally using columns for parameter names. - /// - /// - /// If true, generate parameter names matching column names, if possible. - /// If false, generate @p1, @p2, and so on. - /// - /// - /// The automatically generated object required to perform insertions. - /// - public new NpgsqlCommand GetInsertCommand(bool useColumnsForParameterNames) - { - var cmd = (NpgsqlCommand) base.GetInsertCommand(useColumnsForParameterNames); - cmd.UpdatedRowSource = UpdateRowSource.None; - return cmd; - } + /// + /// Gets the automatically generated object required to perform insertions + /// at the data source, optionally using columns for parameter names. + /// + /// + /// If , generate parameter names matching column names, if possible. + /// If , generate @p1, @p2, and so on. + /// + /// + /// The automatically generated object required to perform insertions. + /// + public new NpgsqlCommand GetInsertCommand(bool useColumnsForParameterNames) + { + var cmd = (NpgsqlCommand) base.GetInsertCommand(useColumnsForParameterNames); + cmd.UpdatedRowSource = UpdateRowSource.None; + return cmd; + } - /// - /// Gets the automatically generated System.Data.Common.DbCommand object required - /// to perform updates at the data source. - /// - /// - /// The automatically generated System.Data.Common.DbCommand object required to perform updates. - /// - public new NpgsqlCommand GetUpdateCommand() => GetUpdateCommand(false); + /// + /// Gets the automatically generated System.Data.Common.DbCommand object required + /// to perform updates at the data source. + /// + /// + /// The automatically generated System.Data.Common.DbCommand object required to perform updates. + /// + public new NpgsqlCommand GetUpdateCommand() => GetUpdateCommand(false); - /// - /// Gets the automatically generated object required to perform updates - /// at the data source, optionally using columns for parameter names. - /// - /// - /// If true, generate parameter names matching column names, if possible. - /// If false, generate @p1, @p2, and so on. - /// - /// - /// The automatically generated object required to perform updates. - /// - public new NpgsqlCommand GetUpdateCommand(bool useColumnsForParameterNames) - { - var cmd = (NpgsqlCommand)base.GetUpdateCommand(useColumnsForParameterNames); - cmd.UpdatedRowSource = UpdateRowSource.None; - return cmd; - } + /// + /// Gets the automatically generated object required to perform updates + /// at the data source, optionally using columns for parameter names. + /// + /// + /// If , generate parameter names matching column names, if possible. + /// If , generate @p1, @p2, and so on. + /// + /// + /// The automatically generated object required to perform updates. + /// + public new NpgsqlCommand GetUpdateCommand(bool useColumnsForParameterNames) + { + var cmd = (NpgsqlCommand)base.GetUpdateCommand(useColumnsForParameterNames); + cmd.UpdatedRowSource = UpdateRowSource.None; + return cmd; + } - /// - /// Gets the automatically generated System.Data.Common.DbCommand object required - /// to perform deletions at the data source. - /// - /// - /// The automatically generated System.Data.Common.DbCommand object required to perform deletions. - /// - public new NpgsqlCommand GetDeleteCommand() => GetDeleteCommand(false); + /// + /// Gets the automatically generated System.Data.Common.DbCommand object required + /// to perform deletions at the data source. + /// + /// + /// The automatically generated System.Data.Common.DbCommand object required to perform deletions. + /// + public new NpgsqlCommand GetDeleteCommand() => GetDeleteCommand(false); - /// - /// Gets the automatically generated object required to perform deletions - /// at the data source, optionally using columns for parameter names. - /// - /// - /// If true, generate parameter names matching column names, if possible. - /// If false, generate @p1, @p2, and so on. - /// - /// - /// The automatically generated object required to perform deletions. - /// - public new NpgsqlCommand GetDeleteCommand(bool useColumnsForParameterNames) - { - var cmd = (NpgsqlCommand) base.GetDeleteCommand(useColumnsForParameterNames); - cmd.UpdatedRowSource = UpdateRowSource.None; - return cmd; - } + /// + /// Gets the automatically generated object required to perform deletions + /// at the data source, optionally using columns for parameter names. + /// + /// + /// If , generate parameter names matching column names, if possible. + /// If , generate @p1, @p2, and so on. + /// + /// + /// The automatically generated object required to perform deletions. + /// + public new NpgsqlCommand GetDeleteCommand(bool useColumnsForParameterNames) + { + var cmd = (NpgsqlCommand) base.GetDeleteCommand(useColumnsForParameterNames); + cmd.UpdatedRowSource = UpdateRowSource.None; + return cmd; + } - //never used - //private string QualifiedTableName(string schema, string tableName) - //{ - // if (schema == null || schema.Length == 0) - // { - // return tableName; - // } - // else - // { - // return schema + "." + tableName; - // } - //} + //never used + //private string QualifiedTableName(string schema, string tableName) + //{ + // if (schema == null || schema.Length == 0) + // { + // return tableName; + // } + // else + // { + // return schema + "." + tableName; + // } + //} /* private static void SetParameterValuesFromRow(NpgsqlCommand command, DataRow row) @@ -202,114 +202,113 @@ private static void SetParameterValuesFromRow(NpgsqlCommand command, DataRow row } */ - /// - /// Applies the parameter information. - /// - /// The parameter. - /// The row. - /// Type of the statement. - /// if set to true [where clause]. - protected override void ApplyParameterInfo(DbParameter p, DataRow row, System.Data.StatementType statementType, bool whereClause) - { - var param = (NpgsqlParameter)p; - param.NpgsqlDbType = (NpgsqlDbType)row[SchemaTableColumn.ProviderType]; - } + /// + /// Applies the parameter information. + /// + /// The parameter. + /// The row. + /// Type of the statement. + /// If set to [where clause]. + protected override void ApplyParameterInfo(DbParameter p, DataRow row, System.Data.StatementType statementType, bool whereClause) + { + var param = (NpgsqlParameter)p; + param.NpgsqlDbType = (NpgsqlDbType)row[SchemaTableColumn.ProviderType]; + } - /// - /// Returns the name of the specified parameter in the format of @p#. - /// - /// The number to be included as part of the parameter's name.. - /// - /// The name of the parameter with the specified number appended as part of the parameter name. - /// - protected override string GetParameterName(int parameterOrdinal) - => string.Format(CultureInfo.InvariantCulture, "@p{0}", parameterOrdinal); + /// + /// Returns the name of the specified parameter in the format of @p#. + /// + /// The number to be included as part of the parameter's name.. + /// + /// The name of the parameter with the specified number appended as part of the parameter name. + /// + protected override string GetParameterName(int parameterOrdinal) + => string.Format(CultureInfo.InvariantCulture, "@p{0}", parameterOrdinal); - /// - /// Returns the full parameter name, given the partial parameter name. - /// - /// The partial name of the parameter. - /// - /// The full parameter name corresponding to the partial parameter name requested. - /// - protected override string GetParameterName(string parameterName) - => string.Format(CultureInfo.InvariantCulture, "@{0}", parameterName); + /// + /// Returns the full parameter name, given the partial parameter name. + /// + /// The partial name of the parameter. + /// + /// The full parameter name corresponding to the partial parameter name requested. + /// + protected override string GetParameterName(string parameterName) + => string.Format(CultureInfo.InvariantCulture, "@{0}", parameterName); - /// - /// Returns the placeholder for the parameter in the associated SQL statement. - /// - /// The number to be included as part of the parameter's name. - /// - /// The name of the parameter with the specified number appended. - /// - protected override string GetParameterPlaceholder(int parameterOrdinal) - => GetParameterName(parameterOrdinal); + /// + /// Returns the placeholder for the parameter in the associated SQL statement. + /// + /// The number to be included as part of the parameter's name. + /// + /// The name of the parameter with the specified number appended. + /// + protected override string GetParameterPlaceholder(int parameterOrdinal) + => GetParameterName(parameterOrdinal); - /// - /// Registers the to handle the event for a . - /// - /// The to be used for the update. - protected override void SetRowUpdatingHandler(DbDataAdapter adapter) - { - var npgsqlAdapter = adapter as NpgsqlDataAdapter; - if (npgsqlAdapter == null) - throw new ArgumentException("adapter needs to be a NpgsqlDataAdapter", nameof(adapter)); + /// + /// Registers the to handle the event for a . + /// + /// The to be used for the update. + protected override void SetRowUpdatingHandler(DbDataAdapter adapter) + { + var npgsqlAdapter = adapter as NpgsqlDataAdapter; + if (npgsqlAdapter == null) + throw new ArgumentException("adapter needs to be a NpgsqlDataAdapter", nameof(adapter)); - // Being called twice for the same adapter means unregister - if (adapter == DataAdapter) - npgsqlAdapter.RowUpdating -= RowUpdatingHandler; - else - npgsqlAdapter.RowUpdating += RowUpdatingHandler; - } + // Being called twice for the same adapter means unregister + if (adapter == DataAdapter) + npgsqlAdapter.RowUpdating -= RowUpdatingHandler; + else + npgsqlAdapter.RowUpdating += RowUpdatingHandler; + } - /// - /// Adds an event handler for the event. - /// - /// The sender - /// A instance containing information about the event. - void RowUpdatingHandler(object sender, NpgsqlRowUpdatingEventArgs e) => base.RowUpdatingHandler(e); + /// + /// Adds an event handler for the event. + /// + /// The sender + /// A instance containing information about the event. + void RowUpdatingHandler(object sender, NpgsqlRowUpdatingEventArgs e) => base.RowUpdatingHandler(e); - /// - /// Given an unquoted identifier in the correct catalog case, returns the correct quoted form of that identifier, including properly escaping any embedded quotes in the identifier. - /// - /// The original unquoted identifier. - /// - /// The quoted version of the identifier. Embedded quotes within the identifier are properly escaped. - /// - /// - /// - /// - /// Unquoted identifier parameter cannot be null - public override string QuoteIdentifier(string unquotedIdentifier) - => unquotedIdentifier == null - ? throw new ArgumentNullException(nameof(unquotedIdentifier), "Unquoted identifier parameter cannot be null") - : $"{QuotePrefix}{unquotedIdentifier.Replace(QuotePrefix, QuotePrefix + QuotePrefix)}{QuoteSuffix}"; + /// + /// Given an unquoted identifier in the correct catalog case, returns the correct quoted form of that identifier, including properly escaping any embedded quotes in the identifier. + /// + /// The original unquoted identifier. + /// + /// The quoted version of the identifier. Embedded quotes within the identifier are properly escaped. + /// + /// + /// + /// + /// Unquoted identifier parameter cannot be null + public override string QuoteIdentifier(string unquotedIdentifier) + => unquotedIdentifier == null + ? throw new ArgumentNullException(nameof(unquotedIdentifier), "Unquoted identifier parameter cannot be null") + : $"{QuotePrefix}{unquotedIdentifier.Replace(QuotePrefix, QuotePrefix + QuotePrefix)}{QuoteSuffix}"; - /// - /// Given a quoted identifier, returns the correct unquoted form of that identifier, including properly un-escaping any embedded quotes in the identifier. - /// - /// The identifier that will have its embedded quotes removed. - /// - /// The unquoted identifier, with embedded quotes properly un-escaped. - /// - /// - /// - /// - /// Quoted identifier parameter cannot be null - public override string UnquoteIdentifier(string quotedIdentifier) - { - if (quotedIdentifier == null) - throw new ArgumentNullException(nameof(quotedIdentifier), "Quoted identifier parameter cannot be null"); + /// + /// Given a quoted identifier, returns the correct unquoted form of that identifier, including properly un-escaping any embedded quotes in the identifier. + /// + /// The identifier that will have its embedded quotes removed. + /// + /// The unquoted identifier, with embedded quotes properly un-escaped. + /// + /// + /// + /// + /// Quoted identifier parameter cannot be null + public override string UnquoteIdentifier(string quotedIdentifier) + { + if (quotedIdentifier == null) + throw new ArgumentNullException(nameof(quotedIdentifier), "Quoted identifier parameter cannot be null"); - var unquotedIdentifier = quotedIdentifier.Trim().Replace(QuotePrefix + QuotePrefix, QuotePrefix); + var unquotedIdentifier = quotedIdentifier.Trim().Replace(QuotePrefix + QuotePrefix, QuotePrefix); - if (unquotedIdentifier.StartsWith(QuotePrefix)) - unquotedIdentifier = unquotedIdentifier.Remove(0, 1); + if (unquotedIdentifier.StartsWith(QuotePrefix, StringComparison.Ordinal)) + unquotedIdentifier = unquotedIdentifier.Remove(0, 1); - if (unquotedIdentifier.EndsWith(QuoteSuffix)) - unquotedIdentifier = unquotedIdentifier.Remove(unquotedIdentifier.Length - 1, 1); + if (unquotedIdentifier.EndsWith(QuoteSuffix, StringComparison.Ordinal)) + unquotedIdentifier = unquotedIdentifier.Remove(unquotedIdentifier.Length - 1, 1); - return unquotedIdentifier; - } + return unquotedIdentifier; } } diff --git a/src/Npgsql/NpgsqlConnection.cs b/src/Npgsql/NpgsqlConnection.cs index 1d6915ec60..638cd602e7 100644 --- a/src/Npgsql/NpgsqlConnection.cs +++ b/src/Npgsql/NpgsqlConnection.cs @@ -13,1847 +13,2002 @@ using System.Threading; using System.Threading.Tasks; using System.Transactions; -using JetBrains.Annotations; -using Npgsql.Logging; -using Npgsql.NameTranslation; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; using Npgsql.TypeMapping; using Npgsql.Util; -using NpgsqlTypes; using IsolationLevel = System.Data.IsolationLevel; -using static Npgsql.Util.Statics; -namespace Npgsql +namespace Npgsql; + +/// +/// This class represents a connection to a PostgreSQL server. +/// +// ReSharper disable once RedundantNameQualifier +[System.ComponentModel.DesignerCategory("")] +public sealed class NpgsqlConnection : DbConnection, ICloneable, IComponent { + #region Fields + + // Set this when disposed is called. + bool _disposed; + + /// + /// The connection string, without the password after open (unless Persist Security Info=true) + /// + string _userFacingConnectionString = string.Empty; + + /// + /// The original connection string provided by the user, including the password. + /// + string _connectionString = string.Empty; + + ConnectionState _fullState; + + /// + /// The physical connection to the database. This is when the connection is closed, + /// and also when it is open in multiplexing mode and unbound (e.g. not in a transaction). + /// + internal NpgsqlConnector? Connector { get; set; } + /// - /// This class represents a connection to a PostgreSQL server. + /// The parsed connection string. Set only after the connection is opened. /// - // ReSharper disable once RedundantNameQualifier - [System.ComponentModel.DesignerCategory("")] - public sealed class NpgsqlConnection : DbConnection, ICloneable + internal NpgsqlConnectionStringBuilder Settings { get; private set; } = DefaultSettings; + + static readonly NpgsqlConnectionStringBuilder DefaultSettings = new(); + + NpgsqlDataSource? _dataSource; + + internal NpgsqlDataSource NpgsqlDataSource { - #region Fields + get + { + Debug.Assert(_dataSource is not null); + return _dataSource; + } + } + + /// + /// Flag used to make sure we never double-close a connection, returning it twice to the pool. + /// + int _closing; + + internal Transaction? EnlistedTransaction { get; set; } + + /// + /// The global type mapper, which contains defaults used by all new connections. + /// Modify mappings on this mapper to affect your entire application. + /// + [Obsolete("Global-level type mapping has been replaced with data source mapping, see the 7.0 release notes.")] + public static INpgsqlTypeMapper GlobalTypeMapper => TypeMapping.GlobalTypeMapper.Instance; + + /// + /// Connection-level type mapping is no longer supported. See the 7.0 release notes for configuring type mapping on NpgsqlDataSource. + /// + [Obsolete("Connection-level type mapping is no longer supported. See the 7.0 release notes for configuring type mapping on NpgsqlDataSource.", true)] + public INpgsqlTypeMapper TypeMapper + => throw new NotSupportedException(); + + static Func? _cloningInstantiator; + + /// + /// The default TCP/IP port for PostgreSQL. + /// + public const int DefaultPort = 5432; - // Set this when disposed is called. - bool _disposed; + /// + /// Maximum value for connection timeout. + /// + internal const int TimeoutLimit = 1024; - /// - /// The connection string, without the password after open (unless Persist Security Info=true) - /// - string _userFacingConnectionString = string.Empty; + /// + /// Tracks when this connection was bound to a physical connector (e.g. at open-time, when a transaction + /// was started...). + /// + internal ConnectorBindingScope ConnectorBindingScope { get; set; } - /// - /// The original connection string provided by the user, including the password. - /// - string _connectionString = string.Empty; + ILogger _connectionLogger = default!; // Initialized in Open, shouldn't be used otherwise - internal string OriginalConnectionString => _connectionString; + static readonly StateChangeEventArgs ClosedToOpenEventArgs = new(ConnectionState.Closed, ConnectionState.Open); + static readonly StateChangeEventArgs OpenToClosedEventArgs = new(ConnectionState.Open, ConnectionState.Closed); - ConnectionState _fullState; + #endregion Fields - /// - /// The physical connection to the database. This is null when the connection is closed, - /// and also when it is open in multiplexing mode and unbound (e.g. not in a transaction). - /// - internal NpgsqlConnector? Connector { get; set; } + #region Constructors / Init / Open - /// - /// The parsed connection string set by the user - /// - internal NpgsqlConnectionStringBuilder Settings { get; private set; } = DefaultSettings; + /// + /// Initializes a new instance of the class. + /// + public NpgsqlConnection() + => GC.SuppressFinalize(this); - static readonly NpgsqlConnectionStringBuilder DefaultSettings = new NpgsqlConnectionStringBuilder(); + /// + /// Initializes a new instance of with the given connection string. + /// + /// The connection used to open the PostgreSQL database. - ConnectorPool? _pool; - internal ConnectorPool? Pool => _pool; + public NpgsqlConnection(string? connectionString) : this() + => ConnectionString = connectionString; - /// - /// Flag used to make sure we never double-close a connection, returning it twice to the pool. - /// - int _closing; + internal NpgsqlConnection(NpgsqlDataSource dataSource, NpgsqlConnector connector) : this() + { + _dataSource = dataSource; + Settings = dataSource.Settings; + _userFacingConnectionString = dataSource.ConnectionString; + + Connector = connector; + connector.Connection = this; + ConnectorBindingScope = ConnectorBindingScope.Connection; + FullState = ConnectionState.Open; + } - internal Transaction? EnlistedTransaction { get; set; } + internal static NpgsqlConnection FromDataSource(NpgsqlDataSource dataSource) + => new() + { + _dataSource = dataSource, + Settings = dataSource.Settings, + _userFacingConnectionString = dataSource.ConnectionString, + }; - /// - /// The global type mapper, which contains defaults used by all new connections. - /// Modify mappings on this mapper to affect your entire application. - /// - public static INpgsqlTypeMapper GlobalTypeMapper => TypeMapping.GlobalTypeMapper.Instance; + /// + /// Opens a database connection with the property settings specified by the . + /// + public override void Open() => Open(false, CancellationToken.None).GetAwaiter().GetResult(); - /// - /// The connection-specific type mapper - all modifications affect this connection only, - /// and are lost when it is closed. - /// - public INpgsqlTypeMapper TypeMapper + /// + /// This is the asynchronous version of . + /// + /// + /// Do not invoke other methods and properties of the object until the returned Task is complete. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + public override Task OpenAsync(CancellationToken cancellationToken) => Open(async: true, cancellationToken); + + void SetupDataSource() + { + // Fast path: a pool already corresponds to this exact version of the connection string. + if (PoolManager.Pools.TryGetValue(_connectionString, out _dataSource)) { - get - { - if (Settings.Multiplexing) - throw new NotSupportedException("Connection-specific type mapping is unsupported when multiplexing is enabled."); + Settings = _dataSource.Settings; // Great, we already have a pool + return; + } - CheckReady(); - return Connector!.TypeMapper!; - } + // Connection string hasn't been seen before. Check for empty and parse (slow one-time path). + if (_connectionString == string.Empty) + { + Settings = DefaultSettings; + _dataSource = null; + return; } - /// - /// The default TCP/IP port for PostgreSQL. - /// - public const int DefaultPort = 5432; - - /// - /// Maximum value for connection timeout. - /// - internal const int TimeoutLimit = 1024; - - /// - /// Tracks when this connection was bound to a physical connector (e.g. at open-time, when a transaction - /// was started...). - /// - internal ConnectorBindingScope ConnectorBindingScope { get; set; } - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnection)); - - static readonly StateChangeEventArgs ClosedToOpenEventArgs = new StateChangeEventArgs(ConnectionState.Closed, ConnectionState.Open); - static readonly StateChangeEventArgs OpenToClosedEventArgs = new StateChangeEventArgs(ConnectionState.Open, ConnectionState.Closed); - - #endregion Fields - - #region Constructors / Init / Open - - /// - /// Initializes a new instance of the - /// NpgsqlConnection class. - /// - public NpgsqlConnection() - => GC.SuppressFinalize(this); - - /// - /// Initializes a new instance of with the given connection string. - /// - /// The connection used to open the PostgreSQL database. - public NpgsqlConnection(string? connectionString) : this() - => ConnectionString = connectionString; - - /// - /// Opens a database connection with the property settings specified by the - /// ConnectionString. - /// - public override void Open() => Open(false, CancellationToken.None).GetAwaiter().GetResult(); - - /// - /// This is the asynchronous version of . - /// - /// - /// Do not invoke other methods and properties of the object until the returned Task is complete. - /// - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation. - public override Task OpenAsync(CancellationToken cancellationToken) + var settings = new NpgsqlConnectionStringBuilder(_connectionString); + settings.PostProcessAndValidate(); + Settings = settings; + + // The connection string may be equivalent to one that has already been seen though (e.g. different + // ordering). Have NpgsqlConnectionStringBuilder produce a canonical string representation + // and recheck. + // Note that we remove TargetSessionAttributes to make all connection strings that are otherwise identical point to the same pool. + var canonical = settings.ConnectionStringForMultipleHosts; + + if (PoolManager.Pools.TryGetValue(canonical, out _dataSource)) { - using (NoSynchronizationContextScope.Enter()) - return Open(true, cancellationToken); + // If this is a multi-host data source and the user specified a TargetSessionAttributes, create a wrapper in front of the + // MultiHostDataSource with that TargetSessionAttributes. + if (_dataSource is NpgsqlMultiHostDataSource multiHostDataSource && settings.TargetSessionAttributesParsed.HasValue) + _dataSource = multiHostDataSource.WithTargetSession(settings.TargetSessionAttributesParsed.Value); + + // The pool was found, but only under the canonical key - we're using a different version + // for the first time. Map it via our own key for next time. + _dataSource = PoolManager.Pools.GetOrAdd(_connectionString, _dataSource); + return; } - void GetPoolAndSettings() + // Really unseen, need to create a new pool + // The canonical pool is the 'base' pool so we need to set that up first. If someone beats us to it use what they put. + // The connection string pool can either be added here or above, if it's added above we should just use that. + var dataSourceBuilder = new NpgsqlDataSourceBuilder(canonical); + dataSourceBuilder.UseLoggerFactory(NpgsqlLoggingConfiguration.GlobalLoggerFactory); + dataSourceBuilder.EnableParameterLogging(NpgsqlLoggingConfiguration.GlobalIsParameterLoggingEnabled); + var newDataSource = dataSourceBuilder.Build(); + + // See Clone() on the following line: + _cloningInstantiator = s => new NpgsqlConnection(s); + + _dataSource = PoolManager.Pools.GetOrAdd(canonical, newDataSource); + if (_dataSource == newDataSource) { - if (PoolManager.TryGetValue(_connectionString, out _pool)) + Debug.Assert(_dataSource is not MultiHostDataSourceWrapper); + // If the pool we created was the one that ended up being stored we need to increment the appropriate counter. + // Avoids a race condition where multiple threads will create a pool but only one will be stored. + if (_dataSource is NpgsqlMultiHostDataSource multiHostConnectorPool) + foreach (var hostPool in multiHostConnectorPool.Pools) + NpgsqlEventSource.Log.DataSourceCreated(hostPool); + else { - Settings = _pool.Settings; // Great, we already have a pool - return; + NpgsqlEventSource.Log.DataSourceCreated(newDataSource); } + } + else + newDataSource.Dispose(); - // Connection string hasn't been seen before. Parse it. - var settings = new NpgsqlConnectionStringBuilder(_connectionString); - settings.Validate(); - Settings = settings; + // If this is a multi-host data source and the user specified a TargetSessionAttributes, create a wrapper in front of the + // MultiHostDataSource with that TargetSessionAttributes. + if (_dataSource is NpgsqlMultiHostDataSource multiHostDataSource2 && settings.TargetSessionAttributesParsed.HasValue) + _dataSource = multiHostDataSource2.WithTargetSession(settings.TargetSessionAttributesParsed.Value); - // Maybe pooling is off - if (!Settings.Pooling) - return; + _dataSource = PoolManager.Pools.GetOrAdd(_connectionString, _dataSource); + } - // The connection string may be equivalent to one that has already been seen though (e.g. different - // ordering). Have NpgsqlConnectionStringBuilder produce a canonical string representation - // and recheck. - var canonical = Settings.ConnectionString; + internal Task Open(bool async, CancellationToken cancellationToken) + { + CheckClosed(); + Debug.Assert(Connector == null); + + if (_dataSource is null) + { + Debug.Assert(string.IsNullOrEmpty(_connectionString)); + ThrowHelper.ThrowInvalidOperationException("The ConnectionString property has not been initialized."); + } - if (PoolManager.TryGetValue(canonical, out _pool)) + _userFacingConnectionString = _dataSource.ConnectionString; + _connectionLogger = _dataSource.LoggingConfiguration.ConnectionLogger; + if (_connectionLogger.IsEnabled(LogLevel.Trace)) + LogMessages.OpeningConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); + + if (Settings.Multiplexing) + { + if (Settings.Enlist && Transaction.Current != null) { - // The pool was found, but only under the canonical key - we're using a different version - // for the first time. Map it via our own key for next time. - _pool = PoolManager.GetOrAdd(_connectionString, _pool); - return; + // TODO: Keep in mind that the TransactionScope can be disposed + ThrowHelper.ThrowNotSupportedException(); } - // Really unseen, need to create a new pool - // The canonical pool is the 'base' pool so we need to set that up first. If someone beats us to it use what they put. - // The connection string pool can either be added here or above, if it's added above we should just use that. - var newPool = new ConnectorPool(Settings, canonical); - _pool = PoolManager.GetOrAdd(canonical, newPool); + // We're opening in multiplexing mode, without a transaction. We don't actually do anything. - // If the pool we created was the one that ended up being stored we need to increment the appropriate counter. - // Avoids a race condition where multiple threads will create a pool but only one will be stored. - if (_pool == newPool) + // If we've never connected with this connection string, open a physical connector in order to generate + // any exception (bad user/password, IP address...). This reproduces the standard error behavior. + if (!_dataSource.IsBootstrapped) { - // If the pool we created was the one that ended up being stored we need to increment the appropriate counter. - // Avoids a race condition where multiple threads will create a pool but only one will be stored. - NpgsqlEventSource.Log.PoolCreated(); + FullState = ConnectionState.Connecting; + return PerformMultiplexingStartupCheck(async, cancellationToken); } - _pool = PoolManager.GetOrAdd(_connectionString, _pool); + if (_connectionLogger.IsEnabled(LogLevel.Debug)) + LogMessages.OpenedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); + FullState = ConnectionState.Open; + + return Task.CompletedTask; } - internal Task Open(bool async, CancellationToken cancellationToken) + return OpenAsync(async, cancellationToken); + + async Task OpenAsync(bool async, CancellationToken cancellationToken) { - CheckClosed(); - Debug.Assert(Connector == null); + Debug.Assert(!Settings.Multiplexing); - Log.Trace("Opening connection..."); FullState = ConnectionState.Connecting; - - if (Settings.Multiplexing) + NpgsqlConnector? connector = null; + try { - Debug.Assert(_pool != null, "Multiplexing is off by default, and cannot be on without pooling"); + var connectionTimeout = TimeSpan.FromSeconds(ConnectionTimeout); + var timeout = new NpgsqlTimeout(connectionTimeout); + + var enlistToTransaction = Settings.Enlist ? Transaction.Current : null; - if (Settings.Enlist && Transaction.Current != null) + // First, check to see if we there's an ambient transaction, and we have a connection enlisted + // to this transaction which has been closed. If so, return that as an optimization rather than + // opening a new one and triggering escalation to a distributed transaction. + // Otherwise just get a new connector and enlist below. + if (enlistToTransaction is not null && _dataSource.TryRentEnlistedPending(enlistToTransaction, this, out connector)) { - // TODO: Keep in mind that the TransactionScope can be disposed - throw new NotImplementedException(); + EnlistedTransaction = enlistToTransaction; + enlistToTransaction = null; } + else + connector = await _dataSource.Get(this, timeout, async, cancellationToken).ConfigureAwait(false); - // We're opening in multiplexing mode, without a transaction. We don't actually do anything. - _userFacingConnectionString = _pool.UserFacingConnectionString; - - // If we've never connected with this connection string, open a physical connector in order to generate - // any exception (bad user/password, IP address...). This reproduces the standard error behavior. - if (!_pool.IsBootstrapped) - return BootstrapMultiplexing(cancellationToken); + Debug.Assert(connector.Connection is null, + $"Connection for opened connector '{Connector?.Id.ToString() ?? "???"}' is bound to another connection"); - CompleteOpen(); - return Task.CompletedTask; - } + ConnectorBindingScope = ConnectorBindingScope.Connection; + connector.Connection = this; + Connector = connector; - return OpenAsync(cancellationToken); + if (enlistToTransaction is not null) + EnlistTransaction(enlistToTransaction); - async Task OpenAsync(CancellationToken cancellationToken2) + LogMessages.OpenedConnection(_connectionLogger, Host!, Port, Database, _userFacingConnectionString, connector.Id); + FullState = ConnectionState.Open; + } + catch { - Debug.Assert(!Settings.Multiplexing); + FullState = ConnectionState.Closed; + ConnectorBindingScope = ConnectorBindingScope.None; + Connector = null; + EnlistedTransaction = null; - NpgsqlConnector? connector = null; - try - { - var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); - - if (_pool == null) // Un-pooled connection (or user forgot to set connection string) - { - if (string.IsNullOrEmpty(_connectionString)) - throw new InvalidOperationException("The ConnectionString property has not been initialized."); - - if (!Settings.PersistSecurityInfo) - _userFacingConnectionString = Settings.ToStringWithoutPassword(); - - connector = new NpgsqlConnector(this); - await connector.Open(timeout, async, cancellationToken2); - } - else - { - _userFacingConnectionString = _pool.UserFacingConnectionString; - - if (Settings.Enlist && Transaction.Current is Transaction transaction) - { - // First, check to see if we there's an ambient transaction, and we have a connection enlisted - // to this transaction which has been closed. If so, return that as an optimization rather than - // opening a new one and triggering escalation to a distributed transaction. - // Otherwise just get a new connector and enlist. - if (_pool.TryRentEnlistedPending(transaction, out connector)) - { - connector.Connection = this; - EnlistedTransaction = transaction; - } - else - { - connector = await _pool.Rent(this, timeout, async, cancellationToken2); - ConnectorBindingScope = ConnectorBindingScope.Connection; - Connector = connector; - EnlistTransaction(Transaction.Current); - } - } - else - connector = await _pool.Rent(this, timeout, async, cancellationToken2); - } - - Debug.Assert(connector.Connection == this, - $"Connection for opened connector {Connector} isn't the same as this connection"); - - ConnectorBindingScope = ConnectorBindingScope.Connection; - Connector = connector; - - // Since this connector was last used, PostgreSQL types (e.g. enums) may have been added - // (and ReloadTypes() called), or global mappings may have changed by the user. - // Bring this up to date if needed. - // Note that in multiplexing execution, the pool-wide type mapper is used so no - // need to update the connector type mapper (this is why this is here). - if (connector.TypeMapper.ChangeCounter != TypeMapping.GlobalTypeMapper.Instance.ChangeCounter) - await connector.LoadDatabaseInfo(false, timeout, async, cancellationToken); - - CompleteOpen(); - } - catch + if (connector is not null) { - FullState = ConnectionState.Closed; - ConnectorBindingScope = ConnectorBindingScope.None; - Connector = null; - - if (connector != null) - { - if (_pool == null) - connector.Close(); - else - _pool.Return(connector); - } - - throw; + connector.Connection = null; + connector.Return(); } - } - async Task BootstrapMultiplexing(CancellationToken cancellationToken2) - { - try - { - var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); - await _pool!.BootstrapMultiplexing(this, timeout, async, cancellationToken2); - CompleteOpen(); - } - catch - { - FullState = ConnectionState.Closed; - throw; - } + throw; } + } - void CompleteOpen() + async Task PerformMultiplexingStartupCheck(bool async, CancellationToken cancellationToken) + { + try { - Log.Debug("Connection opened (multiplexing)"); + var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); + + _ = await StartBindingScope(ConnectorBindingScope.Connection, timeout, async, cancellationToken).ConfigureAwait(false); + EndBindingScope(ConnectorBindingScope.Connection); + + LogMessages.OpenedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); + FullState = ConnectionState.Open; - OnStateChange(ClosedToOpenEventArgs); + } + catch + { + FullState = ConnectionState.Closed; + throw; } } + } - #endregion Open / Init + #endregion Open / Init - #region Connection string management + #region Connection string management - /// - /// Gets or sets the string used to connect to a PostgreSQL database. See the manual for details. - /// - /// The connection string that includes the server name, - /// the database name, and other parameters needed to establish - /// the initial connection. The default value is an empty string. - /// - [AllowNull] - public override string ConnectionString + /// + /// Gets or sets the string used to connect to a PostgreSQL database. See the manual for details. + /// + /// The connection string that includes the server name, + /// the database name, and other parameters needed to establish + /// the initial connection. The default value is an empty string. + /// + [AllowNull] + public override string ConnectionString + { + get => _userFacingConnectionString; + set { - get => _userFacingConnectionString; - set - { - CheckClosed(); + CheckClosed(); - _userFacingConnectionString = _connectionString = value ?? string.Empty; - GetPoolAndSettings(); - } + _userFacingConnectionString = _connectionString = value ?? string.Empty; + SetupDataSource(); } + } - /// - /// Gets or sets the delegate used to generate a password for new database connections. - /// - /// - /// This delegate is executed when a new database connection is opened that requires a password. - /// Password and - /// Passfile connection string - /// properties have precedence over this delegate. It will not be executed if a password is - /// specified, or the specified or default Passfile contains a valid entry. - /// Due to connection pooling this delegate is only executed when a new physical connection - /// is opened, not when reusing a connection that was previously opened from the pool. - /// - public ProvidePasswordCallback? ProvidePasswordCallback { get; set; } - - #endregion Connection string management - - #region Configuration settings - - /// - /// Backend server host name. - /// - [Browsable(true)] - public string? Host => Settings.Host; - - /// - /// Backend server port. - /// - [Browsable(true)] - public int Port => Settings.Port; - - /// - /// Gets the time (in seconds) to wait while trying to establish a connection - /// before terminating the attempt and generating an error. - /// - /// The time (in seconds) to wait for a connection to open. The default value is 15 seconds. - public override int ConnectionTimeout => Settings.Timeout; - - /// - /// Gets the time (in seconds) to wait while trying to execute a command - /// before terminating the attempt and generating an error. - /// - /// The time (in seconds) to wait for a command to complete. The default value is 20 seconds. - public int CommandTimeout => Settings.CommandTimeout; - - /// - /// Gets the name of the current database or the database to be used after a connection is opened. - /// - /// The name of the current database or the name of the database to be - /// used after a connection is opened. The default value is the empty string. - public override string Database => Settings.Database ?? Settings.Username ?? ""; - - /// - /// Gets the string identifying the database server (host and port) - /// - /// - /// The name of the database server (host and port). If the connection uses a Unix-domain socket, - /// the path to that socket is returned. The default value is the empty string. - /// - public override string DataSource => Settings.DataSourceCached; - - /// - /// Whether to use Windows integrated security to log in. - /// - public bool IntegratedSecurity => Settings.IntegratedSecurity; - - /// - /// User name. - /// - public string? UserName => Settings.Username; - - internal string? Password => Settings.Password; - - // The following two lines are here for backwards compatibility with the EF6 provider - // ReSharper disable UnusedMember.Global - internal string? EntityTemplateDatabase => Settings.EntityTemplateDatabase; - internal string? EntityAdminDatabase => Settings.EntityAdminDatabase; - // ReSharper restore UnusedMember.Global - - #endregion Configuration settings - - #region State management - - /// - /// Gets the current state of the connection. - /// - /// A bitwise combination of the ConnectionState values. The default is Closed. - [Browsable(false)] - public ConnectionState FullState - { - // Note: we allow accessing the state after dispose, #164 - get => _fullState switch - { - ConnectionState.Open => Connector == null - ? ConnectionState.Open // When unbound, we only know we're open - : Connector.State switch - { - ConnectorState.Ready => ConnectionState.Open, - ConnectorState.Executing => ConnectionState.Open | ConnectionState.Executing, - ConnectorState.Fetching => ConnectionState.Open | ConnectionState.Fetching, - ConnectorState.Copy => ConnectionState.Open | ConnectionState.Fetching, - ConnectorState.Replication => ConnectionState.Open | ConnectionState.Fetching, - ConnectorState.Waiting => ConnectionState.Open | ConnectionState.Fetching, - ConnectorState.Connecting => ConnectionState.Connecting, - ConnectorState.Broken => ConnectionState.Broken, - ConnectorState.Closed => throw new InvalidOperationException("Internal Npgsql bug: connection is in state Open but connector is in state Closed"), - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {Connector.State} of enum {nameof(ConnectorState)}. Please file a bug.") - }, - _ => _fullState - }; - internal set => _fullState = value; - } + /// + /// Gets or sets the delegate used to generate a password for new database connections. + /// + /// + ///

+ /// This delegate is executed when a new database connection is opened that requires a password. + ///

+ ///

+ /// The and connection + /// string properties have precedence over this delegate: it will not be executed if a password is specified, or if the specified or + /// default Passfile contains a valid entry. + ///

+ ///

+ /// Due to connection pooling this delegate is only executed when a new physical connection is opened, not when reusing a connection + /// that was previously opened from the pool. + ///

+ ///
+ [Obsolete("Use NpgsqlDataSourceBuilder.UsePeriodicPasswordProvider or inject passwords directly into NpgsqlDataSource.Password")] + public ProvidePasswordCallback? ProvidePasswordCallback { get; set; } + + #endregion Connection string management + + #region Configuration settings - /// - /// Gets whether the current state of the connection is Open or Closed - /// - /// ConnectionState.Open, ConnectionState.Closed or ConnectionState.Connecting - [Browsable(false)] - public override ConnectionState State - { - get - { - var s = FullState; - if ((s & ConnectionState.Open) != 0) - return ConnectionState.Open; - if ((s & ConnectionState.Connecting) != 0) - return ConnectionState.Connecting; - return ConnectionState.Closed; - } - } + /// + /// Backend server host name. + /// + [Browsable(true)] + public string? Host => Connector?.Host; - #endregion State management + /// + /// Backend server port. + /// + [Browsable(true)] + public int Port => Connector?.Port ?? 0; - #region Commands + /// + /// Gets the time (in seconds) to wait while trying to establish a connection + /// before terminating the attempt and generating an error. + /// + /// The time (in seconds) to wait for a connection to open. The default value is 15 seconds. + public override int ConnectionTimeout => Settings.Timeout; - /// - /// Creates and returns a DbCommand - /// object associated with the IDbConnection. - /// - /// A DbCommand object. - protected override DbCommand CreateDbCommand() - { - return CreateCommand(); - } + /// + /// Gets the time (in seconds) to wait while trying to execute a command + /// before terminating the attempt and generating an error. + /// + /// The time (in seconds) to wait for a command to complete. The default value is 30 seconds. + public int CommandTimeout => Settings.CommandTimeout; - /// - /// Creates and returns a NpgsqlCommand - /// object associated with the NpgsqlConnection. - /// - /// A NpgsqlCommand object. - public new NpgsqlCommand CreateCommand() - { - CheckDisposed(); - return new NpgsqlCommand("", this); - } + /// + /// Gets the name of the current database or the database to be used after a connection is opened. + /// + /// The name of the current database or the name of the database to be + /// used after a connection is opened. The default value is the empty string. + public override string Database => Settings.Database ?? Settings.Username ?? ""; - #endregion Commands - - #region Transactions - - /// - /// Begins a database transaction with the specified isolation level. - /// - /// The isolation level under which the transaction should run. - /// An DbTransaction - /// object representing the new transaction. - /// - /// Currently the IsolationLevel ReadCommitted and Serializable are supported by the PostgreSQL backend. - /// There's no support for nested transactions. - /// - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => BeginTransaction(isolationLevel); - - /// - /// Begins a database transaction. - /// - /// A NpgsqlTransaction - /// object representing the new transaction. - /// - /// Currently there's no support for nested transactions. Transactions created by this method will have Read Committed isolation level. - /// - public new NpgsqlTransaction BeginTransaction() - => BeginTransaction(IsolationLevel.Unspecified); - - /// - /// Begins a database transaction with the specified isolation level. - /// - /// The isolation level under which the transaction should run. - /// A NpgsqlTransaction - /// object representing the new transaction. - /// - /// Currently the IsolationLevel ReadCommitted and Serializable are supported by the PostgreSQL backend. - /// There's no support for nested transactions. - /// - public new NpgsqlTransaction BeginTransaction(IsolationLevel level) - => BeginTransaction(level, async: false, CancellationToken.None).GetAwaiter().GetResult(); - - async ValueTask BeginTransaction(IsolationLevel level, bool async, CancellationToken cancellationToken) - { - if (level == IsolationLevel.Chaos) - throw new NotSupportedException("Unsupported IsolationLevel: " + level); + /// + /// Gets the string identifying the database server (host and port) + /// + /// + /// The name of the database server (host and port). If the connection uses a Unix-domain socket, + /// the path to that socket is returned. The default value is the empty string. + /// + public override string DataSource => Connector?.Settings.DataSourceCached ?? _dataSource?.Settings.DataSourceCached ?? string.Empty; - CheckReady(); - if (Connector != null && Connector.InTransaction) - throw new InvalidOperationException("A transaction is already in progress; nested/concurrent transactions aren't supported."); + /// + /// User name. + /// + public string? UserName => Settings.Username; - // There was a commited/rollbacked transaction, but it was not disposed - var connector = ConnectorBindingScope == ConnectorBindingScope.Transaction ? - Connector - : await StartBindingScope(ConnectorBindingScope.Transaction, NpgsqlTimeout.Infinite, async, - cancellationToken); + #endregion Configuration settings - Debug.Assert(connector != null); + #region State management - try - { - // Note that beginning a transaction doesn't actually send anything to the backend (only prepends). - // But we start a user action to check the cancellation token and generate exceptions - using var _ = connector.StartUserAction(cancellationToken); + /// + /// Gets the current state of the connection. + /// + /// A bitwise combination of the values. The default is Closed. + [Browsable(false)] + public ConnectionState FullState + { + // Note: we allow accessing the state after dispose, #164 + get + { + if (_fullState != ConnectionState.Open) + return _fullState; - connector.Transaction ??= new NpgsqlTransaction(connector); - connector.Transaction.Init(level); - return connector.Transaction; - } - catch + if (Connector is null) + return ConnectionState.Open; // When unbound, we only know we're open + + switch (Connector.State) { - EndBindingScope(ConnectorBindingScope.Transaction); - throw; + case ConnectorState.Ready: + return ConnectionState.Open; + case ConnectorState.Executing: + return ConnectionState.Open | ConnectionState.Executing; + case ConnectorState.Fetching: + case ConnectorState.Copy: + case ConnectorState.Replication: + case ConnectorState.Waiting: + return ConnectionState.Open | ConnectionState.Fetching; + case ConnectorState.Connecting: + return ConnectionState.Connecting; + case ConnectorState.Broken: + return ConnectionState.Broken; + case ConnectorState.Closed: + ThrowHelper.ThrowInvalidOperationException("Internal Npgsql bug: connection is in state Open but connector is in state Closed"); + return ConnectionState.Broken; + default: + ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {{0}} of enum {nameof(ConnectorState)}. Please file a bug.", Connector.State); + return ConnectionState.Broken; } } - -#if !NETSTANDARD2_0 - /// - /// Asynchronously begins a database transaction. - /// - /// An optional token to cancel the asynchronous operation. The default value is None. - /// A task whose Result property is an object representing the new transaction. - /// - /// Currently there's no support for nested transactions. Transactions created by this method will have Read Committed isolation level. - /// - public new ValueTask BeginTransactionAsync(CancellationToken cancellationToken = default) - => BeginTransactionAsync(IsolationLevel.Unspecified, cancellationToken); - - /// - /// Asynchronously begins a database transaction. - /// - /// The isolation level under which the transaction should run. - /// An optional token to cancel the asynchronous operation. The default value is None. - /// A task whose Result property is an object representing the new transaction. - /// - /// Currently the IsolationLevel ReadCommitted and Serializable are supported by the PostgreSQL backend. - /// There's no support for nested transactions. - /// - public new ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken cancellationToken = default) + internal set { - using (NoSynchronizationContextScope.Enter()) - return BeginTransaction(level, async: true, cancellationToken); - } -#endif + if (value is < 0 or > ConnectionState.Broken) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(value), "Unknown connection state", value); - /// - /// Enlist transaction. - /// - public override void EnlistTransaction(Transaction? transaction) - { - if (Settings.Multiplexing) - throw new NotSupportedException("Ambient transactions aren't yet implemented for multiplexing"); + var originalOpen = _fullState.HasFlag(ConnectionState.Open); - if (EnlistedTransaction != null) + _fullState = value; + var currentOpen = _fullState.HasFlag(ConnectionState.Open); + if (currentOpen != originalOpen) { - if (EnlistedTransaction.Equals(transaction)) - return; - try - { - if (EnlistedTransaction.TransactionInformation.Status == System.Transactions.TransactionStatus.Active) - throw new InvalidOperationException($"Already enlisted to transaction (localid={EnlistedTransaction.TransactionInformation.LocalIdentifier})"); - } - catch (ObjectDisposedException) - { - // The MSDTC 2nd phase is asynchronous, so we may end up checking the TransactionInformation on - // a disposed transaction. To be extra safe we catch that, and understand that the transaction - // has ended - no problem for reenlisting. - } + OnStateChange(currentOpen + ? ClosedToOpenEventArgs + : OpenToClosedEventArgs); } + } + } - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Transaction); + /// + /// Gets whether the current state of the connection is Open or Closed + /// + /// ConnectionState.Open, ConnectionState.Closed or ConnectionState.Connecting + [Browsable(false)] + public override ConnectionState State + { + get + { + var fullState = FullState; + if (fullState.HasFlag(ConnectionState.Connecting)) + return ConnectionState.Connecting; - EnlistedTransaction = transaction; - if (transaction == null) - return; + if (fullState.HasFlag(ConnectionState.Open)) + return ConnectionState.Open; - // Until #1378 is implemented, we have no recovery, and so no need to enlist as a durable resource manager - // (or as promotable single phase). + return ConnectionState.Closed; + } + } - // Note that even when #1378 is implemented in some way, we should check for mono and go volatile in any case - - // distributed transactions aren't supported. + #endregion State management - transaction.EnlistVolatile(new VolatileResourceManager(this, transaction), EnlistmentOptions.None); - Log.Debug($"Enlisted volatile resource manager (localid={transaction.TransactionInformation.LocalIdentifier})", connector.Id); - } + #region Command / Batch creation - #endregion + /// + /// A cached command handed out by , which is returned when disposed. Useful for reducing allocations. + /// + internal NpgsqlCommand? CachedCommand { get; set; } - #region Close + /// + /// Creates and returns a + /// object associated with the . + /// + /// A object. + protected override DbCommand CreateDbCommand() => CreateCommand(); - /// - /// Releases the connection. If the connection is pooled, it will be returned to the pool and made available for re-use. - /// If it is non-pooled, the physical connection will be closed. - /// - public override void Close() => Close(async: false); + /// + /// Creates and returns a object associated with the . + /// + /// A object. + public new NpgsqlCommand CreateCommand() + { + CheckDisposed(); - /// - /// Releases the connection. If the connection is pooled, it will be returned to the pool and made available for re-use. - /// If it is non-pooled, the physical connection will be closed. - /// -#if NETSTANDARD2_0 - public Task CloseAsync() -#else - public override Task CloseAsync() -#endif + var cachedCommand = CachedCommand; + if (cachedCommand is not null) { - using (NoSynchronizationContextScope.Enter()) - return Close(async: true); + CachedCommand = null; + cachedCommand.State = CommandState.Idle; + return cachedCommand; } - internal Task Close(bool async, CancellationToken cancellationToken = default) - { - // Even though NpgsqlConnection isn't thread safe we'll make sure this part is. - // Because we really don't want double returns to the pool. - if (Interlocked.Exchange(ref _closing, 1) == 1) - return Task.CompletedTask; + return NpgsqlCommand.CreateCachedCommand(this); + } - switch (FullState) - { - case ConnectionState.Open: - case ConnectionState.Open | ConnectionState.Executing: - case ConnectionState.Open | ConnectionState.Fetching: - case ConnectionState.Broken: - break; - case ConnectionState.Closed: - Volatile.Write(ref _closing, 0); - return Task.CompletedTask; - case ConnectionState.Connecting: - Volatile.Write(ref _closing, 0); - throw new InvalidOperationException("Can't close, connection is in state " + FullState); - default: - Volatile.Write(ref _closing, 0); - throw new ArgumentOutOfRangeException("Unknown connection state: " + FullState); - } + /// + /// A cached batch handed out by , which is returned when disposed. Useful for reducing allocations. + /// + internal NpgsqlBatch? CachedBatch { get; set; } - // TODO: The following shouldn't exist - we need to flow down the regular path to close any - // open reader / COPY. See test CloseDuringRead with multiplexing. - if (Settings.Multiplexing && ConnectorBindingScope == ConnectorBindingScope.None) - { - // TODO: Consider falling through to the regular reset logic. This adds some unneeded conditions - // and assignment but actual perf impact should be negligible (measure). - Debug.Assert(Connector == null); - FullState = ConnectionState.Closed; - Log.Debug("Connection closed (multiplexing)"); - OnStateChange(OpenToClosedEventArgs); - Volatile.Write(ref _closing, 0); - return Task.CompletedTask; - } +#if NET6_0_OR_GREATER + /// + public override bool CanCreateBatch => true; - return CloseAsync(cancellationToken); + /// + protected override DbBatch CreateDbBatch() => CreateBatch(); - async Task CloseAsync(CancellationToken cancellationToken) - { - Debug.Assert(Connector != null); - var connector = Connector; - Log.Trace("Closing connection...", connector.Id); + /// + public new NpgsqlBatch CreateBatch() + { + CheckDisposed(); - using var _ = Defer(() => Volatile.Write(ref _closing, 0)); + var cachedBatch = CachedBatch; + if (cachedBatch is not null) + { + CachedBatch = null; + return cachedBatch; + } - if (connector.CurrentReader != null || connector.CurrentCopyOperation != null) - { - // This method could re-enter connection.Close() due to an underlying connection failure. - await connector.CloseOngoingOperations(async, cancellationToken); - } + return NpgsqlBatch.CreateCachedBatch(this); + } +#else + /// + /// Creates and returns a object associated with the . + /// + /// A object. + public NpgsqlBatch CreateBatch() => new(this); +#endif - Debug.Assert(connector.IsReady || connector.IsBroken); - Debug.Assert(connector.CurrentReader == null); - Debug.Assert(connector.CurrentCopyOperation == null); + #endregion Command / Batch creation - if (connector.IsBroken) - { - connector.Connection = null; + #region Transactions - if (_pool == null) - connector.Close(); - else - _pool.Return(connector); + /// + /// Begins a database transaction with the specified isolation level. + /// + /// The isolation level under which the transaction should run. + /// A object representing the new transaction. + /// Nested transactions are not supported. + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => BeginTransaction(isolationLevel); - EnlistedTransaction = null; - } - else if (EnlistedTransaction != null) - { - // A System.Transactions transaction is still in progress + /// + /// Begins a database transaction. + /// + /// A object representing the new transaction. + /// + /// Nested transactions are not supported. + /// Transactions created by this method will have the isolation level. + /// + public new NpgsqlTransaction BeginTransaction() + => BeginTransaction(IsolationLevel.Unspecified); - connector.Connection = null; + /// + /// Begins a database transaction with the specified isolation level. + /// + /// The isolation level under which the transaction should run. + /// A object representing the new transaction. + /// Nested transactions are not supported. + public new NpgsqlTransaction BeginTransaction(IsolationLevel level) + => BeginTransaction(async: false, level, CancellationToken.None).GetAwaiter().GetResult(); - // If pooled, close the connection and disconnect it from the resource manager but leave the - // connector in an enlisted pending list in the pool. If another connection is opened within - // the same transaction scope, we will reuse this connector to avoid escalating to a distributed - // transaction - // If a *non-pooled* connection is being closed but is enlisted in an ongoing - // TransactionScope, we do nothing - simply detach the connector from the connection and leave - // it open. It will be closed when the TransactionScope is disposed. - _pool?.AddPendingEnlistedConnector(connector, EnlistedTransaction); + async ValueTask BeginTransaction(bool async, IsolationLevel level, CancellationToken cancellationToken) + { + if (level == IsolationLevel.Chaos) + ThrowHelper.ThrowNotSupportedException($"Unsupported IsolationLevel: {nameof(IsolationLevel.Chaos)}"); - EnlistedTransaction = null; - } - else - { - if (_pool == null) - connector.Close(); - else - { - // Clear the buffer, roll back any pending transaction and prepend a reset message if needed - // Also returns the connector to the pool, if there is an open transaction and multiplexing is on - await connector.Reset(async, cancellationToken); - - if (Settings.Multiplexing) - { - // We've already closed ongoing operations rolled back any transaction and the connector is already in the pool, - // so we must be unbound. Nothing to do. - Debug.Assert(ConnectorBindingScope == ConnectorBindingScope.None, - $"When closing a multiplexed connection, the connection was supposed to be unbound, but {nameof(ConnectorBindingScope)} was {ConnectorBindingScope}"); - } - else - { - connector.Connection = null; - _pool.Return(connector); - } - } - } + CheckReady(); + if (Connector is { InTransaction: true }) + ThrowHelper.ThrowInvalidOperationException("A transaction is already in progress; nested/concurrent transactions aren't supported."); - Connector = null; - ConnectorBindingScope = ConnectorBindingScope.None; - FullState = ConnectionState.Closed; - Log.Debug("Connection closed", connector.Id); - OnStateChange(OpenToClosedEventArgs); - } - } + // There was a committed/rolled back transaction, but it was not disposed + var connector = ConnectorBindingScope == ConnectorBindingScope.Transaction + ? Connector + : await StartBindingScope(ConnectorBindingScope.Transaction, NpgsqlTimeout.Infinite, async, cancellationToken).ConfigureAwait(false); + + Debug.Assert(connector != null); - /// - /// Releases all resources used by the NpgsqlConnection. - /// - /// true when called from Dispose(); - /// false when being called from the finalizer. - protected override void Dispose(bool disposing) + try { - if (_disposed) - return; - if (disposing) - Close(); - _disposed = true; - } + // Note that beginning a transaction doesn't actually send anything to the backend (only prepends). + // But we start a user action to check the cancellation token and generate exceptions + using var _ = connector.StartUserAction(cancellationToken); - /// - /// Releases all resources used by the NpgsqlConnection. - /// -#if NETSTANDARD2_0 - public async ValueTask DisposeAsync() -#else - public override async ValueTask DisposeAsync() -#endif + connector.Transaction ??= new NpgsqlTransaction(connector); + connector.Transaction.Init(level); + return connector.Transaction; + } + catch { - if (_disposed) - return; - await CloseAsync(); - _disposed = true; + EndBindingScope(ConnectorBindingScope.Transaction); + throw; } + } - #endregion - - #region Notifications and Notices - - /// - /// Fires when PostgreSQL notices are received from PostgreSQL. - /// - /// - /// PostgreSQL notices are non-critical messages generated by PostgreSQL, either as a result of a user query - /// (e.g. as a warning or informational notice), or due to outside activity (e.g. if the database administrator - /// initiates a "fast" database shutdown). - /// - /// Note that notices are very different from notifications (see the event). - /// - public event NoticeEventHandler? Notice; - - /// - /// Fires when PostgreSQL notifications are received from PostgreSQL. - /// - /// - /// PostgreSQL notifications are sent when your connection has registered for notifications on a specific channel via the - /// LISTEN command. NOTIFY can be used to generate such notifications, allowing for an inter-connection communication channel. - /// - /// Note that notifications are very different from notices (see the event). - /// - public event NotificationEventHandler? Notification; - - internal void OnNotice(PostgresNotice e) +#if !NETSTANDARD2_0 + /// + /// Asynchronously begins a database transaction. + /// + /// The isolation level under which the transaction should run. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task whose property is an object representing the new transaction. + /// + /// Nested transactions are not supported. + /// + protected override async ValueTask BeginDbTransactionAsync(IsolationLevel isolationLevel, CancellationToken cancellationToken) + => await BeginTransactionAsync(isolationLevel, cancellationToken).ConfigureAwait(false); + + /// + /// Asynchronously begins a database transaction. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task whose Result property is an object representing the new transaction. + /// + /// Nested transactions are not supported. + /// Transactions created by this method will have the isolation level. + /// + public new ValueTask BeginTransactionAsync(CancellationToken cancellationToken = default) + => BeginTransactionAsync(IsolationLevel.Unspecified, cancellationToken); + + /// + /// Asynchronously begins a database transaction. + /// + /// The isolation level under which the transaction should run. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task whose property is an object representing the new transaction. + /// + /// Nested transactions are not supported. + /// + public new ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken cancellationToken = default) + => BeginTransaction(async: true, level, cancellationToken); +#endif + + /// + /// Enlist transaction. + /// + public override void EnlistTransaction(Transaction? transaction) + { + if (Settings.Multiplexing) + throw new NotSupportedException("Ambient transactions aren't yet implemented for multiplexing"); + + if (EnlistedTransaction != null) { + if (EnlistedTransaction.Equals(transaction)) + return; try { - Notice?.Invoke(this, new NpgsqlNoticeEventArgs(e)); + if (EnlistedTransaction.TransactionInformation.Status == System.Transactions.TransactionStatus.Active) + throw new InvalidOperationException($"Already enlisted to transaction (localid={EnlistedTransaction.TransactionInformation.LocalIdentifier})"); } - catch (Exception ex) + catch (ObjectDisposedException) { - // Block all exceptions bubbling up from the user's event handler - Log.Error("User exception caught when emitting notice event", ex); + // The MSDTC 2nd phase is asynchronous, so we may end up checking the TransactionInformation on + // a disposed transaction. To be extra safe we catch that, and understand that the transaction + // has ended - no problem for reenlisting. } } - internal void OnNotification(NpgsqlNotificationEventArgs e) + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Transaction); + + EnlistedTransaction = transaction; + if (transaction == null) { - try - { - Notification?.Invoke(this, e); - } - catch (Exception ex) - { - // Block all exceptions bubbling up from the user's event handler - Log.Error("User exception caught when emitting notification event", ex); - } + EndBindingScope(ConnectorBindingScope.Transaction); + return; } - #endregion Notifications and Notices - - #region SSL - - /// - /// Returns whether SSL is being used for the connection. - /// - internal bool IsSecure => CheckOpenAndRunInTemporaryScope(c => c.IsSecure); - - /// - /// Returns whether SCRAM-SHA256 is being user for the connection - /// - internal bool IsScram => CheckOpenAndRunInTemporaryScope(c => c.IsScram); - - /// - /// Returns whether SCRAM-SHA256-PLUS is being user for the connection - /// - internal bool IsScramPlus => CheckOpenAndRunInTemporaryScope(c => c.IsScramPlus); - - /// - /// Selects the local Secure Sockets Layer (SSL) certificate used for authentication. - /// - /// - /// See - /// - public ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; set; } - - /// - /// Verifies the remote Secure Sockets Layer (SSL) certificate used for authentication. - /// Ignored if is set. - /// - /// - /// See - /// - public RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; set; } - - #endregion SSL - - #region Backend version, capabilities, settings - - // TODO: We should probably move DatabaseInfo from each connector to the pool (but remember unpooled) - - /// - /// Version of the PostgreSQL backend. - /// This can only be called when there is an active connection. - /// - [Browsable(false)] - public Version PostgreSqlVersion => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.Version); - - /// - /// PostgreSQL server version. - /// - public override string ServerVersion => PostgreSqlVersion.ToString(); - - /// - /// Process id of backend server. - /// This can only be called when there is an active connection. - /// - [Browsable(false)] - // ReSharper disable once InconsistentNaming - public int ProcessID - { - get - { - CheckOpen(); + // Until #1378 is implemented, we have no recovery, and so no need to enlist as a durable resource manager + // (or as promotable single phase). - return TryGetBoundConnector(out var connector) - ? connector.BackendProcessId - : throw new InvalidOperationException("No bound physical connection (using multiplexing)"); - } - } + // Note that even when #1378 is implemented in some way, we should check for mono and go volatile in any case - + // distributed transactions aren't supported. - /// - /// Reports whether the backend uses the newer integer timestamp representation. - /// Note that the old floating point representation is not supported. - /// Meant for use by type plugins (e.g. NodaTime) - /// - [Browsable(false)] - public bool HasIntegerDateTimes => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.HasIntegerDateTimes); - - /// - /// The connection's timezone as reported by PostgreSQL, in the IANA/Olson database format. - /// - [Browsable(false)] - public string Timezone => CheckOpenAndRunInTemporaryScope(c => c.Timezone); - - /// - /// Holds all PostgreSQL parameters received for this connection. Is updated if the values change - /// (e.g. as a result of a SET command). - /// - [Browsable(false)] - public IReadOnlyDictionary PostgresParameters - => CheckOpenAndRunInTemporaryScope(c => c.PostgresParameters); - - #endregion Backend version, capabilities, settings - - #region Copy - - /// - /// Begins a binary COPY FROM STDIN operation, a high-performance data import mechanism to a PostgreSQL table. - /// - /// A COPY FROM STDIN SQL command - /// A which can be used to write rows and columns - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) - { - if (copyFromCommand == null) - throw new ArgumentNullException(nameof(copyFromCommand)); - if (!copyFromCommand.TrimStart().ToUpper().StartsWith("COPY")) - throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); + var volatileResourceManager = new VolatileResourceManager(this, transaction); + transaction.EnlistVolatile(volatileResourceManager, EnlistmentOptions.None); + volatileResourceManager.Init(); + EnlistedTransaction = transaction; - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + LogMessages.EnlistedVolatileResourceManager( + Connector!.LoggingConfiguration.TransactionLogger, + transaction.TransactionInformation.LocalIdentifier, + connector.Id); + } - Log.Debug("Starting binary import", connector.Id); - connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); - try - { - var importer = new NpgsqlBinaryImporter(connector, copyFromCommand); - connector.CurrentCopyOperation = importer; - return importer; - } - catch - { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); - throw; - } - } + #endregion - /// - /// Begins a binary COPY TO STDOUT operation, a high-performance data export mechanism from a PostgreSQL table. - /// - /// A COPY TO STDOUT SQL command - /// A which can be used to read rows and columns - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) - { - if (copyToCommand == null) - throw new ArgumentNullException(nameof(copyToCommand)); - if (!copyToCommand.TrimStart().ToUpper().StartsWith("COPY")) - throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); + #region Close - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + /// + /// Releases the connection. If the connection is pooled, it will be returned to the pool and made available for re-use. + /// If it is non-pooled, the physical connection will be closed. + /// + public override void Close() => Close(async: false).GetAwaiter().GetResult(); - Log.Debug("Starting binary export", connector.Id); - connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); - try - { - var exporter = new NpgsqlBinaryExporter(connector, copyToCommand); - connector.CurrentCopyOperation = exporter; - return exporter; - } - catch - { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); - throw; - } + /// + /// Releases the connection. If the connection is pooled, it will be returned to the pool and made available for re-use. + /// If it is non-pooled, the physical connection will be closed. + /// +#if NETSTANDARD2_0 + public Task CloseAsync() +#else + public override Task CloseAsync() +#endif + => Close(async: true); + + internal bool TakeCloseLock() => Interlocked.Exchange(ref _closing, 1) == 0; + + internal void ReleaseCloseLock() => Volatile.Write(ref _closing, 0); + + internal Task Close(bool async) + { + // Even though NpgsqlConnection isn't thread safe we'll make sure this part is. + // Because we really don't want double returns to the pool. + if (!TakeCloseLock()) + return Task.CompletedTask; + + switch (FullState) + { + case ConnectionState.Open: + case ConnectionState.Open | ConnectionState.Executing: + case ConnectionState.Open | ConnectionState.Fetching: + break; + case ConnectionState.Broken: + FullState = ConnectionState.Closed; + goto case ConnectionState.Closed; + case ConnectionState.Closed: + ReleaseCloseLock(); + return Task.CompletedTask; + case ConnectionState.Connecting: + ReleaseCloseLock(); + throw new InvalidOperationException("Can't close, connection is in state " + FullState); + default: + ReleaseCloseLock(); + throw new ArgumentOutOfRangeException("Unknown connection state: " + FullState); } - /// - /// Begins a textual COPY FROM STDIN operation, a data import mechanism to a PostgreSQL table. - /// It is the user's responsibility to send the textual input according to the format specified - /// in . - /// - /// A COPY FROM STDIN SQL command - /// - /// A TextWriter that can be used to send textual data. - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public TextWriter BeginTextImport(string copyFromCommand) + // TODO: The following shouldn't exist - we need to flow down the regular path to close any + // open reader / COPY. See test CloseDuringRead with multiplexing. + if (Settings.Multiplexing && ConnectorBindingScope == ConnectorBindingScope.None) { - if (copyFromCommand == null) - throw new ArgumentNullException(nameof(copyFromCommand)); - if (!copyFromCommand.TrimStart().ToUpper().StartsWith("COPY")) - throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); + // TODO: Consider falling through to the regular reset logic. This adds some unneeded conditions + // and assignment but actual perf impact should be negligible (measure). + Debug.Assert(Connector == null); + ReleaseCloseLock(); - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + FullState = ConnectionState.Closed; + LogMessages.ClosedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - Log.Debug("Starting text import", connector.Id); - connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); - try - { - var writer = new NpgsqlCopyTextWriter(connector, new NpgsqlRawCopyStream(connector, copyFromCommand)); - connector.CurrentCopyOperation = writer; - return writer; - } - catch - { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); - throw; - } + return Task.CompletedTask; } - /// - /// Begins a textual COPY TO STDOUT operation, a data export mechanism from a PostgreSQL table. - /// It is the user's responsibility to parse the textual input according to the format specified - /// in . - /// - /// A COPY TO STDOUT SQL command - /// - /// A TextReader that can be used to read textual data. - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public TextReader BeginTextExport(string copyToCommand) - { - if (copyToCommand == null) - throw new ArgumentNullException(nameof(copyToCommand)); - if (!copyToCommand.TrimStart().ToUpper().StartsWith("COPY")) - throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); + return CloseAsync(async); + } - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + async Task CloseAsync(bool async) + { + Debug.Assert(Connector != null); + Debug.Assert(ConnectorBindingScope != ConnectorBindingScope.None); - Log.Debug("Starting text export", connector.Id); - connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); - try + try + { + var connector = Connector; + LogMessages.ClosingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString, connector.Id); + + if (connector.CurrentReader != null || connector.CurrentCopyOperation != null) { - var reader = new NpgsqlCopyTextReader(connector, new NpgsqlRawCopyStream(connector, copyToCommand)); - connector.CurrentCopyOperation = reader; - return reader; + // This method could re-enter connection.Close() due to an underlying connection failure. + await connector.CloseOngoingOperations(async).ConfigureAwait(false); + + if (ConnectorBindingScope == ConnectorBindingScope.None) + { + Debug.Assert(Settings.Multiplexing); + Debug.Assert(Connector is null); + + FullState = ConnectionState.Closed; + LogMessages.ClosedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); + return; + } } - catch + + Debug.Assert(connector.IsReady || connector.IsBroken, $"Connector is not ready or broken during close, it's {connector.State}"); + Debug.Assert(connector.CurrentReader == null); + Debug.Assert(connector.CurrentCopyOperation == null); + + if (EnlistedTransaction != null) { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); - throw; - } - } + // A System.Transactions transaction is still in progress - /// - /// Begins a raw binary COPY operation (TO STDOUT or FROM STDIN), a high-performance data export/import mechanism to a PostgreSQL table. - /// Note that unlike the other COPY API methods, doesn't implement any encoding/decoding - /// and is unsuitable for structured import/export operation. It is useful mainly for exporting a table as an opaque - /// blob, for the purpose of importing it back later. - /// - /// A COPY TO STDOUT or COPY FROM STDIN SQL command - /// A that can be used to read or write raw binary data. - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public NpgsqlRawCopyStream BeginRawBinaryCopy(string copyCommand) - { - if (copyCommand == null) - throw new ArgumentNullException(nameof(copyCommand)); - if (!copyCommand.TrimStart().ToUpper().StartsWith("COPY")) - throw new ArgumentException("Must contain a COPY TO STDOUT OR COPY FROM STDIN command!", nameof(copyCommand)); + connector.Connection = null; - CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + // Close the connection and disconnect it from the resource manager but leave the + // connector in an enlisted pending list in the data source. If another connection is opened within + // the same transaction scope, we will reuse this connector to avoid escalating to a distributed + // transaction + _dataSource?.AddPendingEnlistedConnector(connector, EnlistedTransaction); - Log.Debug("Starting raw COPY operation", connector.Id); - connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); - try + EnlistedTransaction = null; + } + else { - var stream = new NpgsqlRawCopyStream(connector, copyCommand); - if (!stream.IsBinary) + if (Settings.Pooling) { - // TODO: Stop the COPY operation gracefully, no breaking - throw connector.Break(new ArgumentException( - "copyToCommand triggered a text transfer, only binary is allowed", nameof(copyCommand))); + // Clear the buffer, roll back any pending transaction and prepend a reset message if needed + // Also returns the connector to the pool, if there is an open transaction and multiplexing is on + // Note that we're doing this only for pooled connections + await connector.Reset(async).ConfigureAwait(false); + } + else + { + // We're already doing the same in the NpgsqlConnector.Reset for pooled connections + // TODO: move reset logic to ConnectorSource.Return + connector.Transaction?.UnbindIfNecessary(); + } + + if (Settings.Multiplexing) + { + // We've already closed ongoing operations rolled back any transaction and the connector is already in the pool, + // so we must be unbound. Nothing to do. + Debug.Assert(ConnectorBindingScope == ConnectorBindingScope.None, + $"When closing a multiplexed connection, the connection was supposed to be unbound, but {nameof(ConnectorBindingScope)} was {ConnectorBindingScope}"); + } + else + { + connector.Connection = null; + connector.Return(); } - connector.CurrentCopyOperation = stream; - return stream; - } - catch - { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); - throw; } - } - #endregion - - #region Enum mapping - - /// - /// Maps a CLR enum to a PostgreSQL enum type for use with this connection. - /// - /// - /// CLR enum labels are mapped by name to PostgreSQL enum labels. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. - /// If there is a discrepancy between the .NET and database labels while an enum is read or written, - /// an exception will be raised. - /// - /// Can only be invoked on an open connection; if the connection is closed the mapping is lost. - /// - /// To avoid mapping the type for each connection, use the method. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET enum type to be mapped - [Obsolete("Use NpgsqlConnection.TypeMapper.MapEnum() instead")] - public void MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum - => TypeMapper.MapEnum(pgName, nameTranslator); - - /// - /// Maps a CLR enum to a PostgreSQL enum type for use with all connections created from now on. Existing connections aren't affected. - /// - /// - /// CLR enum labels are mapped by name to PostgreSQL enum labels. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. - /// If there is a discrepancy between the .NET and database labels while an enum is read or written, - /// an exception will be raised. - /// - /// To map the type for a specific connection, use the method. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET enum type to be mapped - [Obsolete("Use NpgsqlConnection.GlobalTypeMapper.MapEnum() instead")] - public static void MapEnumGlobally(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum - => GlobalTypeMapper.MapEnum(pgName, nameTranslator); - - /// - /// Removes a previous global enum mapping. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - [Obsolete("Use NpgsqlConnection.GlobalTypeMapper.UnmapEnum() instead")] - public static void UnmapEnumGlobally(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum - => GlobalTypeMapper.UnmapEnum(pgName, nameTranslator); - - #endregion - - #region Composite registration - - /// - /// Maps a CLR type to a PostgreSQL composite type for use with this connection. - /// - /// - /// CLR fields and properties by string to PostgreSQL enum labels. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your members to manually specify a PostgreSQL enum label. - /// If there is a discrepancy between the .NET and database labels while a composite is read or written, - /// an exception will be raised. - /// - /// Can only be invoked on an open connection; if the connection is closed the mapping is lost. - /// - /// To avoid mapping the type for each connection, use the method. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET type to be mapped - [Obsolete("Use NpgsqlConnection.TypeMapper.MapComposite() instead")] - public void MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : new() - => TypeMapper.MapComposite(pgName, nameTranslator); - - /// - /// Maps a CLR type to a PostgreSQL composite type for use with all connections created from now on. Existing connections aren't affected. - /// - /// - /// CLR fields and properties by string to PostgreSQL enum labels. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your members to manually specify a PostgreSQL enum label. - /// If there is a discrepancy between the .NET and database labels while a composite is read or written, - /// an exception will be raised. - /// - /// To map the type for a specific connection, use the method. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET type to be mapped - [Obsolete("Use NpgsqlConnection.GlobalTypeMapper.MapComposite() instead")] - public static void MapCompositeGlobally(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : new() - => GlobalTypeMapper.MapComposite(pgName, nameTranslator); - - /// - /// Removes a previous global enum mapping. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - [Obsolete("Use NpgsqlConnection.GlobalTypeMapper.UnmapComposite() instead")] - public static void UnmapCompositeGlobally(string pgName, INpgsqlNameTranslator? nameTranslator = null) where T : new() - => GlobalTypeMapper.UnmapComposite(pgName, nameTranslator); - - #endregion - - #region Wait - - /// - /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and - /// exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - /// - /// The time-out value, in milliseconds, passed to . - /// The default value is 0, which indicates an infinite time-out period. - /// Specifying -1 also indicates an infinite time-out period. - /// - /// true if an asynchronous message was received, false if timed out. - public bool Wait(int timeout) + LogMessages.ClosedConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString, connector.Id); + Connector = null; + ConnectorBindingScope = ConnectorBindingScope.None; + FullState = ConnectionState.Closed; + } + finally { - if (timeout != -1 && timeout < 0) - throw new ArgumentException("Argument must be -1, 0 or positive", nameof(timeout)); - if (Settings.Multiplexing) - throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); + ReleaseCloseLock(); + } + } + + /// + /// Releases all resources used by the . + /// + /// when called from ; + /// when being called from the finalizer. + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + Close(); + _disposed = true; + } + + /// + /// Releases all resources used by the . + /// +#if NETSTANDARD2_0 + public async ValueTask DisposeAsync() +#else + public override async ValueTask DisposeAsync() +#endif + { + if (_disposed) + return; + + await CloseAsync().ConfigureAwait(false); + _disposed = true; + } + + internal void MakeDisposed() + => _disposed = true; + + #endregion - CheckReady(); + #region Notifications and Notices - Log.Debug($"Starting to wait (timeout={timeout})...", Connector!.Id); - return Connector!.Wait(async: false, timeout, CancellationToken.None).GetAwaiter().GetResult(); + /// + /// Fires when PostgreSQL notices are received from PostgreSQL. + /// + /// + /// PostgreSQL notices are non-critical messages generated by PostgreSQL, either as a result of a user query + /// (e.g. as a warning or informational notice), or due to outside activity (e.g. if the database administrator + /// initiates a "fast" database shutdown). + /// + /// Note that notices are very different from notifications (see the event). + /// + public event NoticeEventHandler? Notice; + + /// + /// Fires when PostgreSQL notifications are received from PostgreSQL. + /// + /// + /// PostgreSQL notifications are sent when your connection has registered for notifications on a specific channel via the + /// LISTEN command. NOTIFY can be used to generate such notifications, allowing for an inter-connection communication channel. + /// + /// Note that notifications are very different from notices (see the event). + /// + public event NotificationEventHandler? Notification; + + internal void OnNotice(PostgresNotice e) + { + try + { + Notice?.Invoke(this, new NpgsqlNoticeEventArgs(e)); + } + catch (Exception ex) + { + // Block all exceptions bubbling up from the user's event handler + LogMessages.CaughtUserExceptionInNoticeEventHandler(_connectionLogger, ex); } + } - /// - /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and - /// exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - /// - /// The time-out value is passed to . - /// - /// true if an asynchronous message was received, false if timed out. - public bool Wait(TimeSpan timeout) => Wait((int)timeout.TotalMilliseconds); - - /// - /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and - /// exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - public void Wait() => Wait(0); - - /// - /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) - /// arrives, and exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - /// - /// The time-out value, in milliseconds. - /// The default value is 0, which indicates an infinite time-out period. - /// Specifying -1 also indicates an infinite time-out period. - /// - /// The token to monitor for cancellation requests. The default value is . - /// true if an asynchronous message was received, false if timed out. - public Task WaitAsync(int timeout, CancellationToken cancellationToken = default) + internal void OnNotification(NpgsqlNotificationEventArgs e) + { + try + { + Notification?.Invoke(this, e); + } + catch (Exception ex) { - if (Settings.Multiplexing) - throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); + // Block all exceptions bubbling up from the user's event handler + LogMessages.CaughtUserExceptionInNotificationEventHandler(_connectionLogger, ex); + } + } - CheckReady(); + #endregion Notifications and Notices - Log.Debug("Starting to wait asynchronously...", Connector!.Id); - using (NoSynchronizationContextScope.Enter()) - return Connector!.Wait(async: true, timeout, cancellationToken); - } + #region SSL + + /// + /// Returns whether SSL is being used for the connection. + /// + internal bool IsSecure => CheckOpenAndRunInTemporaryScope(c => c.IsSecure); + + /// + /// Returns whether SCRAM-SHA256 is being user for the connection + /// + internal bool IsScram => CheckOpenAndRunInTemporaryScope(c => c.IsScram); + + /// + /// Returns whether SCRAM-SHA256-PLUS is being user for the connection + /// + internal bool IsScramPlus => CheckOpenAndRunInTemporaryScope(c => c.IsScramPlus); + + /// + /// Selects the local Secure Sockets Layer (SSL) certificate used for authentication. + /// + /// + /// See + /// + public ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; set; } + + /// + /// When using SSL/TLS, this is a callback that allows customizing how the PostgreSQL-provided certificate is verified. This is an + /// advanced API, consider using or instead. + /// + /// + /// + /// Cannot be used in conjunction with , and + /// . + /// + /// + /// See . + /// + /// + public RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; set; } + + #endregion SSL + + #region Backend version, capabilities, settings + + // TODO: We should probably move DatabaseInfo from each connector to the pool (but remember unpooled) + + /// + /// The version of the PostgreSQL server we're connected to. + /// + ///

+ /// This can only be called when the connection is open. + ///

+ ///

+ /// In case of a development or pre-release version this field will contain + /// the version of the next version to be released from this branch. + ///

+ ///
+ ///
+ [Browsable(false)] + public Version PostgreSqlVersion => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.Version); - /// - /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) - /// arrives, and exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - /// - /// The time-out value as - /// - /// The token to monitor for cancellation requests. The default value is . - /// true if an asynchronous message was received, false if timed out. - public Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken = default) => WaitAsync((int)timeout.TotalMilliseconds, cancellationToken); - - /// - /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) - /// arrives, and exits immediately. The asynchronous message is delivered via the normal events - /// (, ). - /// - /// The token to monitor for cancellation requests. The default value is . - public Task WaitAsync(CancellationToken cancellationToken = default) => WaitAsync(0, cancellationToken); - - #endregion - - #region State checks - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckOpen() + /// + /// The PostgreSQL server version as returned by the server_version option. + /// + /// This can only be called when the connection is open. + /// + /// + public override string ServerVersion => CheckOpenAndRunInTemporaryScope( + c => c.DatabaseInfo.ServerVersion); + + /// + /// Process id of backend server. + /// This can only be called when there is an active connection. + /// + [Browsable(false)] + // ReSharper disable once InconsistentNaming + public int ProcessID + { + get { - CheckDisposed(); + CheckOpen(); - switch (FullState) - { - case ConnectionState.Open: - case ConnectionState.Open | ConnectionState.Executing: - case ConnectionState.Open | ConnectionState.Fetching: - case ConnectionState.Connecting: - break; - case ConnectionState.Closed: - case ConnectionState.Broken: - throw new InvalidOperationException("Connection is not open"); - default: - throw new ArgumentOutOfRangeException(); - } + return TryGetBoundConnector(out var connector) + ? connector.BackendProcessId + : throw new InvalidOperationException("No bound physical connection (using multiplexing)"); } + } + + /// + /// Reports whether the backend uses the newer integer timestamp representation. + /// Note that the old floating point representation is not supported. + /// Meant for use by type plugins (e.g. NodaTime) + /// + [Browsable(false)] + public bool HasIntegerDateTimes => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.HasIntegerDateTimes); + + /// + /// The connection's timezone as reported by PostgreSQL, in the IANA/Olson database format. + /// + [Browsable(false)] + public string Timezone => CheckOpenAndRunInTemporaryScope(c => c.Timezone); + + /// + /// Holds all PostgreSQL parameters received for this connection. Is updated if the values change + /// (e.g. as a result of a SET command). + /// + [Browsable(false)] + public IReadOnlyDictionary PostgresParameters + => CheckOpenAndRunInTemporaryScope(c => c.PostgresParameters); + + #endregion Backend version, capabilities, settings + + #region Copy + + /// + /// Begins a binary COPY FROM STDIN operation, a high-performance data import mechanism to a PostgreSQL table. + /// + /// A COPY FROM STDIN SQL command + /// A which can be used to write rows and columns + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) + => BeginBinaryImport(async: false, copyFromCommand, CancellationToken.None).GetAwaiter().GetResult(); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckClosed() + /// + /// Begins a binary COPY FROM STDIN operation, a high-performance data import mechanism to a PostgreSQL table. + /// + /// A COPY FROM STDIN SQL command + /// An optional token to cancel the asynchronous operation. The default value is None. + /// A which can be used to write rows and columns + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public Task BeginBinaryImportAsync(string copyFromCommand, CancellationToken cancellationToken = default) + => BeginBinaryImport(async: true, copyFromCommand, cancellationToken); + + async Task BeginBinaryImport(bool async, string copyFromCommand, CancellationToken cancellationToken = default) + { + if (copyFromCommand == null) + throw new ArgumentNullException(nameof(copyFromCommand)); + if (!IsValidCopyCommand(copyFromCommand)) + throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); + + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Copy); + + LogMessages.StartingBinaryImport(connector.LoggingConfiguration.CopyLogger, connector.Id); + // no point in passing a cancellationToken here, as we register the cancellation in the Init method + connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + try { - CheckDisposed(); + var importer = new NpgsqlBinaryImporter(connector); + await importer.Init(copyFromCommand, async, cancellationToken).ConfigureAwait(false); + connector.CurrentCopyOperation = importer; + return importer; + } + catch + { + connector.EndUserAction(); + EndBindingScope(ConnectorBindingScope.Copy); + throw; + } + } - switch (FullState) - { - case ConnectionState.Closed: - case ConnectionState.Broken: - break; - case ConnectionState.Open: - case ConnectionState.Connecting: - case ConnectionState.Open | ConnectionState.Executing: - case ConnectionState.Open | ConnectionState.Fetching: - throw new InvalidOperationException("Connection already open"); - default: - throw new ArgumentOutOfRangeException(); - } + /// + /// Begins a binary COPY TO STDOUT operation, a high-performance data export mechanism from a PostgreSQL table. + /// + /// A COPY TO STDOUT SQL command + /// A which can be used to read rows and columns + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) + => BeginBinaryExport(async: false, copyToCommand, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Begins a binary COPY TO STDOUT operation, a high-performance data export mechanism from a PostgreSQL table. + /// + /// A COPY TO STDOUT SQL command + /// An optional token to cancel the asynchronous operation. The default value is None. + /// A which can be used to read rows and columns + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public Task BeginBinaryExportAsync(string copyToCommand, CancellationToken cancellationToken = default) + => BeginBinaryExport(async: true, copyToCommand, cancellationToken); + + async Task BeginBinaryExport(bool async, string copyToCommand, CancellationToken cancellationToken = default) + { + if (copyToCommand == null) + throw new ArgumentNullException(nameof(copyToCommand)); + if (!IsValidCopyCommand(copyToCommand)) + throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); + + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Copy); + + LogMessages.StartingBinaryExport(connector.LoggingConfiguration.CopyLogger, connector.Id); + // no point in passing a cancellationToken here, as we register the cancellation in the Init method + connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + try + { + var exporter = new NpgsqlBinaryExporter(connector); + await exporter.Init(copyToCommand, async, cancellationToken).ConfigureAwait(false); + connector.CurrentCopyOperation = exporter; + return exporter; + } + catch + { + connector.EndUserAction(); + EndBindingScope(ConnectorBindingScope.Copy); + throw; } + } + + /// + /// Begins a textual COPY FROM STDIN operation, a data import mechanism to a PostgreSQL table. + /// It is the user's responsibility to send the textual input according to the format specified + /// in . + /// + /// A COPY FROM STDIN SQL command + /// + /// A TextWriter that can be used to send textual data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public TextWriter BeginTextImport(string copyFromCommand) + => BeginTextImport(async: false, copyFromCommand, CancellationToken.None).GetAwaiter().GetResult(); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckDisposed() + /// + /// Begins a textual COPY FROM STDIN operation, a data import mechanism to a PostgreSQL table. + /// It is the user's responsibility to send the textual input according to the format specified + /// in . + /// + /// A COPY FROM STDIN SQL command + /// An optional token to cancel the asynchronous operation. The default value is None. + /// + /// A TextWriter that can be used to send textual data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public Task BeginTextImportAsync(string copyFromCommand, CancellationToken cancellationToken = default) + => BeginTextImport(async: true, copyFromCommand, cancellationToken); + + async Task BeginTextImport(bool async, string copyFromCommand, CancellationToken cancellationToken = default) + { + if (copyFromCommand == null) + throw new ArgumentNullException(nameof(copyFromCommand)); + if (!IsValidCopyCommand(copyFromCommand)) + throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); + + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Copy); + + LogMessages.StartingTextImport(connector.LoggingConfiguration.CopyLogger, connector.Id); + // no point in passing a cancellationToken here, as we register the cancellation in the Init method + connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + try + { + var copyStream = new NpgsqlRawCopyStream(connector); + await copyStream.Init(copyFromCommand, async, cancellationToken).ConfigureAwait(false); + var writer = new NpgsqlCopyTextWriter(connector, copyStream); + connector.CurrentCopyOperation = writer; + return writer; + } + catch { - if (_disposed) - throw new ObjectDisposedException(typeof(NpgsqlConnection).Name); + connector.EndUserAction(); + EndBindingScope(ConnectorBindingScope.Copy); + throw; } + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void CheckReady() + /// + /// Begins a textual COPY TO STDOUT operation, a data export mechanism from a PostgreSQL table. + /// It is the user's responsibility to parse the textual input according to the format specified + /// in . + /// + /// A COPY TO STDOUT SQL command + /// + /// A TextReader that can be used to read textual data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public TextReader BeginTextExport(string copyToCommand) + => BeginTextExport(async: false, copyToCommand, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Begins a textual COPY TO STDOUT operation, a data export mechanism from a PostgreSQL table. + /// It is the user's responsibility to parse the textual input according to the format specified + /// in . + /// + /// A COPY TO STDOUT SQL command + /// An optional token to cancel the asynchronous operation. The default value is None. + /// + /// A TextReader that can be used to read textual data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public Task BeginTextExportAsync(string copyToCommand, CancellationToken cancellationToken = default) + => BeginTextExport(async: true, copyToCommand, cancellationToken); + + async Task BeginTextExport(bool async, string copyToCommand, CancellationToken cancellationToken = default) + { + if (copyToCommand == null) + throw new ArgumentNullException(nameof(copyToCommand)); + if (!IsValidCopyCommand(copyToCommand)) + throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); + + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Copy); + + LogMessages.StartingTextExport(connector.LoggingConfiguration.CopyLogger, connector.Id); + // no point in passing a cancellationToken here, as we register the cancellation in the Init method + connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + try + { + var copyStream = new NpgsqlRawCopyStream(connector); + await copyStream.Init(copyToCommand, async, cancellationToken).ConfigureAwait(false); + var reader = new NpgsqlCopyTextReader(connector, copyStream); + connector.CurrentCopyOperation = reader; + return reader; + } + catch { - CheckDisposed(); + connector.EndUserAction(); + EndBindingScope(ConnectorBindingScope.Copy); + throw; + } + } + + /// + /// Begins a raw binary COPY operation (TO STDOUT or FROM STDIN), a high-performance data export/import mechanism to a PostgreSQL table. + /// Note that unlike the other COPY API methods, doesn't implement any encoding/decoding + /// and is unsuitable for structured import/export operation. It is useful mainly for exporting a table as an opaque + /// blob, for the purpose of importing it back later. + /// + /// A COPY TO STDOUT or COPY FROM STDIN SQL command + /// A that can be used to read or write raw binary data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public NpgsqlRawCopyStream BeginRawBinaryCopy(string copyCommand) + => BeginRawBinaryCopy(async: false, copyCommand, CancellationToken.None).GetAwaiter().GetResult(); - switch (FullState) + /// + /// Begins a raw binary COPY operation (TO STDOUT or FROM STDIN), a high-performance data export/import mechanism to a PostgreSQL table. + /// Note that unlike the other COPY API methods, doesn't implement any encoding/decoding + /// and is unsuitable for structured import/export operation. It is useful mainly for exporting a table as an opaque + /// blob, for the purpose of importing it back later. + /// + /// A COPY TO STDOUT or COPY FROM STDIN SQL command + /// An optional token to cancel the asynchronous operation. The default value is None. + /// A that can be used to read or write raw binary data. + /// + /// See https://www.postgresql.org/docs/current/static/sql-copy.html. + /// + public Task BeginRawBinaryCopyAsync(string copyCommand, CancellationToken cancellationToken = default) + => BeginRawBinaryCopy(async: true, copyCommand, cancellationToken); + + async Task BeginRawBinaryCopy(bool async, string copyCommand, CancellationToken cancellationToken = default) + { + if (copyCommand == null) + throw new ArgumentNullException(nameof(copyCommand)); + if (!IsValidCopyCommand(copyCommand)) + throw new ArgumentException("Must contain a COPY TO STDOUT OR COPY FROM STDIN command!", nameof(copyCommand)); + + CheckReady(); + var connector = StartBindingScope(ConnectorBindingScope.Copy); + + LogMessages.StartingRawCopy(connector.LoggingConfiguration.CopyLogger, connector.Id); + // no point in passing a cancellationToken here, as we register the cancellation in the Init method + connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + try + { + var stream = new NpgsqlRawCopyStream(connector); + await stream.Init(copyCommand, async, cancellationToken).ConfigureAwait(false); + if (!stream.IsBinary) { - case ConnectionState.Open: - case ConnectionState.Connecting: // We need to do type loading as part of connecting - break; - case ConnectionState.Closed: - case ConnectionState.Broken: - throw new InvalidOperationException("Connection is not open"); - case ConnectionState.Open | ConnectionState.Executing: - case ConnectionState.Open | ConnectionState.Fetching: - throw new InvalidOperationException("Connection is busy"); - default: - throw new ArgumentOutOfRangeException(); + // TODO: Stop the COPY operation gracefully, no breaking + throw connector.Break(new ArgumentException( + "copyToCommand triggered a text transfer, only binary is allowed", nameof(copyCommand))); } + connector.CurrentCopyOperation = stream; + return stream; + } + catch + { + connector.EndUserAction(); + EndBindingScope(ConnectorBindingScope.Copy); + throw; } + } - #endregion State checks + static bool IsValidCopyCommand(string copyCommand) + { + #if NET6_0_OR_GREATER || NETSTANDARD2_1 + return copyCommand.AsSpan().TrimStart().StartsWith("COPY", StringComparison.OrdinalIgnoreCase); + #else + return copyCommand.TrimStart().StartsWith("COPY", StringComparison.OrdinalIgnoreCase); + #endif + } + #endregion - #region Connector binding + #region Wait - /// - /// Returns whether the connection is currently bound to a connector. - /// - internal bool IsBound => ConnectorBindingScope != ConnectorBindingScope.None; + /// + /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and + /// exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + /// + /// The time-out value, in milliseconds, passed to . + /// The default value is 0, which indicates an infinite time-out period. + /// Specifying -1 also indicates an infinite time-out period. + /// + /// true if an asynchronous message was received, false if timed out. + public bool Wait(int timeout) + { + if (timeout != -1 && timeout < 0) + throw new ArgumentException("Argument must be -1, 0 or positive", nameof(timeout)); + if (Settings.Multiplexing) + throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); + + CheckReady(); + + LogMessages.StartingWait(_connectionLogger, timeout, Connector!.Id); + return Connector!.Wait(async: false, timeout, CancellationToken.None).GetAwaiter().GetResult(); + } + + /// + /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and + /// exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + /// + /// The time-out value is passed to . + /// + /// true if an asynchronous message was received, false if timed out. + public bool Wait(TimeSpan timeout) => Wait((int)timeout.TotalMilliseconds); - /// - /// Checks whether the connection is currently bound to a connector, and if so, returns it via - /// . - /// - internal bool TryGetBoundConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + /// + /// Waits until an asynchronous PostgreSQL messages (e.g. a notification) arrives, and + /// exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + public void Wait() => Wait(0); + + /// + /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) + /// arrives, and exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + /// + /// The time-out value, in milliseconds. + /// The default value is 0, which indicates an infinite time-out period. + /// Specifying -1 also indicates an infinite time-out period. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// true if an asynchronous message was received, false if timed out. + public Task WaitAsync(int timeout, CancellationToken cancellationToken = default) + { + if (Settings.Multiplexing) + throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); + + CheckReady(); + + LogMessages.StartingWait(_connectionLogger, timeout, Connector!.Id); + return Connector!.Wait(async: true, timeout, cancellationToken); + } + + /// + /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) + /// arrives, and exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + /// + /// The time-out value as + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// true if an asynchronous message was received, false if timed out. + public Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken = default) => WaitAsync((int)timeout.TotalMilliseconds, cancellationToken); + + /// + /// Waits asynchronously until an asynchronous PostgreSQL messages (e.g. a notification) + /// arrives, and exits immediately. The asynchronous message is delivered via the normal events + /// (, ). + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task WaitAsync(CancellationToken cancellationToken = default) => WaitAsync(0, cancellationToken); + + #endregion + + #region State checks + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void CheckOpen() + { + CheckDisposed(); + + switch (FullState) { - if (ConnectorBindingScope == ConnectorBindingScope.None) - { - Debug.Assert(Connector == null, $"Binding scope is None but {Connector} exists"); - connector = null; - return false; - } - Debug.Assert(Connector != null, $"Binding scope is {ConnectorBindingScope} but {Connector} is null"); - Debug.Assert(Connector.Connection == this, $"Bound connector {Connector} does not reference this connection"); - connector = Connector; - return true; + case ConnectionState.Open: + case ConnectionState.Open | ConnectionState.Executing: + case ConnectionState.Open | ConnectionState.Fetching: + case ConnectionState.Connecting: + return; + case ConnectionState.Closed: + case ConnectionState.Broken: + ThrowHelper.ThrowInvalidOperationException("Connection is not open"); + return; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(); + return; } + } - /// - /// Binds this connection to a physical connector. This happens when opening a non-multiplexing connection, - /// or when starting a transaction on a multiplexed connection. - /// - internal ValueTask StartBindingScope( - ConnectorBindingScope scope, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + void CheckClosed() + { + CheckDisposed(); + + var fullState = FullState; + if (fullState is ConnectionState.Connecting || fullState.HasFlag(ConnectionState.Open)) + ThrowHelper.ThrowInvalidOperationException("Connection already open"); + } + + void CheckDisposed() + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlConnection)); + } + + internal void CheckReady() + { + CheckDisposed(); + + switch (FullState) { - // If the connection is around bound at a higher scope, we do nothing (e.g. copy operation started - // within a transaction on a multiplexing connection). - // Note that if we're in an ambient transaction, that means we're already bound and so we do nothing here. - if (ConnectorBindingScope != ConnectorBindingScope.None) - { - Debug.Assert(Connector != null, $"Connection bound with scope {ConnectorBindingScope} but has no connector"); - Debug.Assert(scope != ConnectorBindingScope, $"Binding scopes aren't reentrant ({ConnectorBindingScope})"); - return new ValueTask(Connector); - } + case ConnectionState.Open: + case ConnectionState.Connecting: // We need to do type loading as part of connecting + return; + case ConnectionState.Closed: + case ConnectionState.Broken: + ThrowHelper.ThrowInvalidOperationException("Connection is not open"); + return; + case ConnectionState.Open | ConnectionState.Executing: + case ConnectionState.Open | ConnectionState.Fetching: + ThrowHelper.ThrowInvalidOperationException("Connection is busy"); + return; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(); + return; + } + } - return StartBindingScopeAsync(); + #endregion State checks - async ValueTask StartBindingScopeAsync() + #region Connector binding + + /// + /// Checks whether the connection is currently bound to a connector, and if so, returns it via + /// . + /// + internal bool TryGetBoundConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + { + if (ConnectorBindingScope == ConnectorBindingScope.None) + { + Debug.Assert(Connector == null, $"Binding scope is None but {Connector} exists"); + connector = null; + return false; + } + Debug.Assert(Connector != null, $"Binding scope is {ConnectorBindingScope} but {Connector} is null"); + Debug.Assert(Connector.Connection == this, $"Bound connector {Connector} does not reference this connection"); + connector = Connector; + return true; + } + + /// + /// Binds this connection to a physical connector. This happens when opening a non-multiplexing connection, + /// or when starting a transaction on a multiplexed connection. + /// + internal ValueTask StartBindingScope( + ConnectorBindingScope scope, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + // If the connection is around bound at a higher scope, we do nothing (e.g. copy operation started + // within a transaction on a multiplexing connection). + // Note that if we're in an ambient transaction, that means we're already bound and so we do nothing here. + if (ConnectorBindingScope != ConnectorBindingScope.None) + { + Debug.Assert(Connector != null, $"Connection bound with scope {ConnectorBindingScope} but has no connector"); + Debug.Assert(scope != ConnectorBindingScope, $"Binding scopes aren't reentrant ({ConnectorBindingScope})"); + return new ValueTask(Connector); + } + + return StartBindingScopeAsync(); + + async ValueTask StartBindingScopeAsync() + { + try { Debug.Assert(Settings.Multiplexing); - Debug.Assert(_pool != null); + Debug.Assert(_dataSource != null); - var connector = await _pool.Rent(this, timeout, async, cancellationToken); + var connector = await _dataSource.Get(this, timeout, async, cancellationToken).ConfigureAwait(false); + Connector = connector; + connector.Connection = this; ConnectorBindingScope = scope; return connector; } + catch + { + FullState = ConnectionState.Broken; + throw; + } } + } - internal NpgsqlConnector StartBindingScope(ConnectorBindingScope scope) - => StartBindingScope(scope, NpgsqlTimeout.Infinite, async: false, CancellationToken.None) - .GetAwaiter().GetResult(); + internal NpgsqlConnector StartBindingScope(ConnectorBindingScope scope) + => StartBindingScope(scope, NpgsqlTimeout.Infinite, async: false, CancellationToken.None) + .GetAwaiter().GetResult(); - internal EndScopeDisposable StartTemporaryBindingScope(out NpgsqlConnector connector) - { - connector = StartBindingScope(ConnectorBindingScope.Temporary); - return new EndScopeDisposable(this); - } + internal EndScopeDisposable StartTemporaryBindingScope(out NpgsqlConnector connector) + { + connector = StartBindingScope(ConnectorBindingScope.Temporary); + return new EndScopeDisposable(this); + } - internal T CheckOpenAndRunInTemporaryScope(Func func) - { - CheckOpen(); + internal T CheckOpenAndRunInTemporaryScope(Func func) + { + CheckOpen(); - using var _ = StartTemporaryBindingScope(out var connector); - var result = func(connector); - return result; - } + using var _ = StartTemporaryBindingScope(out var connector); + var result = func(connector); + return result; + } - /// - /// Ends binding scope to the physical connection and returns it to the pool. Only useful with multiplexing on. - /// - /// - /// After this method is called, under no circumstances the physical connection (connector) should ever be used if multiplexing is on. - /// See #3249. - /// - internal void EndBindingScope(ConnectorBindingScope scope) - { - Debug.Assert(ConnectorBindingScope != ConnectorBindingScope.None, $"Ending binding scope {scope} but connection's scope is null"); + /// + /// Ends binding scope to the physical connection and returns it to the pool. Only useful with multiplexing on. + /// + /// + /// After this method is called, under no circumstances the physical connection (connector) should ever be used if multiplexing is on. + /// See #3249. + /// + internal void EndBindingScope(ConnectorBindingScope scope) + { + Debug.Assert(ConnectorBindingScope != ConnectorBindingScope.None || FullState == ConnectionState.Broken, + $"Ending binding scope {scope} but connection's scope is null"); + + if (scope != ConnectorBindingScope) + return; + + Debug.Assert(Connector != null, $"Ending binding scope {scope} but connector is null"); + Debug.Assert(_dataSource != null, $"Ending binding scope {scope} but _pool is null"); + Debug.Assert(Settings.Multiplexing, $"Ending binding scope {scope} but multiplexing is disabled"); + + // TODO: If enlisted transaction scope is still active, need to AddPendingEnlistedConnector, just like Close + var connector = Connector; + Connector = null; + connector.Connection = null; + connector.Transaction?.UnbindIfNecessary(); + connector.Return(); + ConnectorBindingScope = ConnectorBindingScope.None; + } - if (scope != ConnectorBindingScope) - return; + #endregion Connector binding - Debug.Assert(Connector != null, $"Ending binding scope {scope} but connector is null"); - Debug.Assert(_pool != null, $"Ending binding scope {scope} but _pool is null"); - Debug.Assert(Settings.Multiplexing, $"Ending binding scope {scope} but multiplexing is disabled"); + #region Schema operations - // TODO: If enlisted transaction scope is still active, need to AddPendingEnlistedConnector, just like Close - var connector = Connector; - Connector = null; - connector.Connection = null; - connector.Transaction?.UnbindIfNecessary(); - _pool.Return(connector); - ConnectorBindingScope = ConnectorBindingScope.None; - } + /// + /// Returns the supported collections + /// + public override DataTable GetSchema() + => GetSchema("MetaDataCollections", null); + + /// + /// Returns the schema collection specified by the collection name. + /// + /// The collection name. + /// The collection specified. + public override DataTable GetSchema(string? collectionName) => GetSchema(collectionName, null); + + /// + /// Returns the schema collection specified by the collection name filtered by the restrictions. + /// + /// The collection name. + /// + /// The restriction values to filter the results. A description of the restrictions is contained + /// in the Restrictions collection. + /// + /// The collection specified. + public override DataTable GetSchema(string? collectionName, string?[]? restrictions) + => NpgsqlSchema.GetSchema(async: false, this, collectionName, restrictions).GetAwaiter().GetResult(); - #endregion Connector binding - - #region Schema operations - - /// - /// Returns the supported collections - /// - public override DataTable GetSchema() - => GetSchema("MetaDataCollections", null); - - /// - /// Returns the schema collection specified by the collection name. - /// - /// The collection name. - /// The collection specified. - public override DataTable GetSchema(string? collectionName) => GetSchema(collectionName, null); - - /// - /// Returns the schema collection specified by the collection name filtered by the restrictions. - /// - /// The collection name. - /// - /// The restriction values to filter the results. A description of the restrictions is contained - /// in the Restrictions collection. - /// - /// The collection specified. - public override DataTable GetSchema(string? collectionName, string?[]? restrictions) - => NpgsqlSchema.GetSchema(this, collectionName, restrictions, async: false).GetAwaiter().GetResult(); - - /// - /// Asynchronously returns the supported collections. - /// - /// The token to monitor for cancellation requests. The default value is None. - /// The collection specified. -#if NET - public override Task GetSchemaAsync(CancellationToken cancellationToken = default) + /// + /// Asynchronously returns the supported collections. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The collection specified. +#if NET5_0_OR_GREATER + public override Task GetSchemaAsync(CancellationToken cancellationToken = default) #else - public Task GetSchemaAsync(CancellationToken cancellationToken = default) + public Task GetSchemaAsync(CancellationToken cancellationToken = default) #endif - => GetSchemaAsync("MetaDataCollections", null, cancellationToken); - - /// - /// Asynchronously returns the schema collection specified by the collection name. - /// - /// The collection name. - /// The token to monitor for cancellation requests. The default value is None. - /// The collection specified. -#if NET - public override Task GetSchemaAsync(string collectionName, CancellationToken cancellationToken = default) + => GetSchemaAsync("MetaDataCollections", null, cancellationToken); + + /// + /// Asynchronously returns the schema collection specified by the collection name. + /// + /// The collection name. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The collection specified. +#if NET5_0_OR_GREATER + public override Task GetSchemaAsync(string collectionName, CancellationToken cancellationToken = default) #else - public Task GetSchemaAsync(string collectionName, CancellationToken cancellationToken = default) + public Task GetSchemaAsync(string collectionName, CancellationToken cancellationToken = default) #endif - => GetSchemaAsync(collectionName, null, cancellationToken); - - /// - /// Asynchronously returns the schema collection specified by the collection name filtered by the restrictions. - /// - /// The collection name. - /// - /// The restriction values to filter the results. A description of the restrictions is contained - /// in the Restrictions collection. - /// - /// The token to monitor for cancellation requests. The default value is None. - /// The collection specified. -#if NET - public override Task GetSchemaAsync(string collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) + => GetSchemaAsync(collectionName, null, cancellationToken); + + /// + /// Asynchronously returns the schema collection specified by the collection name filtered by the restrictions. + /// + /// The collection name. + /// + /// The restriction values to filter the results. A description of the restrictions is contained + /// in the Restrictions collection. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The collection specified. +#if NET5_0_OR_GREATER + public override Task GetSchemaAsync(string collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) #else - public Task GetSchemaAsync(string collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) + public Task GetSchemaAsync(string collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return NpgsqlSchema.GetSchema(this, collectionName, restrictions, async: true, cancellationToken); - } + { + return NpgsqlSchema.GetSchema(async: true, this, collectionName, restrictions, cancellationToken); + } - #endregion Schema operations + #endregion Schema operations - #region Misc + #region Misc - /// - /// Creates a closed connection with the connection string and authentication details of this message. - /// - object ICloneable.Clone() - { - CheckDisposed(); - var conn = new NpgsqlConnection(_connectionString) { - ProvideClientCertificatesCallback = ProvideClientCertificatesCallback, - UserCertificateValidationCallback = UserCertificateValidationCallback, - ProvidePasswordCallback = ProvidePasswordCallback, - _userFacingConnectionString = _userFacingConnectionString - }; - return conn; - } + /// + /// Creates a closed connection with the connection string and authentication details of this message. + /// + object ICloneable.Clone() + { + CheckDisposed(); + + // For NativeAOT code size reduction, we avoid instantiating a connection here directly with + // `new NpgsqlConnection(_connectionString)`, since that would bring in the default data source builder, and with it various + // features which significantly increase binary size (ranges, System.Text.Json...). Instead, we pass through a "cloning + // instantiator" abstraction, where the implementation only ever gets set if SetupDataSource above is called (in which case the + // default data source is brought in anyway). + Debug.Assert(_dataSource is not null || _cloningInstantiator is not null); + var conn = _dataSource is null + ? _cloningInstantiator!(_connectionString) + : _dataSource.CreateConnection(); + + conn.ProvideClientCertificatesCallback = ProvideClientCertificatesCallback; + conn.UserCertificateValidationCallback = UserCertificateValidationCallback; +#pragma warning disable CS0618 // Obsolete + conn.ProvidePasswordCallback = ProvidePasswordCallback; +#pragma warning restore CS0618 + conn._userFacingConnectionString = _userFacingConnectionString; + + return conn; + } - /// - /// Clones this connection, replacing its connection string with the given one. - /// This allows creating a new connection with the same security information - /// (password, SSL callbacks) while changing other connection parameters (e.g. - /// database or pooling) - /// - public NpgsqlConnection CloneWith(string connectionString) - { - CheckDisposed(); - var csb = new NpgsqlConnectionStringBuilder(connectionString); - if (csb.Password == null && Password != null) - csb.Password = Password; - if (csb.PersistSecurityInfo && !Settings.PersistSecurityInfo) - csb.PersistSecurityInfo = false; - return new NpgsqlConnection(csb.ToString()) { - ProvideClientCertificatesCallback = ProvideClientCertificatesCallback, - UserCertificateValidationCallback = UserCertificateValidationCallback, - ProvidePasswordCallback = ProvidePasswordCallback, - }; - } + /// + /// Clones this connection, replacing its connection string with the given one. + /// This allows creating a new connection with the same security information + /// (password, SSL callbacks) while changing other connection parameters (e.g. + /// database or pooling) + /// + public NpgsqlConnection CloneWith(string connectionString) + { + CheckDisposed(); + var csb = new NpgsqlConnectionStringBuilder(connectionString); + csb.Password ??= _dataSource?.GetPassword(async: false).GetAwaiter().GetResult(); + if (csb.PersistSecurityInfo && !Settings.PersistSecurityInfo) + csb.PersistSecurityInfo = false; - /// - /// This method changes the current database by disconnecting from the actual - /// database and connecting to the specified. - /// - /// The name of the database to use in place of the current database. - public override void ChangeDatabase(string dbName) + return new NpgsqlConnection(csb.ToString()) { - if (dbName == null) - throw new ArgumentNullException(nameof(dbName)); - if (string.IsNullOrEmpty(dbName)) - throw new ArgumentOutOfRangeException(nameof(dbName), dbName, $"Invalid database name: {dbName}"); + ProvideClientCertificatesCallback = + ProvideClientCertificatesCallback ?? + (_dataSource?.ClientCertificatesCallback is { } clientCertificatesCallback + ? (ProvideClientCertificatesCallback)(certs => clientCertificatesCallback(certs)) + : null), + UserCertificateValidationCallback = UserCertificateValidationCallback ?? _dataSource?.UserCertificateValidationCallback, +#pragma warning disable CS0618 // Obsolete + ProvidePasswordCallback = ProvidePasswordCallback, +#pragma warning restore CS0618 + }; + } - CheckOpen(); - Close(); + /// + /// This method changes the current database by disconnecting from the actual + /// database and connecting to the specified. + /// + /// The name of the database to use in place of the current database. + public override void ChangeDatabase(string dbName) + { + if (dbName == null) + throw new ArgumentNullException(nameof(dbName)); + if (string.IsNullOrEmpty(dbName)) + throw new ArgumentOutOfRangeException(nameof(dbName), dbName, $"Invalid database name: {dbName}"); - _pool = null; - Settings = Settings.Clone(); - Settings.Database = dbName; - ConnectionString = Settings.ToString(); + CheckOpen(); + Close(); - Open(); - } + _dataSource = null; + Settings = Settings.Clone(); + Settings.Database = dbName; + ConnectionString = Settings.ToString(); - /// - /// DB provider factory. - /// - protected override DbProviderFactory DbProviderFactory => NpgsqlFactory.Instance; - - /// - /// Clears the connection pool. All idle physical connections in the pool of the given connection are - /// immediately closed, and any busy connections which were opened before was called - /// will be closed when returned to the pool. - /// - public static void ClearPool(NpgsqlConnection connection) => PoolManager.Clear(connection._connectionString); - - /// - /// Clear all connection pools. All idle physical connections in all pools are immediately closed, and any busy - /// connections which were opened before was called will be closed when returned - /// to their pool. - /// - public static void ClearAllPools() => PoolManager.ClearAll(); - - /// - /// Unprepares all prepared statements on this connection. - /// - public void UnprepareAll() - { - if (Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); + Open(); + } - CheckReady(); + /// + /// DB provider factory. + /// + protected override DbProviderFactory DbProviderFactory => NpgsqlFactory.Instance; - using (Connector!.StartUserAction()) - Connector.UnprepareAll(); - } + /// + /// Clears the connection pool. All idle physical connections in the pool of the given connection are + /// immediately closed, and any busy connections which were opened before was called + /// will be closed when returned to the pool. + /// + public static void ClearPool(NpgsqlConnection connection) => PoolManager.Clear(connection._connectionString); - /// - /// Flushes the type cache for this connection's connection string and reloads the types for this connection only. - /// Type changes will appear for other connections only after they are re-opened from the pool. - /// - public void ReloadTypes() - { - CheckReady(); - using var scope = StartTemporaryBindingScope(out var connector); - connector.LoadDatabaseInfo( - forceReload: true, - NpgsqlTimeout.Infinite, - async: false, - CancellationToken.None).GetAwaiter().GetResult(); - // Increment the change counter on the global type mapper. This will make conn.Open() pick up the - // new DatabaseInfo and set up a new connection type mapper - TypeMapping.GlobalTypeMapper.Instance.RecordChange(); - } + /// + /// Clear all connection pools. All idle physical connections in all pools are immediately closed, and any busy + /// connections which were opened before was called will be closed when returned + /// to their pool. + /// + public static void ClearAllPools() => PoolManager.ClearAll(); + + /// + /// Unprepares all prepared statements on this connection. + /// + public void UnprepareAll() + { + if (Settings.Multiplexing) + throw new NotSupportedException("Explicit preparation not supported with multiplexing"); + + CheckReady(); - #endregion Misc + using (Connector!.StartUserAction()) + Connector.UnprepareAll(); } - enum ConnectorBindingScope + /// + /// Flushes the type cache for this connection's connection string and reloads the types for this connection only. + /// Type changes will appear for other connections only after they are re-opened from the pool. + /// + public void ReloadTypes() { - /// - /// The connection is currently not bound to a connector. - /// - None, - - /// - /// The connection is bound to its connector for the scope of the entire connection - /// (i.e. non-multiplexed connection). - /// - Connection, - - /// - /// The connection is bound to its connector for the scope of a transaction. - /// - Transaction, - - /// - /// The connection is bound to its connector for the scope of a COPY operation. - /// - Copy, - - /// - /// The connection is bound to its connector for the scope of a single reader. - /// - Reader, - - /// - /// The connection is bound to its connector for the scope of establishing a new physical connection. - /// - PhysicalConnecting, - - /// - /// The connection is bound to its connector for an unspecified, temporary scope; the code that initiated - /// the binding is also responsible to unbind it. - /// - Temporary + CheckReady(); + + using var scope = StartTemporaryBindingScope(out var connector); + + _dataSource!.Bootstrap( + connector, + NpgsqlTimeout.Infinite, + forceReload: true, + async: false, + CancellationToken.None) + .GetAwaiter().GetResult(); } - readonly struct EndScopeDisposable : IDisposable + /// + /// Flushes the type cache for this connection's connection string and reloads the types for this connection only. + /// Type changes will appear for other connections only after they are re-opened from the pool. + /// + public async Task ReloadTypesAsync() + { + CheckReady(); + + using var scope = StartTemporaryBindingScope(out var connector); + + await _dataSource!.Bootstrap( + connector, + NpgsqlTimeout.Infinite, + forceReload: true, + async: true, + CancellationToken.None).ConfigureAwait(false); + } + + /// + /// This event is unsupported by Npgsql. Use instead. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public new event EventHandler? Disposed + { + add => throw new NotSupportedException("The Disposed event isn't supported by Npgsql. Use DbConnection.StateChange instead."); + remove => throw new NotSupportedException("The Disposed event isn't supported by Npgsql. Use DbConnection.StateChange instead."); + } + + event EventHandler? IComponent.Disposed { - readonly NpgsqlConnection _connection; - public EndScopeDisposable(NpgsqlConnection connection) => _connection = connection; - public void Dispose() => _connection.EndBindingScope(ConnectorBindingScope.Temporary); + add => Disposed += value; + remove => Disposed -= value; } - #region Delegates + #endregion Misc +} +enum ConnectorBindingScope +{ /// - /// Represents a method that handles the event. + /// The connection is currently not bound to a connector. /// - /// The source of the event. - /// A that contains the notice information (e.g. message, severity...). - public delegate void NoticeEventHandler(object sender, NpgsqlNoticeEventArgs e); + None, /// - /// Represents a method that handles the event. + /// The connection is bound to its connector for the scope of the entire connection + /// (i.e. non-multiplexed connection). /// - /// The source of the event. - /// A that contains the notification payload. - public delegate void NotificationEventHandler(object sender, NpgsqlNotificationEventArgs e); + Connection, /// - /// Represents the method that allows the application to provide a certificate collection to be used for SSL client authentication + /// The connection is bound to its connector for the scope of a transaction. /// - /// A X509CertificateCollection to be filled with one or more client certificates. - public delegate void ProvideClientCertificatesCallback(X509CertificateCollection certificates); + Transaction, /// - /// Represents the method that allows the application to provide a password at connection time in code rather than configuration + /// The connection is bound to its connector for the scope of a COPY operation. /// - /// Hostname - /// Port - /// Database Name - /// User - /// A valid password for connecting to the database - public delegate string ProvidePasswordCallback(string host, int port, string database, string username); + Copy, - #endregion + /// + /// The connection is bound to its connector for the scope of a single reader. + /// + Reader, + + /// + /// The connection is bound to its connector for an unspecified, temporary scope; the code that initiated + /// the binding is also responsible to unbind it. + /// + Temporary +} + +readonly struct EndScopeDisposable : IDisposable +{ + readonly NpgsqlConnection _connection; + public EndScopeDisposable(NpgsqlConnection connection) => _connection = connection; + public void Dispose() => _connection.EndBindingScope(ConnectorBindingScope.Temporary); } + +#region Delegates + +/// +/// Represents a method that handles the event. +/// +/// The source of the event. +/// A that contains the notice information (e.g. message, severity...). +public delegate void NoticeEventHandler(object sender, NpgsqlNoticeEventArgs e); + +/// +/// Represents a method that handles the event. +/// +/// The source of the event. +/// A that contains the notification payload. +public delegate void NotificationEventHandler(object sender, NpgsqlNotificationEventArgs e); + +/// +/// Represents a method that allows the application to provide a certificate collection to be used for SSL client authentication +/// +/// +/// A to be filled with one or more client +/// certificates. +/// +public delegate void ProvideClientCertificatesCallback(X509CertificateCollection certificates); + +/// +/// Represents a method that allows the application to provide a password at connection time in code rather than configuration +/// +/// Hostname +/// Port +/// Database Name +/// User +/// A valid password for connecting to the database +[Obsolete("Use NpgsqlDataSourceBuilder.UsePeriodicPasswordProvider or inject passwords directly into NpgsqlDataSource.Password")] +public delegate string ProvidePasswordCallback(string host, int port, string database, string username); + +#endregion diff --git a/src/Npgsql/NpgsqlConnectionStringBuilder.cs b/src/Npgsql/NpgsqlConnectionStringBuilder.cs index e9d743c8fd..0047510590 100644 --- a/src/Npgsql/NpgsqlConnectionStringBuilder.cs +++ b/src/Npgsql/NpgsqlConnectionStringBuilder.cs @@ -6,1689 +6,1727 @@ using System.Data.Common; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; -using System.Text; +using Npgsql.Internal; +using Npgsql.Netstandard20; using Npgsql.Replication; -namespace Npgsql +namespace Npgsql; + +/// +/// Provides a simple way to create and manage the contents of connection strings used by +/// the class. +/// +[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2112:ReflectionToRequiresUnreferencedCode", + Justification = "Suppressing the same warnings as suppressed in the base DbConnectionStringBuilder. See https://github.com/dotnet/runtime/issues/97057")] +[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2113:ReflectionToRequiresUnreferencedCode", + Justification = "Suppressing the same warnings as suppressed in the base DbConnectionStringBuilder. See https://github.com/dotnet/runtime/issues/97057")] +public sealed partial class NpgsqlConnectionStringBuilder : DbConnectionStringBuilder, IDictionary { + #region Fields + /// - /// Provides a simple way to create and manage the contents of connection strings used by - /// the class. + /// Cached DataSource value to reduce allocations on NpgsqlConnection.DataSource.get /// - public sealed class NpgsqlConnectionStringBuilder : DbConnectionStringBuilder, IDictionary - { - #region Fields - - /// - /// Makes all valid keywords for a property to that property (e.g. User Name -> Username, UserId -> Username...) - /// - static readonly Dictionary PropertiesByKeyword; + string? _dataSourceCached; - /// - /// Maps CLR property names (e.g. BufferSize) to their canonical keyword name, which is the - /// property's [DisplayName] (e.g. Buffer Size) - /// - static readonly Dictionary PropertyNameToCanonicalKeyword; + internal string? DataSourceCached + => _dataSourceCached ??= _host is null || _host.Contains(",") + ? null + : IsUnixSocket(_host, _port, out var socketPath, replaceForAbstract: false) + ? socketPath + : $"tcp://{_host}:{_port}"; - /// - /// Maps each property to its [DefaultValue] - /// - static readonly Dictionary PropertyDefaults; + // Note that we can't cache the result due to nullable's assignment not being thread safe + internal TimeSpan HostRecheckSecondsTranslated + => TimeSpan.FromSeconds(HostRecheckSeconds == 0 ? -1 : HostRecheckSeconds); - /// - /// Cached DataSource value to reduce allocations on NpgsqlConnection.DataSource.get - /// - string? _dataSourceCached; - - internal string DataSourceCached - => _dataSourceCached ??= _host is null - ? string.Empty - : Path.IsPathRooted(_host) - ? Path.Combine(_host, $".s.PGSQL.{_port}") - : $"tcp://{_host}:{_port}"; + #endregion - #endregion + #region Constructors - #region Constructors + /// + /// Initializes a new instance of the NpgsqlConnectionStringBuilder class. + /// + public NpgsqlConnectionStringBuilder() => Init(); - /// - /// Initializes a new instance of the NpgsqlConnectionStringBuilder class. - /// - public NpgsqlConnectionStringBuilder() { Init(); } + /// + /// Initializes a new instance of the NpgsqlConnectionStringBuilder class, optionally using ODBC rules for quoting values. + /// + /// true to use {} to delimit fields; false to use quotation marks. + public NpgsqlConnectionStringBuilder(bool useOdbcRules) : base(useOdbcRules) => Init(); - /// - /// Initializes a new instance of the NpgsqlConnectionStringBuilder class, optionally using ODBC rules for quoting values. - /// - /// true to use {} to delimit fields; false to use quotation marks. - public NpgsqlConnectionStringBuilder(bool useOdbcRules) : base(useOdbcRules) { Init(); } + /// + /// Initializes a new instance of the NpgsqlConnectionStringBuilder class and sets its . + /// + public NpgsqlConnectionStringBuilder(string? connectionString) + { + Init(); + ConnectionString = connectionString; + } - /// - /// Initializes a new instance of the NpgsqlConnectionStringBuilder class and sets its . - /// - public NpgsqlConnectionStringBuilder(string? connectionString) - { - Init(); - ConnectionString = connectionString; - } + // Method fake-returns an int only to make sure it's code-generated + private partial int Init(); - void Init() - { - // Set the strongly-typed properties to their default values - foreach (var kv in PropertyDefaults) - kv.Key.SetValue(this, kv.Value); - // Setting the strongly-typed properties here also set the string-based properties in the base class. - // Clear them (default settings = empty connection string) - base.Clear(); - } + /// + /// GeneratedAction and GeneratedActions exist to be able to produce a streamlined binary footprint for NativeAOT. + /// An idiomatic approach where each action has its own method would double the binary size of NpgsqlConnectionStringBuilder. + /// + enum GeneratedAction + { + Set, + Get, + Remove, + GetCanonical + } + private partial bool GeneratedActions(GeneratedAction action, string keyword, ref object? value); - #endregion + #endregion - #region Static initialization + #region Non-static property handling - static NpgsqlConnectionStringBuilder() + /// + /// Gets or sets the value associated with the specified key. + /// + /// The key of the item to get or set. + /// The value associated with the specified key. + [AllowNull] + public override object this[string keyword] + { + get { - var properties = typeof(NpgsqlConnectionStringBuilder) - .GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic) - .Where(p => p.GetCustomAttribute() != null) - .ToArray(); - - Debug.Assert(properties.All(p => p.CanRead && p.CanWrite)); - Debug.Assert(properties.All(p => p.GetCustomAttribute() != null)); - - PropertiesByKeyword = ( - from p in properties - let displayName = p.GetCustomAttribute()!.DisplayName.ToUpperInvariant() - let propertyName = p.Name.ToUpperInvariant() - from k in new[] { displayName } - .Concat(propertyName != displayName ? new[] { propertyName } : EmptyStringArray ) - .Concat(p.GetCustomAttribute()!.Synonyms - .Select(a => a.ToUpperInvariant()) - ) - .Select(k => new { Property = p, Keyword = k }) - select k - ).ToDictionary(t => t.Keyword, t => t.Property); - - PropertyNameToCanonicalKeyword = properties.ToDictionary( - p => p.Name, - p => p.GetCustomAttribute()!.DisplayName - ); - - PropertyDefaults = properties - .Where(p => p.GetCustomAttribute() == null) - .ToDictionary( - p => p, - p => p.GetCustomAttribute() != null - ? p.GetCustomAttribute()!.Value - : (p.PropertyType.GetTypeInfo().IsValueType ? Activator.CreateInstance(p.PropertyType) : null) - ); - } - - #endregion - - #region Non-static property handling - - /// - /// Gets or sets the value associated with the specified key. - /// - /// The key of the item to get or set. - /// The value associated with the specified key. - [AllowNull] - public override object this[string keyword] + if (!TryGetValue(keyword, out var value)) + throw new ArgumentException("Keyword not supported: " + keyword, nameof(keyword)); + return value; + } + set { - get + if (value is null) { - if (!TryGetValue(keyword, out var value)) - throw new ArgumentException("Keyword not supported: " + keyword, nameof(keyword)); - return value; + Remove(keyword); + return; } - set + + try + { + var val = value; + GeneratedActions(GeneratedAction.Set, keyword.ToUpperInvariant(), ref val); + } + catch (Exception e) { - if (value is null) - { - Remove(keyword); - return; - } - - var p = GetProperty(keyword); - try - { - var convertedValue = p.PropertyType.GetTypeInfo().IsEnum && value is string str - ? Enum.Parse(p.PropertyType, str) - : Convert.ChangeType(value, p.PropertyType); - p.SetValue(this, convertedValue); - } - catch (Exception e) - { - throw new ArgumentException("Couldn't set " + keyword, keyword, e); - } + throw new ArgumentException("Couldn't set " + keyword, keyword, e); } } + } - object? IDictionary.this[string keyword] - { - get => this[keyword]; - set => this[keyword] = value!; - } + object? IDictionary.this[string keyword] + { + get => this[keyword]; + set => this[keyword] = value!; + } - /// - /// Adds an item to the . - /// - /// The key-value pair to be added. - public void Add(KeyValuePair item) - => this[item.Key] = item.Value!; + /// + /// Adds an item to the . + /// + /// The key-value pair to be added. + public void Add(KeyValuePair item) + => this[item.Key] = item.Value!; - void IDictionary.Add(string keyword, object? value) - => this[keyword] = value; + void IDictionary.Add(string keyword, object? value) + => this[keyword] = value; - /// - /// Removes the entry with the specified key from the DbConnectionStringBuilder instance. - /// - /// The key of the key/value pair to be removed from the connection string in this DbConnectionStringBuilder. - /// true if the key existed within the connection string and was removed; false if the key did not exist. - public override bool Remove(string keyword) - { - var p = GetProperty(keyword); - var canonicalName = PropertyNameToCanonicalKeyword[p.Name]; - var removed = base.ContainsKey(canonicalName); - // Note that string property setters call SetValue, which itself calls base.Remove(). - p.SetValue(this, PropertyDefaults[p]); - base.Remove(canonicalName); - return removed; - } + /// + /// Removes the entry with the specified key from the DbConnectionStringBuilder instance. + /// + /// The key of the key/value pair to be removed from the connection string in this DbConnectionStringBuilder. + /// true if the key existed within the connection string and was removed; false if the key did not exist. + public override bool Remove(string keyword) + { + object? value = null; + return GeneratedActions(GeneratedAction.Remove, keyword.ToUpperInvariant(), ref value); + } - /// - /// Removes the entry from the DbConnectionStringBuilder instance. - /// - /// The key/value pair to be removed from the connection string in this DbConnectionStringBuilder. - /// true if the key existed within the connection string and was removed; false if the key did not exist. - public bool Remove(KeyValuePair item) - => Remove(item.Key); + /// + /// Removes the entry from the DbConnectionStringBuilder instance. + /// + /// The key/value pair to be removed from the connection string in this DbConnectionStringBuilder. + /// true if the key existed within the connection string and was removed; false if the key did not exist. + public bool Remove(KeyValuePair item) + => Remove(item.Key); - /// - /// Clears the contents of the instance. - /// - public override void Clear() - { - Debug.Assert(Keys != null); - foreach (var k in Keys.ToArray()) { - Remove(k); - } - } + /// + /// Clears the contents of the instance. + /// + public override void Clear() + { + Debug.Assert(Keys != null); + foreach (var k in (string[])Keys) + Remove(k); + } - /// - /// Determines whether the contains a specific key. - /// - /// The key to locate in the . - /// true if the contains an entry with the specified key; otherwise false. - public override bool ContainsKey(string keyword) - => keyword is null - ? throw new ArgumentNullException(nameof(keyword)) - : PropertiesByKeyword.ContainsKey(keyword.ToUpperInvariant()); - - /// - /// Determines whether the contains a specific key-value pair. - /// - /// The item to locate in the . - /// true if the contains the entry; otherwise false. - public bool Contains(KeyValuePair item) - => TryGetValue(item.Key, out var value) && - ((value == null && item.Value == null) || (value != null && value.Equals(item.Value))); - - PropertyInfo GetProperty(string keyword) - => PropertiesByKeyword.TryGetValue(keyword.ToUpperInvariant(), out var p) - ? p - : throw new ArgumentException("Keyword not supported: " + keyword, nameof(keyword)); - - /// - /// Retrieves a value corresponding to the supplied key from this . - /// - /// The key of the item to retrieve. - /// The value corresponding to the key. - /// true if keyword was found within the connection string, false otherwise. - public override bool TryGetValue(string keyword, [NotNullWhen(true)] out object? value) - { - if (keyword == null) - throw new ArgumentNullException(nameof(keyword)); - - if (!PropertiesByKeyword.ContainsKey(keyword.ToUpperInvariant())) - { - value = null; - return false; - } + /// + /// Determines whether the contains a specific key. + /// + /// The key to locate in the . + /// true if the contains an entry with the specified key; otherwise false. + public override bool ContainsKey(string keyword) + { + object? value = null; + return GeneratedActions(GeneratedAction.GetCanonical, (keyword ?? throw new ArgumentNullException(nameof(keyword))).ToUpperInvariant(), ref value); + } - value = GetProperty(keyword).GetValue(this) ?? ""; - return true; + /// + /// Determines whether the contains a specific key-value pair. + /// + /// The item to locate in the . + /// true if the contains the entry; otherwise false. + public bool Contains(KeyValuePair item) + => TryGetValue(item.Key, out var value) && + ((value == null && item.Value == null) || (value != null && value.Equals(item.Value))); - } + /// + /// Retrieves a value corresponding to the supplied key from this . + /// + /// The key of the item to retrieve. + /// The value corresponding to the key. + /// true if keyword was found within the connection string, false otherwise. + public override bool TryGetValue(string keyword, [NotNullWhen(true)] out object? value) + { + object? v = null; + var result = GeneratedActions(GeneratedAction.Get, (keyword ?? throw new ArgumentNullException(nameof(keyword))).ToUpperInvariant(), ref v); + value = v; + return result; + } - void SetValue(string propertyName, object? value) - { - var canonicalKeyword = PropertyNameToCanonicalKeyword[propertyName]; - if (value == null) - base.Remove(canonicalKeyword); - else - base[canonicalKeyword] = value; - } + void SetValue(string propertyName, object? value) + { + object? canonicalKeyword = null; + var result = GeneratedActions(GeneratedAction.GetCanonical, (propertyName ?? throw new ArgumentNullException(nameof(propertyName))).ToUpperInvariant(), ref canonicalKeyword); + if (!result) + throw new KeyNotFoundException(); + if (value == null) + base.Remove((string)canonicalKeyword!); + else + base[(string)canonicalKeyword!] = value; + } - #endregion + #endregion - #region Properties - Connection + #region Properties - Connection - /// - /// The hostname or IP address of the PostgreSQL server to connect to. - /// - [Category("Connection")] - [Description("The hostname or IP address of the PostgreSQL server to connect to.")] - [DisplayName("Host")] - [NpgsqlConnectionStringProperty("Server")] - public string? Host + /// + /// The hostname or IP address of the PostgreSQL server to connect to. + /// + [Category("Connection")] + [Description("The hostname or IP address of the PostgreSQL server to connect to.")] + [DisplayName("Host")] + [NpgsqlConnectionStringProperty("Server")] + public string? Host + { + get => _host; + set { - get => _host; - set - { - _host = value; - SetValue(nameof(Host), value); - _dataSourceCached = null; - } + _host = value; + SetValue(nameof(Host), value); + _dataSourceCached = null; } - string? _host; + } + string? _host; - /// - /// The TCP/IP port of the PostgreSQL server. - /// - [Category("Connection")] - [Description("The TCP port of the PostgreSQL server.")] - [DisplayName("Port")] - [NpgsqlConnectionStringProperty] - [DefaultValue(NpgsqlConnection.DefaultPort)] - public int Port + /// + /// The TCP/IP port of the PostgreSQL server. + /// + [Category("Connection")] + [Description("The TCP port of the PostgreSQL server.")] + [DisplayName("Port")] + [NpgsqlConnectionStringProperty] + [DefaultValue(NpgsqlConnection.DefaultPort)] + public int Port + { + get => _port; + set { - get => _port; - set - { - if (value <= 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "Invalid port: " + value); + if (value <= 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "Invalid port: " + value); - _port = value; - SetValue(nameof(Port), value); - _dataSourceCached = null; - } + _port = value; + SetValue(nameof(Port), value); + _dataSourceCached = null; } - int _port; + } + int _port; - /// - /// The PostgreSQL database to connect to. - /// - [Category("Connection")] - [Description("The PostgreSQL database to connect to.")] - [DisplayName("Database")] - [NpgsqlConnectionStringProperty("DB")] - public string? Database + /// + /// The PostgreSQL database to connect to. + /// + [Category("Connection")] + [Description("The PostgreSQL database to connect to.")] + [DisplayName("Database")] + [NpgsqlConnectionStringProperty("DB")] + public string? Database + { + get => _database; + set { - get => _database; - set - { - _database = value; - SetValue(nameof(Database), value); - } + _database = value; + SetValue(nameof(Database), value); } - string? _database; + } + string? _database; - /// - /// The username to connect with. Not required if using IntegratedSecurity. - /// - [Category("Connection")] - [Description("The username to connect with. Not required if using IntegratedSecurity.")] - [DisplayName("Username")] - [NpgsqlConnectionStringProperty("User Name", "UserId", "User Id", "UID")] - public string? Username + /// + /// The username to connect with. + /// + [Category("Connection")] + [Description("The username to connect with.")] + [DisplayName("Username")] + [NpgsqlConnectionStringProperty("User Name", "UserId", "User Id", "UID")] + public string? Username + { + get => _username; + set { - get => _username; - set - { - _username = value; - SetValue(nameof(Username), value); - } + _username = value; + SetValue(nameof(Username), value); } - string? _username; + } + string? _username; - /// - /// The password to connect with. Not required if using IntegratedSecurity. - /// - [Category("Connection")] - [Description("The password to connect with. Not required if using IntegratedSecurity.")] - [PasswordPropertyText(true)] - [DisplayName("Password")] - [NpgsqlConnectionStringProperty("PSW", "PWD")] - public string? Password + /// + /// The password to connect with. + /// + [Category("Connection")] + [Description("The password to connect with.")] + [PasswordPropertyText(true)] + [DisplayName("Password")] + [NpgsqlConnectionStringProperty("PSW", "PWD")] + public string? Password + { + get => _password; + set { - get => _password; - set - { - _password = value; - SetValue(nameof(Password), value); - } + _password = value; + SetValue(nameof(Password), value); } - string? _password; + } + string? _password; - /// - /// Path to a PostgreSQL password file (PGPASSFILE), from which the password would be taken. - /// - [Category("Connection")] - [Description("Path to a PostgreSQL password file (PGPASSFILE), from which the password would be taken.")] - [DisplayName("Passfile")] - [NpgsqlConnectionStringProperty] - public string? Passfile + /// + /// Path to a PostgreSQL password file (PGPASSFILE), from which the password would be taken. + /// + [Category("Connection")] + [Description("Path to a PostgreSQL password file (PGPASSFILE), from which the password would be taken.")] + [DisplayName("Passfile")] + [NpgsqlConnectionStringProperty] + public string? Passfile + { + get => _passfile; + set { - get => _passfile; - set - { - _passfile = value; - SetValue(nameof(Passfile), value); - } + _passfile = value; + SetValue(nameof(Passfile), value); } + } - string? _passfile; + string? _passfile; - /// - /// The optional application name parameter to be sent to the backend during connection initiation. - /// - [Category("Connection")] - [Description("The optional application name parameter to be sent to the backend during connection initiation")] - [DisplayName("Application Name")] - [NpgsqlConnectionStringProperty] - public string? ApplicationName + /// + /// The optional application name parameter to be sent to the backend during connection initiation. + /// + [Category("Connection")] + [Description("The optional application name parameter to be sent to the backend during connection initiation")] + [DisplayName("Application Name")] + [NpgsqlConnectionStringProperty] + public string? ApplicationName + { + get => _applicationName; + set { - get => _applicationName; - set - { - _applicationName = value; - SetValue(nameof(ApplicationName), value); - } + _applicationName = value; + SetValue(nameof(ApplicationName), value); } - string? _applicationName; + } + string? _applicationName; - /// - /// Whether to enlist in an ambient TransactionScope. - /// - [Category("Connection")] - [Description("Whether to enlist in an ambient TransactionScope.")] - [DisplayName("Enlist")] - [DefaultValue(true)] - [NpgsqlConnectionStringProperty] - public bool Enlist + /// + /// Whether to enlist in an ambient TransactionScope. + /// + [Category("Connection")] + [Description("Whether to enlist in an ambient TransactionScope.")] + [DisplayName("Enlist")] + [DefaultValue(true)] + [NpgsqlConnectionStringProperty] + public bool Enlist + { + get => _enlist; + set { - get => _enlist; - set - { - _enlist = value; - SetValue(nameof(Enlist), value); - } + _enlist = value; + SetValue(nameof(Enlist), value); } - bool _enlist; + } + bool _enlist; - /// - /// Gets or sets the schema search path. - /// - [Category("Connection")] - [Description("Gets or sets the schema search path.")] - [DisplayName("Search Path")] - [NpgsqlConnectionStringProperty] - public string? SearchPath + /// + /// Gets or sets the schema search path. + /// + [Category("Connection")] + [Description("Gets or sets the schema search path.")] + [DisplayName("Search Path")] + [NpgsqlConnectionStringProperty] + public string? SearchPath + { + get => _searchPath; + set { - get => _searchPath; - set - { - _searchPath = value; - SetValue(nameof(SearchPath), value); - } + _searchPath = value; + SetValue(nameof(SearchPath), value); } - string? _searchPath; + } + string? _searchPath; - /// - /// Gets or sets the client_encoding parameter. - /// - [Category("Connection")] - [Description("Gets or sets the client_encoding parameter.")] - [DisplayName("Client Encoding")] - [NpgsqlConnectionStringProperty] - public string? ClientEncoding + /// + /// Gets or sets the client_encoding parameter. + /// + [Category("Connection")] + [Description("Gets or sets the client_encoding parameter.")] + [DisplayName("Client Encoding")] + [NpgsqlConnectionStringProperty] + public string? ClientEncoding + { + get => _clientEncoding; + set { - get => _clientEncoding; - set - { - _clientEncoding = value; - SetValue(nameof(ClientEncoding), value); - } + _clientEncoding = value; + SetValue(nameof(ClientEncoding), value); } - string? _clientEncoding; + } + string? _clientEncoding; - /// - /// Gets or sets the .NET encoding that will be used to encode/decode PostgreSQL string data. - /// - [Category("Connection")] - [Description("Gets or sets the .NET encoding that will be used to encode/decode PostgreSQL string data.")] - [DisplayName("Encoding")] - [DefaultValue("UTF8")] - [NpgsqlConnectionStringProperty] - public string Encoding + /// + /// Gets or sets the .NET encoding that will be used to encode/decode PostgreSQL string data. + /// + [Category("Connection")] + [Description("Gets or sets the .NET encoding that will be used to encode/decode PostgreSQL string data.")] + [DisplayName("Encoding")] + [DefaultValue("UTF8")] + [NpgsqlConnectionStringProperty] + public string Encoding + { + get => _encoding; + set { - get => _encoding; - set - { - _encoding = value; - SetValue(nameof(Encoding), value); - } + _encoding = value; + SetValue(nameof(Encoding), value); } - string _encoding = "UTF8"; + } + string _encoding = "UTF8"; - /// - /// Gets or sets the PostgreSQL session timezone, in Olson/IANA database format. - /// - [Category("Connection")] - [Description("Gets or sets the PostgreSQL session timezone, in Olson/IANA database format.")] - [DisplayName("Timezone")] - [NpgsqlConnectionStringProperty] - public string? Timezone + /// + /// Gets or sets the PostgreSQL session timezone, in Olson/IANA database format. + /// + [Category("Connection")] + [Description("Gets or sets the PostgreSQL session timezone, in Olson/IANA database format.")] + [DisplayName("Timezone")] + [NpgsqlConnectionStringProperty] + public string? Timezone + { + get => _timezone; + set { - get => _timezone; - set - { - _timezone = value; - SetValue(nameof(Timezone), value); - } + _timezone = value; + SetValue(nameof(Timezone), value); } - string? _timezone; + } + string? _timezone; - #endregion + #endregion - #region Properties - Security + #region Properties - Security - /// - /// Controls whether SSL is required, disabled or preferred, depending on server support. - /// - [Category("Security")] - [Description("Controls whether SSL is required, disabled or preferred, depending on server support.")] - [DisplayName("SSL Mode")] - [NpgsqlConnectionStringProperty] - public SslMode SslMode + /// + /// Controls whether SSL is required, disabled or preferred, depending on server support. + /// + [Category("Security")] + [Description("Controls whether SSL is required, disabled or preferred, depending on server support.")] + [DisplayName("SSL Mode")] + [DefaultValue(SslMode.Prefer)] + [NpgsqlConnectionStringProperty] + public SslMode SslMode + { + get => _sslMode; + set { - get => _sslMode; - set - { - _sslMode = value; - SetValue(nameof(SslMode), value); - } + _sslMode = value; + SetValue(nameof(SslMode), value); } - SslMode _sslMode; + } + SslMode _sslMode; - /// - /// Whether to trust the server certificate without validating it. - /// - [Category("Security")] - [Description("Whether to trust the server certificate without validating it.")] - [DisplayName("Trust Server Certificate")] - [NpgsqlConnectionStringProperty] - public bool TrustServerCertificate + /// + /// Location of a client certificate to be sent to the server. + /// + [Category("Security")] + [Description("Location of a client certificate to be sent to the server.")] + [DisplayName("SSL Certificate")] + [NpgsqlConnectionStringProperty] + public string? SslCertificate + { + get => _sslCertificate; + set { - get => _trustServerCertificate; - set - { - _trustServerCertificate = value; - SetValue(nameof(TrustServerCertificate), value); - } + _sslCertificate = value; + SetValue(nameof(SslCertificate), value); } - bool _trustServerCertificate; + } + string? _sslCertificate; - /// - /// Location of a client certificate to be sent to the server. - /// - [Category("Security")] - [Description("Location of a client certificate to be sent to the server.")] - [DisplayName("Client Certificate")] - [NpgsqlConnectionStringProperty] - public string? ClientCertificate + /// + /// Location of a client key for a client certificate to be sent to the server. + /// + [Category("Security")] + [Description("Location of a client key for a client certificate to be sent to the server.")] + [DisplayName("SSL Key")] + [NpgsqlConnectionStringProperty] + public string? SslKey + { + get => _sslKey; + set { - get => _clientCertificate; - set - { - _clientCertificate = value; - SetValue(nameof(ClientCertificate), value); - } + _sslKey = value; + SetValue(nameof(SslKey), value); } - string? _clientCertificate; + } + string? _sslKey; - /// - /// Key for a client certificate to be sent to the server. - /// - [Category("Security")] - [Description("Key for a client certificate to be sent to the server.")] - [DisplayName("Client Certificate Key")] - [NpgsqlConnectionStringProperty] - public string? ClientCertificateKey + /// + /// Password for a key for a client certificate. + /// + [Category("Security")] + [Description("Password for a key for a client certificate.")] + [DisplayName("SSL Password")] + [NpgsqlConnectionStringProperty] + public string? SslPassword + { + get => _sslPassword; + set { - get => _clientCertificateKey; - set - { - _clientCertificateKey = value; - SetValue(nameof(ClientCertificateKey), value); - } + _sslPassword = value; + SetValue(nameof(SslPassword), value); } - string? _clientCertificateKey; + } + string? _sslPassword; - /// - /// Location of a CA certificate used to validate the server certificate. - /// - [Category("Security")] - [Description("Location of a CA certificate used to validate the server certificate.")] - [DisplayName("Root Certificate")] - [NpgsqlConnectionStringProperty] - public string? RootCertificate + /// + /// Location of a CA certificate used to validate the server certificate. + /// + [Category("Security")] + [Description("Location of a CA certificate used to validate the server certificate.")] + [DisplayName("Root Certificate")] + [NpgsqlConnectionStringProperty] + public string? RootCertificate + { + get => _rootCertificate; + set { - get => _rootCertificate; - set - { - _rootCertificate = value; - SetValue(nameof(RootCertificate), value); - } + _rootCertificate = value; + SetValue(nameof(RootCertificate), value); } - string? _rootCertificate; + } + string? _rootCertificate; - /// - /// Whether to check the certificate revocation list during authentication. - /// False by default. - /// - [Category("Security")] - [Description("Whether to check the certificate revocation list during authentication.")] - [DisplayName("Check Certificate Revocation")] - [NpgsqlConnectionStringProperty] - public bool CheckCertificateRevocation + /// + /// Whether to check the certificate revocation list during authentication. + /// False by default. + /// + [Category("Security")] + [Description("Whether to check the certificate revocation list during authentication.")] + [DisplayName("Check Certificate Revocation")] + [NpgsqlConnectionStringProperty] + public bool CheckCertificateRevocation + { + get => _checkCertificateRevocation; + set { - get => _checkCertificateRevocation; - set - { - _checkCertificateRevocation = value; - SetValue(nameof(CheckCertificateRevocation), value); - } + _checkCertificateRevocation = value; + SetValue(nameof(CheckCertificateRevocation), value); } - bool _checkCertificateRevocation; + } + bool _checkCertificateRevocation; - /// - /// Whether to use Windows integrated security to log in. - /// - [Category("Security")] - [Description("Whether to use Windows integrated security to log in.")] - [DisplayName("Integrated Security")] - [NpgsqlConnectionStringProperty] - public bool IntegratedSecurity + /// + /// The Kerberos service name to be used for authentication. + /// + [Category("Security")] + [Description("The Kerberos service name to be used for authentication.")] + [DisplayName("Kerberos Service Name")] + [NpgsqlConnectionStringProperty("Krbsrvname")] + [DefaultValue("postgres")] + public string KerberosServiceName + { + get => _kerberosServiceName; + set { - get => _integratedSecurity; - set - { - // No integrated security if we're on mono and .NET 4.5 because of ClaimsIdentity, - // see https://github.com/npgsql/Npgsql/issues/133 - if (value && Type.GetType("Mono.Runtime") != null) - throw new NotSupportedException("IntegratedSecurity is currently unsupported on mono and .NET 4.5 (see https://github.com/npgsql/Npgsql/issues/133)"); - _integratedSecurity = value; - SetValue(nameof(IntegratedSecurity), value); - } + _kerberosServiceName = value; + SetValue(nameof(KerberosServiceName), value); } - bool _integratedSecurity; + } + string _kerberosServiceName = "postgres"; - /// - /// The Kerberos service name to be used for authentication. - /// - [Category("Security")] - [Description("The Kerberos service name to be used for authentication.")] - [DisplayName("Kerberos Service Name")] - [NpgsqlConnectionStringProperty("Krbsrvname")] - [DefaultValue("postgres")] - public string KerberosServiceName + /// + /// The Kerberos realm to be used for authentication. + /// + [Category("Security")] + [Description("The Kerberos realm to be used for authentication.")] + [DisplayName("Include Realm")] + [NpgsqlConnectionStringProperty] + public bool IncludeRealm + { + get => _includeRealm; + set { - get => _kerberosServiceName; - set - { - _kerberosServiceName = value; - SetValue(nameof(KerberosServiceName), value); - } + _includeRealm = value; + SetValue(nameof(IncludeRealm), value); } - string _kerberosServiceName = "postgres"; + } + bool _includeRealm; - /// - /// The Kerberos realm to be used for authentication. - /// - [Category("Security")] - [Description("The Kerberos realm to be used for authentication.")] - [DisplayName("Include Realm")] - [NpgsqlConnectionStringProperty] - public bool IncludeRealm + /// + /// Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state. + /// + [Category("Security")] + [Description("Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state.")] + [DisplayName("Persist Security Info")] + [NpgsqlConnectionStringProperty] + public bool PersistSecurityInfo + { + get => _persistSecurityInfo; + set { - get => _includeRealm; - set - { - _includeRealm = value; - SetValue(nameof(IncludeRealm), value); - } + _persistSecurityInfo = value; + SetValue(nameof(PersistSecurityInfo), value); } - bool _includeRealm; + } + bool _persistSecurityInfo; - /// - /// Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state. - /// - [Category("Security")] - [Description("Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state.")] - [DisplayName("Persist Security Info")] - [NpgsqlConnectionStringProperty] - public bool PersistSecurityInfo + /// + /// When enabled, parameter values are logged when commands are executed. Defaults to false. + /// + [Category("Security")] + [Description("When enabled, parameter values are logged when commands are executed. Defaults to false.")] + [DisplayName("Log Parameters")] + [NpgsqlConnectionStringProperty] + public bool LogParameters + { + get => _logParameters; + set { - get => _persistSecurityInfo; - set - { - _persistSecurityInfo = value; - SetValue(nameof(PersistSecurityInfo), value); - } + _logParameters = value; + SetValue(nameof(LogParameters), value); } - bool _persistSecurityInfo; + } + bool _logParameters; - /// - /// When enabled, parameter values are logged when commands are executed. Defaults to false. - /// - [Category("Security")] - [Description("When enabled, parameter values are logged when commands are executed. Defaults to false.")] - [DisplayName("Log Parameters")] - [NpgsqlConnectionStringProperty] - public bool LogParameters + internal const string IncludeExceptionDetailDisplayName = "Include Error Detail"; + + /// + /// When enabled, PostgreSQL error details are included on and + /// . These can contain sensitive data. + /// + [Category("Security")] + [Description("When enabled, PostgreSQL error and notice details are included on PostgresException.Detail and PostgresNotice.Detail. These can contain sensitive data.")] + [DisplayName(IncludeExceptionDetailDisplayName)] + [NpgsqlConnectionStringProperty] + public bool IncludeErrorDetail + { + get => _includeErrorDetail; + set { - get => _logParameters; - set - { - _logParameters = value; - SetValue(nameof(LogParameters), value); - } + _includeErrorDetail = value; + SetValue(nameof(IncludeErrorDetail), value); } - bool _logParameters; - - internal const string IncludeExceptionDetailDisplayName = "Include Error Detail"; + } + bool _includeErrorDetail; - /// - /// When enabled, PostgreSQL error details are included on and - /// . These can contain sensitive data. - /// - [Category("Security")] - [Description("When enabled, PostgreSQL error and notice details are included on PostgresException.Detail and PostgresNotice.Detail. These can contain sensitive data.")] - [DisplayName(IncludeExceptionDetailDisplayName)] - [NpgsqlConnectionStringProperty] - public bool IncludeErrorDetails + /// + /// Controls whether channel binding is required, disabled or preferred, depending on server support. + /// + [Category("Security")] + [Description("Controls whether channel binding is required, disabled or preferred, depending on server support.")] + [DisplayName("Channel Binding")] + [DefaultValue(ChannelBinding.Prefer)] + [NpgsqlConnectionStringProperty] + public ChannelBinding ChannelBinding + { + get => _channelBinding; + set { - get => _includeErrorDetails; - set - { - _includeErrorDetails = value; - SetValue(nameof(IncludeErrorDetails), value); - } + _channelBinding = value; + SetValue(nameof(ChannelBinding), value); } - bool _includeErrorDetails; - + } + ChannelBinding _channelBinding; - #endregion + #endregion - #region Properties - Pooling + #region Properties - Pooling - /// - /// Whether connection pooling should be used. - /// - [Category("Pooling")] - [Description("Whether connection pooling should be used.")] - [DisplayName("Pooling")] - [NpgsqlConnectionStringProperty] - [DefaultValue(true)] - public bool Pooling + /// + /// Whether connection pooling should be used. + /// + [Category("Pooling")] + [Description("Whether connection pooling should be used.")] + [DisplayName("Pooling")] + [NpgsqlConnectionStringProperty] + [DefaultValue(true)] + public bool Pooling + { + get => _pooling; + set { - get => _pooling; - set - { - _pooling = value; - SetValue(nameof(Pooling), value); - } + _pooling = value; + SetValue(nameof(Pooling), value); } - bool _pooling; + } + bool _pooling; - /// - /// The minimum connection pool size. - /// - [Category("Pooling")] - [Description("The minimum connection pool size.")] - [DisplayName("Minimum Pool Size")] - [NpgsqlConnectionStringProperty] - [DefaultValue(0)] - public int MinPoolSize + /// + /// The minimum connection pool size. + /// + [Category("Pooling")] + [Description("The minimum connection pool size.")] + [DisplayName("Minimum Pool Size")] + [NpgsqlConnectionStringProperty] + [DefaultValue(0)] + public int MinPoolSize + { + get => _minPoolSize; + set { - get => _minPoolSize; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "MinPoolSize can't be negative"); + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "MinPoolSize can't be negative"); - _minPoolSize = value; - SetValue(nameof(MinPoolSize), value); - } + _minPoolSize = value; + SetValue(nameof(MinPoolSize), value); } - int _minPoolSize; + } + int _minPoolSize; - /// - /// The maximum connection pool size. - /// - [Category("Pooling")] - [Description("The maximum connection pool size.")] - [DisplayName("Maximum Pool Size")] - [NpgsqlConnectionStringProperty] - [DefaultValue(100)] - public int MaxPoolSize + /// + /// The maximum connection pool size. + /// + [Category("Pooling")] + [Description("The maximum connection pool size.")] + [DisplayName("Maximum Pool Size")] + [NpgsqlConnectionStringProperty] + [DefaultValue(100)] + public int MaxPoolSize + { + get => _maxPoolSize; + set { - get => _maxPoolSize; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "MaxPoolSize can't be negative"); + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "MaxPoolSize can't be negative"); - _maxPoolSize = value; - SetValue(nameof(MaxPoolSize), value); - } - } - int _maxPoolSize; - - /// - /// The time to wait before closing idle connections in the pool if the count - /// of all connections exceeds MinPoolSize. - /// - /// The time (in seconds) to wait. The default value is 300. - [Category("Pooling")] - [Description("The time to wait before closing unused connections in the pool if the count of all connections exceeds MinPoolSize.")] - [DisplayName("Connection Idle Lifetime")] - [NpgsqlConnectionStringProperty] - [DefaultValue(300)] - public int ConnectionIdleLifetime - { - get => _connectionIdleLifetime; - set - { - _connectionIdleLifetime = value; - SetValue(nameof(ConnectionIdleLifetime), value); - } + _maxPoolSize = value; + SetValue(nameof(MaxPoolSize), value); } - int _connectionIdleLifetime; - - /// - /// How many seconds the pool waits before attempting to prune idle connections that are beyond - /// idle lifetime (. - /// - /// The interval (in seconds). The default value is 10. - [Category("Pooling")] - [Description("How many seconds the pool waits before attempting to prune idle connections that are beyond idle lifetime.")] - [DisplayName("Connection Pruning Interval")] - [NpgsqlConnectionStringProperty] - [DefaultValue(10)] - public int ConnectionPruningInterval - { - get => _connectionPruningInterval; - set - { - _connectionPruningInterval = value; - SetValue(nameof(ConnectionPruningInterval), value); - } - } - int _connectionPruningInterval; - - /// - /// The total maximum lifetime of connections (in seconds). Connections which have exceeded this value will be - /// destroyed instead of returned from the pool. This is useful in clustered configurations to force load - /// balancing between a running server and a server just brought online. - /// - /// The time (in seconds) to wait, or 0 to to make connections last indefinitely (the default). - [Category("Pooling")] - [Description("The total maximum lifetime of connections (in seconds).")] - [DisplayName("Connection Lifetime")] - [NpgsqlConnectionStringProperty("Load Balance Timeout")] - [DefaultValue(0)] - public int ConnectionLifetime - { - get => _connectionLifetime; - set - { - _connectionLifetime = value; - SetValue(nameof(ConnectionLifetime), value); - } - } - int _connectionLifetime; - - #endregion - - #region Properties - Timeouts + } + int _maxPoolSize; - /// - /// The time to wait (in seconds) while trying to establish a connection before terminating the attempt and generating an error. - /// Defaults to 15 seconds. - /// - [Category("Timeouts")] - [Description("The time to wait (in seconds) while trying to establish a connection before terminating the attempt and generating an error.")] - [DisplayName("Timeout")] - [NpgsqlConnectionStringProperty] - [DefaultValue(DefaultTimeout)] - public int Timeout + /// + /// The time to wait before closing idle connections in the pool if the count + /// of all connections exceeds MinPoolSize. + /// + /// The time (in seconds) to wait. The default value is 300. + [Category("Pooling")] + [Description("The time to wait before closing unused connections in the pool if the count of all connections exceeds MinPoolSize.")] + [DisplayName("Connection Idle Lifetime")] + [NpgsqlConnectionStringProperty] + [DefaultValue(300)] + public int ConnectionIdleLifetime + { + get => _connectionIdleLifetime; + set { - get => _timeout; - set - { - if (value < 0 || value > NpgsqlConnection.TimeoutLimit) - throw new ArgumentOutOfRangeException(nameof(value), value, "Timeout must be between 0 and " + NpgsqlConnection.TimeoutLimit); - - _timeout = value; - SetValue(nameof(Timeout), value); - } + _connectionIdleLifetime = value; + SetValue(nameof(ConnectionIdleLifetime), value); } - int _timeout; + } + int _connectionIdleLifetime; - internal const int DefaultTimeout = 15; - - /// - /// The time to wait (in seconds) while trying to execute a command before terminating the attempt and generating an error. - /// Defaults to 30 seconds. - /// - [Category("Timeouts")] - [Description("The time to wait (in seconds) while trying to execute a command before terminating the attempt and generating an error. Set to zero for infinity.")] - [DisplayName("Command Timeout")] - [NpgsqlConnectionStringProperty] - [DefaultValue(NpgsqlCommand.DefaultTimeout)] - public int CommandTimeout + /// + /// How many seconds the pool waits before attempting to prune idle connections that are beyond + /// idle lifetime (. + /// + /// The interval (in seconds). The default value is 10. + [Category("Pooling")] + [Description("How many seconds the pool waits before attempting to prune idle connections that are beyond idle lifetime.")] + [DisplayName("Connection Pruning Interval")] + [NpgsqlConnectionStringProperty] + [DefaultValue(10)] + public int ConnectionPruningInterval + { + get => _connectionPruningInterval; + set { - get => _commandTimeout; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be negative"); + _connectionPruningInterval = value; + SetValue(nameof(ConnectionPruningInterval), value); + } + } + int _connectionPruningInterval; - _commandTimeout = value; - SetValue(nameof(CommandTimeout), value); - } + /// + /// The total maximum lifetime of connections (in seconds). Connections which have exceeded this value will be + /// destroyed instead of returned from the pool. This is useful in clustered configurations to force load + /// balancing between a running server and a server just brought online. + /// + /// The time (in seconds) to wait, or 0 to to make connections last indefinitely (the default). + [Category("Pooling")] + [Description("The total maximum lifetime of connections (in seconds).")] + [DisplayName("Connection Lifetime")] + [NpgsqlConnectionStringProperty("Load Balance Timeout")] + public int ConnectionLifetime + { + get => _connectionLifetime; + set + { + _connectionLifetime = value; + SetValue(nameof(ConnectionLifetime), value); } - int _commandTimeout; + } + int _connectionLifetime; + + #endregion + + #region Properties - Timeouts - /// - /// The time to wait (in seconds) while trying to execute a an internal command before terminating the attempt and generating an error. - /// - [Category("Timeouts")] - [Description("The time to wait (in seconds) while trying to execute a an internal command before terminating the attempt and generating an error. -1 uses CommandTimeout, 0 means no timeout.")] - [DisplayName("Internal Command Timeout")] - [NpgsqlConnectionStringProperty] - [DefaultValue(-1)] - public int InternalCommandTimeout + /// + /// The time to wait (in seconds) while trying to establish a connection before terminating the attempt and generating an error. + /// Defaults to 15 seconds. + /// + [Category("Timeouts")] + [Description("The time to wait (in seconds) while trying to establish a connection before terminating the attempt and generating an error.")] + [DisplayName("Timeout")] + [NpgsqlConnectionStringProperty] + [DefaultValue(DefaultTimeout)] + public int Timeout + { + get => _timeout; + set { - get => _internalCommandTimeout; - set - { - if (value != 0 && value != -1 && value < NpgsqlConnector.MinimumInternalCommandTimeout) - throw new ArgumentOutOfRangeException(nameof(value), value, - $"InternalCommandTimeout must be >= {NpgsqlConnector.MinimumInternalCommandTimeout}, 0 (infinite) or -1 (use CommandTimeout)"); + if (value < 0 || value > NpgsqlConnection.TimeoutLimit) + throw new ArgumentOutOfRangeException(nameof(value), value, "Timeout must be between 0 and " + NpgsqlConnection.TimeoutLimit); - _internalCommandTimeout = value; - SetValue(nameof(InternalCommandTimeout), value); - } + _timeout = value; + SetValue(nameof(Timeout), value); } - int _internalCommandTimeout; - - /// - /// The time to wait (in milliseconds) while trying to read a response for a cancellation request for a timed out or cancelled query, before terminating the attempt and generating an error. - /// Defaults to 2000 milliseconds. - /// - [Category("Timeouts")] - [Description("After Command Timeout is reached (or user supplied cancellation token is cancelled) and command cancellation is attempted, Npgsql waits for this additional timeout (in milliseconds) before breaking the connection. Defaults to 2000, set to zero for infinity.")] - [DisplayName("Cancellation Timeout")] - [NpgsqlConnectionStringProperty] - [DefaultValue(2000)] - public int CancellationTimeout - { - get => _cancellationTimeout; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(CancellationTimeout)} can't be negative"); + } + int _timeout; - _cancellationTimeout = value; - SetValue(nameof(CancellationTimeout), value); - } - } - int _cancellationTimeout; + internal const int DefaultTimeout = 15; - #endregion + /// + /// The time to wait (in seconds) while trying to execute a command before terminating the attempt and generating an error. + /// Defaults to 30 seconds. + /// + [Category("Timeouts")] + [Description("The time to wait (in seconds) while trying to execute a command before terminating the attempt and generating an error. Set to zero for infinity.")] + [DisplayName("Command Timeout")] + [NpgsqlConnectionStringProperty] + [DefaultValue(NpgsqlCommand.DefaultTimeout)] + public int CommandTimeout + { + get => _commandTimeout; + set + { + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be negative"); - #region Properties - Entity Framework + _commandTimeout = value; + SetValue(nameof(CommandTimeout), value); + } + } + int _commandTimeout; - /// - /// The database template to specify when creating a database in Entity Framework. If not specified, - /// PostgreSQL defaults to "template1". - /// - /// - /// https://www.postgresql.org/docs/current/static/manage-ag-templatedbs.html - /// - [Category("Entity Framework")] - [Description("The database template to specify when creating a database in Entity Framework. If not specified, PostgreSQL defaults to \"template1\".")] - [DisplayName("EF Template Database")] - [NpgsqlConnectionStringProperty] - public string? EntityTemplateDatabase + /// + /// The time to wait (in milliseconds) while trying to read a response for a cancellation request for a timed out or cancelled query, before terminating the attempt and generating an error. + /// Zero for infinity, -1 to skip the wait. + /// Defaults to 2000 milliseconds. + /// + [Category("Timeouts")] + [Description("After Command Timeout is reached (or user supplied cancellation token is cancelled) and command cancellation is attempted, Npgsql waits for this additional timeout (in milliseconds) before breaking the connection. Defaults to 2000, set to zero for infinity.")] + [DisplayName("Cancellation Timeout")] + [NpgsqlConnectionStringProperty] + [DefaultValue(2000)] + public int CancellationTimeout + { + get => _cancellationTimeout; + set { - get => _entityTemplateDatabase; - set - { - _entityTemplateDatabase = value; - SetValue(nameof(EntityTemplateDatabase), value); - } - } - string? _entityTemplateDatabase; - - /// - /// The database admin to specify when creating and dropping a database in Entity Framework. This is needed because - /// Npgsql needs to connect to a database in order to send the create/drop database command. - /// If not specified, defaults to "template1". Check NpgsqlServices.UsingPostgresDBConnection for more information. - /// - [Category("Entity Framework")] - [Description("The database admin to specify when creating and dropping a database in Entity Framework. If not specified, defaults to \"template1\".")] - [DisplayName("EF Admin Database")] - [NpgsqlConnectionStringProperty] - public string? EntityAdminDatabase - { - get => _entityAdminDatabase; - set - { - _entityAdminDatabase = value; - SetValue(nameof(EntityAdminDatabase), value); - } + if (value < -1) + throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(CancellationTimeout)} can't less than -1"); + + _cancellationTimeout = value; + SetValue(nameof(CancellationTimeout), value); } - string? _entityAdminDatabase; + } + int _cancellationTimeout; - #endregion + #endregion - #region Properties - Advanced + #region Properties - Failover and load balancing - /// - /// The number of seconds of connection inactivity before Npgsql sends a keepalive query. - /// Set to 0 (the default) to disable. - /// - [Category("Advanced")] - [Description("The number of seconds of connection inactivity before Npgsql sends a keepalive query.")] - [DisplayName("Keepalive")] - [NpgsqlConnectionStringProperty] - public int KeepAlive + /// + /// Determines the preferred PostgreSQL target server type. + /// + [Category("Failover and load balancing")] + [Description("Determines the preferred PostgreSQL target server type.")] + [DisplayName("Target Session Attributes")] + [NpgsqlConnectionStringProperty] + public string? TargetSessionAttributes + { + get => TargetSessionAttributesParsed switch { - get => _keepAlive; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "KeepAlive can't be negative"); + Npgsql.TargetSessionAttributes.Any => "any", + Npgsql.TargetSessionAttributes.Primary => "primary", + Npgsql.TargetSessionAttributes.Standby => "standby", + Npgsql.TargetSessionAttributes.PreferPrimary => "prefer-primary", + Npgsql.TargetSessionAttributes.PreferStandby => "prefer-standby", + Npgsql.TargetSessionAttributes.ReadWrite => "read-write", + Npgsql.TargetSessionAttributes.ReadOnly => "read-only", + null => null, - _keepAlive = value; - SetValue(nameof(KeepAlive), value); - } - } - int _keepAlive; + _ => throw new ArgumentException($"Unhandled enum value '{TargetSessionAttributesParsed}'") + }; - /// - /// Whether to use TCP keepalive with system defaults if overrides isn't specified. - /// - [Category("Advanced")] - [Description("Whether to use TCP keepalive with system defaults if overrides isn't specified.")] - [DisplayName("TCP Keepalive")] - [NpgsqlConnectionStringProperty] - public bool TcpKeepAlive + set { - get => _tcpKeepAlive; - set - { - _tcpKeepAlive = value; - SetValue(nameof(TcpKeepAlive), value); - } + TargetSessionAttributesParsed = value is null ? null : ParseTargetSessionAttributes(value); + SetValue(nameof(TargetSessionAttributes), value); } - bool _tcpKeepAlive; - - /// - /// The number of seconds of connection inactivity before a TCP keepalive query is sent. - /// Use of this option is discouraged, use instead if possible. - /// Set to 0 (the default) to disable. - /// - [Category("Advanced")] - [Description("The number of seconds of connection inactivity before a TCP keepalive query is sent.")] - [DisplayName("TCP Keepalive Time")] - [NpgsqlConnectionStringProperty] - public int TcpKeepAliveTime - { - get => _tcpKeepAliveTime; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveTime can't be negative"); + } - _tcpKeepAliveTime = value; - SetValue(nameof(TcpKeepAliveTime), value); - } - } - int _tcpKeepAliveTime; + internal TargetSessionAttributes? TargetSessionAttributesParsed { get; set; } - /// - /// The interval, in seconds, between when successive keep-alive packets are sent if no acknowledgement is received. - /// Defaults to the value of . must be non-zero as well. - /// - [Category("Advanced")] - [Description("The interval, in seconds, between when successive keep-alive packets are sent if no acknowledgement is received.")] - [DisplayName("TCP Keepalive Interval")] - [NpgsqlConnectionStringProperty] - public int TcpKeepAliveInterval + internal static TargetSessionAttributes ParseTargetSessionAttributes(string s) + => s switch { - get => _tcpKeepAliveInterval; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveInterval can't be negative"); + "any" => Npgsql.TargetSessionAttributes.Any, + "primary" => Npgsql.TargetSessionAttributes.Primary, + "standby" => Npgsql.TargetSessionAttributes.Standby, + "prefer-primary" => Npgsql.TargetSessionAttributes.PreferPrimary, + "prefer-standby" => Npgsql.TargetSessionAttributes.PreferStandby, + "read-write" => Npgsql.TargetSessionAttributes.ReadWrite, + "read-only" => Npgsql.TargetSessionAttributes.ReadOnly, - _tcpKeepAliveInterval = value; - SetValue(nameof(TcpKeepAliveInterval), value); - } - } - int _tcpKeepAliveInterval; + _ => throw new ArgumentException($"TargetSessionAttributes contains an invalid value '{s}'") + }; - /// - /// Determines the size of the internal buffer Npgsql uses when reading. Increasing may improve performance if transferring large values from the database. - /// - [Category("Advanced")] - [Description("Determines the size of the internal buffer Npgsql uses when reading. Increasing may improve performance if transferring large values from the database.")] - [DisplayName("Read Buffer Size")] - [NpgsqlConnectionStringProperty] - [DefaultValue(NpgsqlReadBuffer.DefaultSize)] - public int ReadBufferSize + /// + /// Enables balancing between multiple hosts by round-robin. + /// + [Category("Failover and load balancing")] + [Description("Enables balancing between multiple hosts by round-robin.")] + [DisplayName("Load Balance Hosts")] + [NpgsqlConnectionStringProperty] + public bool LoadBalanceHosts + { + get => _loadBalanceHosts; + set { - get => _readBufferSize; - set - { - _readBufferSize = value; - SetValue(nameof(ReadBufferSize), value); - } + _loadBalanceHosts = value; + SetValue(nameof(LoadBalanceHosts), value); } - int _readBufferSize; + } + bool _loadBalanceHosts; - /// - /// Determines the size of the internal buffer Npgsql uses when writing. Increasing may improve performance if transferring large values to the database. - /// - [Category("Advanced")] - [Description("Determines the size of the internal buffer Npgsql uses when writing. Increasing may improve performance if transferring large values to the database.")] - [DisplayName("Write Buffer Size")] - [NpgsqlConnectionStringProperty] - [DefaultValue(NpgsqlWriteBuffer.DefaultSize)] - public int WriteBufferSize + /// + /// Controls for how long the host's cached state will be considered as valid. + /// + [Category("Failover and load balancing")] + [Description("Controls for how long the host's cached state will be considered as valid.")] + [DisplayName("Host Recheck Seconds")] + [DefaultValue(10)] + [NpgsqlConnectionStringProperty] + public int HostRecheckSeconds + { + get => _hostRecheckSeconds; + set { - get => _writeBufferSize; - set - { - _writeBufferSize = value; - SetValue(nameof(WriteBufferSize), value); - } + if (value < 0) + throw new ArgumentException($"{HostRecheckSeconds} cannot be negative", nameof(HostRecheckSeconds)); + _hostRecheckSeconds = value; + SetValue(nameof(HostRecheckSeconds), value); } - int _writeBufferSize; + } + int _hostRecheckSeconds; - /// - /// Determines the size of socket read buffer. - /// - [Category("Advanced")] - [Description("Determines the size of socket receive buffer.")] - [DisplayName("Socket Receive Buffer Size")] - [NpgsqlConnectionStringProperty] - public int SocketReceiveBufferSize + #endregion Properties - Failover and load balancing + + #region Properties - Advanced + + /// + /// The number of seconds of connection inactivity before Npgsql sends a keepalive query. + /// Set to 0 (the default) to disable. + /// + [Category("Advanced")] + [Description("The number of seconds of connection inactivity before Npgsql sends a keepalive query.")] + [DisplayName("Keepalive")] + [NpgsqlConnectionStringProperty] + public int KeepAlive + { + get => _keepAlive; + set { - get => _socketReceiveBufferSize; - set - { - _socketReceiveBufferSize = value; - SetValue(nameof(SocketReceiveBufferSize), value); - } + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "KeepAlive can't be negative"); + + _keepAlive = value; + SetValue(nameof(KeepAlive), value); } - int _socketReceiveBufferSize; + } + int _keepAlive; - /// - /// Determines the size of socket send buffer. - /// - [Category("Advanced")] - [Description("Determines the size of socket send buffer.")] - [DisplayName("Socket Send Buffer Size")] - [NpgsqlConnectionStringProperty] - public int SocketSendBufferSize + /// + /// Whether to use TCP keepalive with system defaults if overrides isn't specified. + /// + [Category("Advanced")] + [Description("Whether to use TCP keepalive with system defaults if overrides isn't specified.")] + [DisplayName("TCP Keepalive")] + [NpgsqlConnectionStringProperty] + public bool TcpKeepAlive + { + get => _tcpKeepAlive; + set { - get => _socketSendBufferSize; - set - { - _socketSendBufferSize = value; - SetValue(nameof(SocketSendBufferSize), value); - } + _tcpKeepAlive = value; + SetValue(nameof(TcpKeepAlive), value); } - int _socketSendBufferSize; - - /// - /// The maximum number SQL statements that can be automatically prepared at any given point. - /// Beyond this number the least-recently-used statement will be recycled. - /// Zero (the default) disables automatic preparation. - /// - [Category("Advanced")] - [Description("The maximum number SQL statements that can be automatically prepared at any given point. Beyond this number the least-recently-used statement will be recycled. Zero (the default) disables automatic preparation.")] - [DisplayName("Max Auto Prepare")] - [NpgsqlConnectionStringProperty] - public int MaxAutoPrepare - { - get => _maxAutoPrepare; - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(MaxAutoPrepare)} cannot be negative"); + } + bool _tcpKeepAlive; - _maxAutoPrepare = value; - SetValue(nameof(MaxAutoPrepare), value); - } - } - int _maxAutoPrepare; - - /// - /// The minimum number of usages an SQL statement is used before it's automatically prepared. - /// Defaults to 5. - /// - [Category("Advanced")] - [Description("The minimum number of usages an SQL statement is used before it's automatically prepared. Defaults to 5.")] - [DisplayName("Auto Prepare Min Usages")] - [NpgsqlConnectionStringProperty] - [DefaultValue(5)] - public int AutoPrepareMinUsages - { - get => _autoPrepareMinUsages; - set - { - if (value < 1) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(AutoPrepareMinUsages)} must be 1 or greater"); + /// + /// The number of seconds of connection inactivity before a TCP keepalive query is sent. + /// Use of this option is discouraged, use instead if possible. + /// Set to 0 (the default) to disable. + /// + [Category("Advanced")] + [Description("The number of seconds of connection inactivity before a TCP keepalive query is sent.")] + [DisplayName("TCP Keepalive Time")] + [NpgsqlConnectionStringProperty] + public int TcpKeepAliveTime + { + get => _tcpKeepAliveTime; + set + { + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveTime can't be negative"); - _autoPrepareMinUsages = value; - SetValue(nameof(AutoPrepareMinUsages), value); - } + _tcpKeepAliveTime = value; + SetValue(nameof(TcpKeepAliveTime), value); } - int _autoPrepareMinUsages; + } + int _tcpKeepAliveTime; - /// - /// If set to true, a pool connection's state won't be reset when it is closed (improves performance). - /// Do not specify this unless you know what you're doing. - /// - [Category("Advanced")] - [Description("If set to true, a pool connection's state won't be reset when it is closed (improves performance). Do not specify this unless you know what you're doing.")] - [DisplayName("No Reset On Close")] - [NpgsqlConnectionStringProperty] - public bool NoResetOnClose + /// + /// The interval, in seconds, between when successive keep-alive packets are sent if no acknowledgement is received. + /// Defaults to the value of . must be non-zero as well. + /// + [Category("Advanced")] + [Description("The interval, in seconds, between when successive keep-alive packets are sent if no acknowledgement is received.")] + [DisplayName("TCP Keepalive Interval")] + [NpgsqlConnectionStringProperty] + public int TcpKeepAliveInterval + { + get => _tcpKeepAliveInterval; + set { - get => _noResetOnClose; - set - { - _noResetOnClose = value; - SetValue(nameof(NoResetOnClose), value); - } + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveInterval can't be negative"); + + _tcpKeepAliveInterval = value; + SetValue(nameof(TcpKeepAliveInterval), value); } - bool _noResetOnClose; + } + int _tcpKeepAliveInterval; - /// - /// Load table composite type definitions, and not just free-standing composite types. - /// - [Category("Advanced")] - [Description("Load table composite type definitions, and not just free-standing composite types.")] - [DisplayName("Load Table Composites")] - [NpgsqlConnectionStringProperty] - public bool LoadTableComposites + /// + /// Determines the size of the internal buffer Npgsql uses when reading. Increasing may improve performance if transferring large values from the database. + /// + [Category("Advanced")] + [Description("Determines the size of the internal buffer Npgsql uses when reading. Increasing may improve performance if transferring large values from the database.")] + [DisplayName("Read Buffer Size")] + [NpgsqlConnectionStringProperty] + [DefaultValue(NpgsqlReadBuffer.DefaultSize)] + public int ReadBufferSize + { + get => _readBufferSize; + set { - get => _loadTableComposites; - set - { - _loadTableComposites = value; - SetValue(nameof(LoadTableComposites), value); - } - } - bool _loadTableComposites; - - /// - /// Set the replication mode of the connection - /// - /// - /// This property and its corresponding enum are intentionally kept internal as they - /// should not be set by users or even be visible in their connection strings. - /// Replication connections are a special kind of connection that is encapsulated in - /// - /// and . - /// - - [NpgsqlConnectionStringProperty] - [DisplayName("Replication Mode")] - internal ReplicationMode ReplicationMode - { - get => _replicationMode; - set - { - _replicationMode = value; - SetValue(nameof(ReplicationMode), value); - } + _readBufferSize = value; + SetValue(nameof(ReadBufferSize), value); } - ReplicationMode _replicationMode; + } + int _readBufferSize; - /// - /// Set PostgreSQL configuration parameter default values for the connection. - /// - [Category("Advanced")] - [Description("Set PostgreSQL configuration parameter default values for the connection.")] - [DisplayName("Options")] - [NpgsqlConnectionStringProperty] - public string? Options + /// + /// Determines the size of the internal buffer Npgsql uses when writing. Increasing may improve performance if transferring large values to the database. + /// + [Category("Advanced")] + [Description("Determines the size of the internal buffer Npgsql uses when writing. Increasing may improve performance if transferring large values to the database.")] + [DisplayName("Write Buffer Size")] + [NpgsqlConnectionStringProperty] + [DefaultValue(NpgsqlWriteBuffer.DefaultSize)] + public int WriteBufferSize + { + get => _writeBufferSize; + set { - get => _options; - set - { - _options = value; - SetValue(nameof(Options), value); - } + _writeBufferSize = value; + SetValue(nameof(WriteBufferSize), value); } + } + int _writeBufferSize; - string? _options; - - #endregion - - #region Multiplexing - - /// - /// Enables multiplexing, which allows more efficient use of connections. - /// - [Category("Multiplexing")] - [Description("Enables multiplexing, which allows more efficient use of connections.")] - [DisplayName("Multiplexing")] - [NpgsqlConnectionStringProperty] - [DefaultValue(false)] - public bool Multiplexing + /// + /// Determines the size of socket read buffer. + /// + [Category("Advanced")] + [Description("Determines the size of socket receive buffer.")] + [DisplayName("Socket Receive Buffer Size")] + [NpgsqlConnectionStringProperty] + public int SocketReceiveBufferSize + { + get => _socketReceiveBufferSize; + set { - get => _multiplexing; - set - { - _multiplexing = value; - SetValue(nameof(Multiplexing), value); - } - } - bool _multiplexing; - - /// - /// When multiplexing is enabled, determines the maximum amount of time to wait for further - /// commands before flushing to the network. In microseconds, 0 disables waiting altogether. - /// - [Category("Multiplexing")] - [Description("When multiplexing is enabled, determines the maximum amount of time to wait for further " + - "commands before flushing to the network. In microseconds, 0 disables waiting altogether.")] - [DisplayName("Write Coalescing Delay Us")] - [NpgsqlConnectionStringProperty] - [DefaultValue(0)] - public int WriteCoalescingDelayUs - { - get => _writeCoalescingDelayUs; - set - { - _writeCoalescingDelayUs = value; - SetValue(nameof(WriteCoalescingDelayUs), value); - } + _socketReceiveBufferSize = value; + SetValue(nameof(SocketReceiveBufferSize), value); } - int _writeCoalescingDelayUs; - - /// - /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before - /// flushing to the network. - /// - [Category("Multiplexing")] - [Description("When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before " + - "flushing to the network.")] - [DisplayName("Write Coalescing Buffer Threshold Bytes")] - [NpgsqlConnectionStringProperty] - [DefaultValue(1000)] - public int WriteCoalescingBufferThresholdBytes - { - get => _writeCoalescingBufferThresholdBytes; - set - { - _writeCoalescingBufferThresholdBytes = value; - SetValue(nameof(WriteCoalescingBufferThresholdBytes), value); - } + } + int _socketReceiveBufferSize; + + /// + /// Determines the size of socket send buffer. + /// + [Category("Advanced")] + [Description("Determines the size of socket send buffer.")] + [DisplayName("Socket Send Buffer Size")] + [NpgsqlConnectionStringProperty] + public int SocketSendBufferSize + { + get => _socketSendBufferSize; + set + { + _socketSendBufferSize = value; + SetValue(nameof(SocketSendBufferSize), value); } - int _writeCoalescingBufferThresholdBytes; + } + int _socketSendBufferSize; - #endregion + /// + /// The maximum number SQL statements that can be automatically prepared at any given point. + /// Beyond this number the least-recently-used statement will be recycled. + /// Zero (the default) disables automatic preparation. + /// + [Category("Advanced")] + [Description("The maximum number SQL statements that can be automatically prepared at any given point. Beyond this number the least-recently-used statement will be recycled. Zero (the default) disables automatic preparation.")] + [DisplayName("Max Auto Prepare")] + [NpgsqlConnectionStringProperty] + public int MaxAutoPrepare + { + get => _maxAutoPrepare; + set + { + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(MaxAutoPrepare)} cannot be negative"); - #region Properties - Compatibility + _maxAutoPrepare = value; + SetValue(nameof(MaxAutoPrepare), value); + } + } + int _maxAutoPrepare; - /// - /// A compatibility mode for special PostgreSQL server types. - /// - [Category("Compatibility")] - [Description("A compatibility mode for special PostgreSQL server types.")] - [DisplayName("Server Compatibility Mode")] - [NpgsqlConnectionStringProperty] - public ServerCompatibilityMode ServerCompatibilityMode + /// + /// The minimum number of usages an SQL statement is used before it's automatically prepared. + /// Defaults to 5. + /// + [Category("Advanced")] + [Description("The minimum number of usages an SQL statement is used before it's automatically prepared. Defaults to 5.")] + [DisplayName("Auto Prepare Min Usages")] + [NpgsqlConnectionStringProperty] + [DefaultValue(5)] + public int AutoPrepareMinUsages + { + get => _autoPrepareMinUsages; + set { - get => _serverCompatibilityMode; - set - { - _serverCompatibilityMode = value; - SetValue(nameof(ServerCompatibilityMode), value); - } + if (value < 1) + throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(AutoPrepareMinUsages)} must be 1 or greater"); + + _autoPrepareMinUsages = value; + SetValue(nameof(AutoPrepareMinUsages), value); } - ServerCompatibilityMode _serverCompatibilityMode; + } + int _autoPrepareMinUsages; - /// - /// Makes MaxValue and MinValue timestamps and dates readable as infinity and negative infinity. - /// - [Category("Compatibility")] - [Description("Makes MaxValue and MinValue timestamps and dates readable as infinity and negative infinity.")] - [DisplayName("Convert Infinity DateTime")] - [NpgsqlConnectionStringProperty] - public bool ConvertInfinityDateTime + /// + /// If set to true, a pool connection's state won't be reset when it is closed (improves performance). + /// Do not specify this unless you know what you're doing. + /// + [Category("Advanced")] + [Description("If set to true, a pool connection's state won't be reset when it is closed (improves performance). Do not specify this unless you know what you're doing.")] + [DisplayName("No Reset On Close")] + [NpgsqlConnectionStringProperty] + public bool NoResetOnClose + { + get => _noResetOnClose; + set { - get => _convertInfinityDateTime; - set - { - _convertInfinityDateTime = value; - SetValue(nameof(ConvertInfinityDateTime), value); - } + _noResetOnClose = value; + SetValue(nameof(NoResetOnClose), value); } - bool _convertInfinityDateTime; - - #endregion - - #region Properties - Obsolete + } + bool _noResetOnClose; - /// - /// Obsolete, see https://www.npgsql.org/doc/release-notes/3.1.html - /// - [Category("Obsolete")] - [Description("Obsolete, see https://www.npgsql.org/doc/release-notes/3.1.html")] - [DisplayName("Continuous Processing")] - [NpgsqlConnectionStringProperty] - [Obsolete("The ContinuousProcessing parameter is no longer supported.")] - public bool ContinuousProcessing + /// + /// Load table composite type definitions, and not just free-standing composite types. + /// + [Category("Advanced")] + [Description("Load table composite type definitions, and not just free-standing composite types.")] + [DisplayName("Load Table Composites")] + [NpgsqlConnectionStringProperty] + public bool LoadTableComposites + { + get => _loadTableComposites; + set { - get => false; - set => throw new NotSupportedException("The ContinuousProcessing parameter is no longer supported. Please see https://www.npgsql.org/doc/release-notes/3.1.html"); + _loadTableComposites = value; + SetValue(nameof(LoadTableComposites), value); } + } + bool _loadTableComposites; - /// - /// Obsolete, see https://www.npgsql.org/doc/release-notes/3.1.html - /// - [Category("Obsolete")] - [Description("Obsolete, see https://www.npgsql.org/doc/release-notes/3.1.html")] - [DisplayName("Backend Timeouts")] - [NpgsqlConnectionStringProperty] - [Obsolete("The BackendTimeouts parameter is no longer supported")] - public bool BackendTimeouts + /// + /// Set the replication mode of the connection + /// + /// + /// This property and its corresponding enum are intentionally kept internal as they + /// should not be set by users or even be visible in their connection strings. + /// Replication connections are a special kind of connection that is encapsulated in + /// + /// and . + /// + [NpgsqlConnectionStringProperty] + [DisplayName("Replication Mode")] + internal ReplicationMode ReplicationMode + { + get => _replicationMode; + set { - get => false; - set => throw new NotSupportedException("The BackendTimeouts parameter is no longer supported. Please see https://www.npgsql.org/doc/release-notes/3.1.html"); + _replicationMode = value; + SetValue(nameof(ReplicationMode), value); } + } + ReplicationMode _replicationMode; - /// - /// Obsolete, see https://www.npgsql.org/doc/release-notes/3.0.html - /// - [Category("Obsolete")] - [Description("Obsolete, see https://www.npgsql.org/doc/v/3.0.html")] - [DisplayName("Preload Reader")] - [NpgsqlConnectionStringProperty] - [Obsolete("The PreloadReader parameter is no longer supported")] - public bool PreloadReader + /// + /// Set PostgreSQL configuration parameter default values for the connection. + /// + [Category("Advanced")] + [Description("Set PostgreSQL configuration parameter default values for the connection.")] + [DisplayName("Options")] + [NpgsqlConnectionStringProperty] + public string? Options + { + get => _options; + set { - get => false; - set => throw new NotSupportedException("The PreloadReader parameter is no longer supported. Please see https://www.npgsql.org/doc/release-notes/3.0.html"); + _options = value; + SetValue(nameof(Options), value); } + } + + string? _options; - /// - /// Obsolete, see https://www.npgsql.org/doc/release-notes/3.0.html - /// - [Category("Obsolete")] - [Description("Obsolete, see https://www.npgsql.org/doc/release-notes/3.0.html")] - [DisplayName("Use Extended Types")] - [NpgsqlConnectionStringProperty] - [Obsolete("The UseExtendedTypes parameter is no longer supported")] - public bool UseExtendedTypes + /// + /// Configure the way arrays of value types are returned when requested as object instances. + /// + [Category("Advanced")] + [Description("Configure the way arrays of value types are returned when requested as object instances.")] + [DisplayName("Array Nullability Mode")] + [NpgsqlConnectionStringProperty] + public ArrayNullabilityMode ArrayNullabilityMode + { + get => _arrayNullabilityMode; + set { - get => false; - set => throw new NotSupportedException("The UseExtendedTypes parameter is no longer supported. Please see https://www.npgsql.org/doc/release-notes/3.0.html"); + _arrayNullabilityMode = value; + SetValue(nameof(ArrayNullabilityMode), value); } + } + + ArrayNullabilityMode _arrayNullabilityMode; + + #endregion + + #region Multiplexing - /// - /// Obsolete, see https://www.npgsql.org/doc/release-notes/4.1.html - /// - [Category("Obsolete")] - [Description("Obsolete, see https://www.npgsql.org/doc/release-notes/4.1.html")] - [DisplayName("Use Ssl Stream")] - [NpgsqlConnectionStringProperty] - [Obsolete("The UseSslStream parameter is no longer supported (always true)")] - public bool UseSslStream + /// + /// Enables multiplexing, which allows more efficient use of connections. + /// + [Category("Multiplexing")] + [Description("Enables multiplexing, which allows more efficient use of connections.")] + [DisplayName("Multiplexing")] + [NpgsqlConnectionStringProperty] + [DefaultValue(false)] + public bool Multiplexing + { + get => _multiplexing; + set { - get => true; - set => throw new NotSupportedException("The UseSslStream parameter is no longer supported (SslStream is always used). Please see https://www.npgsql.org/doc/release-notes/4.1.html"); + _multiplexing = value; + SetValue(nameof(Multiplexing), value); } + } + bool _multiplexing; - /// - /// Writes connection performance information to performance counters. - /// - [Category("Advanced")] - [Description("Writes connection performance information to performance counters.")] - [DisplayName("Use Perf Counters")] - [NpgsqlConnectionStringProperty] - [Obsolete("The UsePerfCounters parameter is no longer supported")] - public bool UsePerfCounters + /// + /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before + /// flushing to the network. + /// + [Category("Multiplexing")] + [Description("When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before " + + "flushing to the network.")] + [DisplayName("Write Coalescing Buffer Threshold Bytes")] + [NpgsqlConnectionStringProperty] + [DefaultValue(1000)] + public int WriteCoalescingBufferThresholdBytes + { + get => _writeCoalescingBufferThresholdBytes; + set { - get => false; - set => throw new NotSupportedException("The UsePerfCounters parameter is no longer supported. Please see https://www.npgsql.org/doc/release-notes/5.0.html"); + _writeCoalescingBufferThresholdBytes = value; + SetValue(nameof(WriteCoalescingBufferThresholdBytes), value); } + } + int _writeCoalescingBufferThresholdBytes; - #endregion + #endregion - #region Misc + #region Properties - Compatibility - internal void Validate() + /// + /// A compatibility mode for special PostgreSQL server types. + /// + [Category("Compatibility")] + [Description("A compatibility mode for special PostgreSQL server types.")] + [DisplayName("Server Compatibility Mode")] + [NpgsqlConnectionStringProperty] + public ServerCompatibilityMode ServerCompatibilityMode + { + get => _serverCompatibilityMode; + set { - if (string.IsNullOrWhiteSpace(Host)) - throw new ArgumentException("Host can't be null"); - if (Multiplexing && !Pooling) - throw new ArgumentException("Pooling must be on to use multiplexing"); + _serverCompatibilityMode = value; + SetValue(nameof(ServerCompatibilityMode), value); } + } + ServerCompatibilityMode _serverCompatibilityMode; - internal string ToStringWithoutPassword() - { - var clone = Clone(); - clone.Password = null; - return clone.ToString(); - } + #endregion - internal NpgsqlConnectionStringBuilder Clone() => new NpgsqlConnectionStringBuilder(ConnectionString); + #region Properties - Obsolete - /// - /// Determines whether the specified object is equal to the current object. - /// - public override bool Equals(object? obj) - => obj is NpgsqlConnectionStringBuilder o && EquivalentTo(o); + /// + /// Whether to trust the server certificate without validating it. + /// + [Category("Security")] + [Description("Whether to trust the server certificate without validating it.")] + [DisplayName("Trust Server Certificate")] + [Obsolete("The TrustServerCertificate parameter is no longer needed and does nothing.")] + [NpgsqlConnectionStringProperty] + public bool TrustServerCertificate + { + get => _trustServerCertificate; + set + { + _trustServerCertificate = value; + SetValue(nameof(TrustServerCertificate), value); + } + } + bool _trustServerCertificate; - /// - /// Hash function. - /// - /// - public override int GetHashCode() => Host?.GetHashCode() ?? 0; + /// + /// The time to wait (in seconds) while trying to execute a an internal command before terminating the attempt and generating an error. + /// + [Category("Obsolete")] + [Description("The time to wait (in seconds) while trying to execute a an internal command before terminating the attempt and generating an error. -1 uses CommandTimeout, 0 means no timeout.")] + [DisplayName("Internal Command Timeout")] + [NpgsqlConnectionStringProperty] + [DefaultValue(-1)] + [Obsolete("The InternalCommandTimeout parameter is no longer needed and does nothing.")] + public int InternalCommandTimeout + { + get => _internalCommandTimeout; + set + { + if (value != 0 && value != -1 && value < NpgsqlConnector.MinimumInternalCommandTimeout) + throw new ArgumentOutOfRangeException(nameof(value), value, + $"InternalCommandTimeout must be >= {NpgsqlConnector.MinimumInternalCommandTimeout}, 0 (infinite) or -1 (use CommandTimeout)"); - #endregion + _internalCommandTimeout = value; + SetValue(nameof(InternalCommandTimeout), value); + } + } + int _internalCommandTimeout; - #region IDictionary + #endregion - /// - /// Gets an containing the keys of the . - /// - public new ICollection Keys => base.Keys.Cast().ToArray()!; + #region Misc - /// - /// Gets an containing the values in the . - /// - public new ICollection Values => base.Values.Cast().ToArray(); + internal void PostProcessAndValidate() + { + if (string.IsNullOrWhiteSpace(Host)) + throw new ArgumentException("Host can't be null"); + if (Multiplexing && !Pooling) + throw new ArgumentException("Pooling must be on to use multiplexing"); - /// - /// Copies the elements of the to an Array, starting at a particular Array index. - /// - /// - /// The one-dimensional Array that is the destination of the elements copied from . - /// The Array must have zero-based indexing. - /// - /// - /// The zero-based index in array at which copying begins. - /// - public void CopyTo(KeyValuePair[] array, int arrayIndex) + if (!Host.Contains(",")) { - foreach (var kv in this) - array[arrayIndex++] = kv; + if (TargetSessionAttributesParsed is not null && + TargetSessionAttributesParsed != Npgsql.TargetSessionAttributes.Any) + { + throw new NotSupportedException("Target Session Attributes other then Any is only supported with multiple hosts"); + } + + // Support single host:port format in Host + if (!IsUnixSocket(Host, Port, out _) && + TrySplitHostPort(Host.AsSpan(), out var newHost, out var newPort)) + { + Host = newHost; + Port = newPort; + } } + } - /// - /// Returns an enumerator that iterates through the . - /// - /// - public IEnumerator> GetEnumerator() + internal string ToStringWithoutPassword() + { + var clone = Clone(); + clone.Password = null; + return clone.ToString(); + } + + internal string ConnectionStringForMultipleHosts + { + get { - foreach (var k in Keys) - yield return new KeyValuePair(k, this[k]); + var clone = Clone(); + clone[nameof(TargetSessionAttributes)] = null; + return clone.ConnectionString; } + } - #endregion IDictionary + internal NpgsqlConnectionStringBuilder Clone() => new(ConnectionString); - #region ICustomTypeDescriptor -#nullable disable -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - protected override void GetProperties(Hashtable propertyDescriptors) + internal static bool TrySplitHostPort(ReadOnlySpan originalHost, [NotNullWhen(true)] out string? host, out int port) + { + var portSeparator = originalHost.LastIndexOf(':'); + if (portSeparator != -1) { - // Tweak which properties are exposed via TypeDescriptor. This affects the VS DDEX - // provider, for example. - base.GetProperties(propertyDescriptors); - - var toRemove = propertyDescriptors.Values - .Cast() - .Where(d => - !d.Attributes.Cast().Any(a => a is NpgsqlConnectionStringPropertyAttribute) || - d.Attributes.Cast().Any(a => a is ObsoleteAttribute) - ) - .ToList(); - foreach (var o in toRemove) - propertyDescriptors.Remove(o.DisplayName); + var otherColon = originalHost.Slice(0, portSeparator).LastIndexOf(':'); + var ipv6End = originalHost.LastIndexOf(']'); + if (otherColon == -1 || portSeparator > ipv6End && otherColon < ipv6End) + { + port = originalHost.Slice(portSeparator + 1).ParseInt(); + host = originalHost.Slice(0, portSeparator).ToString(); + return true; + } } -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member -#nullable enable - #endregion - internal static readonly string[] EmptyStringArray = new string[0]; + port = -1; + host = null; + return false; } - #region Attributes - - /// - /// Marks on which participate in the connection - /// string. Optionally holds a set of synonyms for the property. - /// - [AttributeUsage(AttributeTargets.Property)] - public class NpgsqlConnectionStringPropertyAttribute : Attribute + internal static bool IsUnixSocket(string host, int port, [NotNullWhen(true)] out string? socketPath, bool replaceForAbstract = true) { - /// - /// Holds a list of synonyms for the property. - /// - public string[] Synonyms { get; } + socketPath = null; + if (string.IsNullOrEmpty(host)) + return false; - /// - /// Creates a . - /// - public NpgsqlConnectionStringPropertyAttribute() + var isPathRooted = Path.IsPathRooted(host); + + if (host[0] == '@') { - Synonyms = NpgsqlConnectionStringBuilder.EmptyStringArray; + if (replaceForAbstract) + host = $"\0{host.Substring(1)}"; + isPathRooted = true; } - /// - /// Creates a . - /// - public NpgsqlConnectionStringPropertyAttribute(params string[] synonyms) + if (isPathRooted) { - Synonyms = synonyms; + socketPath = Path.Combine(host, $".s.PGSQL.{port}"); + return true; } + + return false; } + /// + /// Determines whether the specified object is equal to the current object. + /// + public override bool Equals(object? obj) + => obj is NpgsqlConnectionStringBuilder o && EquivalentTo(o); + + /// + /// Hash function. + /// + /// + public override int GetHashCode() => Host?.GetHashCode() ?? 0; + #endregion - #region Enums + #region IDictionary /// - /// An option specified in the connection string that activates special compatibility features. + /// Gets an containing the keys of the . /// - public enum ServerCompatibilityMode + public new ICollection Keys { - /// - /// No special server compatibility mode is active - /// - None, - /// - /// The server is an Amazon Redshift instance. - /// - Redshift, - /// - /// The server is doesn't support full type loading from the PostgreSQL catalogs, support the basic set - /// of types via information hardcoded inside Npgsql. - /// - NoTypeLoading, + get + { + var result = new string[base.Keys.Count]; + var i = 0; + foreach (var key in base.Keys) + result[i++] = (string)key; + return result; + } } /// - /// Specifies how to manage SSL. + /// Gets an containing the values in the . /// - public enum SslMode + public new ICollection Values { - /// - /// SSL is disabled. If the server requires SSL, the connection will fail. - /// - Disable, - /// - /// Prefer SSL connections if the server allows them, but allow connections without SSL. - /// - Prefer, - /// - /// Fail the connection if the server doesn't support SSL. - /// - Require, + get + { + var result = new object?[base.Keys.Count]; + var i = 0; + foreach (var key in base.Values) + result[i++] = (object?)key; + return result; + } } /// - /// Specifies whether the connection shall be initialized as a physical or - /// logical replication connection + /// Copies the elements of the to an Array, starting at a particular Array index. /// - /// - /// This enum and its corresponding property are intentionally kept internal as they - /// should not be set by users or even be visible in their connection strings. - /// Replication connections are a special kind of connection that is encapsulated in - /// - /// and . - /// - enum ReplicationMode - { - /// - /// Replication disabled. This is the default - /// - Off, - /// - /// Physical replication enabled - /// - Physical, - /// - /// Logical replication enabled - /// - Logical + /// + /// The one-dimensional Array that is the destination of the elements copied from . + /// The Array must have zero-based indexing. + /// + /// + /// The zero-based index in array at which copying begins. + /// + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + foreach (var kv in this) + array[arrayIndex++] = kv; } + + /// + /// Returns an enumerator that iterates through the . + /// + /// + public IEnumerator> GetEnumerator() + { + foreach (var k in Keys) + yield return new KeyValuePair(k, this[k]); + } + + #endregion IDictionary + + #region ICustomTypeDescriptor + + /// + [RequiresUnreferencedCode("PropertyDescriptor's PropertyType cannot be statically discovered.")] + protected override void GetProperties(Hashtable propertyDescriptors) + { + // Tweak which properties are exposed via TypeDescriptor. This affects the VS DDEX + // provider, for example. + base.GetProperties(propertyDescriptors); + + var toRemove = new List(); + foreach (var value in propertyDescriptors.Values) + { + var d = (PropertyDescriptor)value; + foreach (var attribute in d.Attributes) + if (attribute is NpgsqlConnectionStringPropertyAttribute or ObsoleteAttribute) + toRemove.Add(d); + } + + foreach (var o in toRemove) + propertyDescriptors.Remove(o.DisplayName); + } + #endregion } + +#region Attributes + +/// +/// Marks on which participate in the connection +/// string. Optionally holds a set of synonyms for the property. +/// +[AttributeUsage(AttributeTargets.Property)] +sealed class NpgsqlConnectionStringPropertyAttribute : Attribute +{ + /// + /// Holds a list of synonyms for the property. + /// + public string[] Synonyms { get; } + + /// + /// Creates a . + /// + public NpgsqlConnectionStringPropertyAttribute() + => Synonyms = Array.Empty(); + + /// + /// Creates a . + /// + public NpgsqlConnectionStringPropertyAttribute(params string[] synonyms) + => Synonyms = synonyms; +} + +#endregion + +#region Enums + +/// +/// An option specified in the connection string that activates special compatibility features. +/// +public enum ServerCompatibilityMode +{ + /// + /// No special server compatibility mode is active + /// + None, + /// + /// The server is an Amazon Redshift instance. + /// + Redshift, + /// + /// The server is doesn't support full type loading from the PostgreSQL catalogs, support the basic set + /// of types via information hardcoded inside Npgsql. + /// + NoTypeLoading, +} + +/// +/// Specifies how to manage SSL. +/// +public enum SslMode +{ + /// + /// SSL is disabled. If the server requires SSL, the connection will fail. + /// + Disable, + /// + /// Prefer non-SSL connections if the server allows them, but allow SSL connections. + /// + Allow, + /// + /// Prefer SSL connections if the server allows them, but allow connections without SSL. + /// + Prefer, + /// + /// Fail the connection if the server doesn't support SSL. + /// + Require, + /// + /// Fail the connection if the server doesn't support SSL. Also verifies server certificate. + /// + VerifyCA, + /// + /// Fail the connection if the server doesn't support SSL. Also verifies server certificate with host's name. + /// + VerifyFull +} + +/// +/// Specifies how to manage channel binding. +/// +public enum ChannelBinding +{ + /// + /// Channel binding is disabled. If the server requires channel binding, the connection will fail. + /// + Disable, + /// + /// Prefer channel binding if the server allows it, but connect without it if not. + /// + Prefer, + /// + /// Fail the connection if the server doesn't support channel binding. + /// + Require +} + +/// +/// Specifies how the mapping of arrays of +/// value types +/// behaves with respect to nullability when they are requested via an API returning an . +/// +public enum ArrayNullabilityMode +{ + /// + /// Arrays of value types are always returned as non-nullable arrays (e.g. int[]). + /// If the PostgreSQL array contains a NULL value, an exception is thrown. This is the default mode. + /// + Never, + /// + /// Arrays of value types are always returned as nullable arrays (e.g. int?[]). + /// + Always, + /// + /// The type of array that gets returned is determined at runtime. + /// Arrays of value types are returned as non-nullable arrays (e.g. int[]) + /// if the actual instance that gets returned doesn't contain null values + /// and as nullable arrays (e.g. int?[]) if it does. + /// + /// When using this setting, make sure that your code is prepared to the fact + /// that the actual type of array instances returned from APIs like + /// may change on a row by row base. + PerInstance, +} + +/// +/// Specifies whether the connection shall be initialized as a physical or +/// logical replication connection +/// +/// +/// This enum and its corresponding property are intentionally kept internal as they +/// should not be set by users or even be visible in their connection strings. +/// Replication connections are a special kind of connection that is encapsulated in +/// +/// and . +/// +enum ReplicationMode +{ + /// + /// Replication disabled. This is the default + /// + Off, + /// + /// Physical replication enabled + /// + Physical, + /// + /// Logical replication enabled + /// + Logical +} + +#endregion diff --git a/src/Npgsql/NpgsqlConnector.Auth.cs b/src/Npgsql/NpgsqlConnector.Auth.cs deleted file mode 100644 index b53594cf00..0000000000 --- a/src/Npgsql/NpgsqlConnector.Auth.cs +++ /dev/null @@ -1,483 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net; -using System.Net.Security; -using System.Security.Cryptography; -using System.Security.Cryptography.X509Certificates; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Util; -using static Npgsql.Util.Statics; - -namespace Npgsql -{ - partial class NpgsqlConnector - { - async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - Log.Trace("Authenticating...", Id); - - timeout.CheckAndApply(this); - var msg = Expect(await ReadMessage(async), this); - switch (msg.AuthRequestType) - { - case AuthenticationRequestType.AuthenticationOk: - return; - - case AuthenticationRequestType.AuthenticationCleartextPassword: - await AuthenticateCleartext(username, async, cancellationToken); - return; - - case AuthenticationRequestType.AuthenticationMD5Password: - await AuthenticateMD5(username, ((AuthenticationMD5PasswordMessage)msg).Salt, async, cancellationToken); - return; - - case AuthenticationRequestType.AuthenticationSASL: - await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async, cancellationToken); - return; - - case AuthenticationRequestType.AuthenticationGSS: - case AuthenticationRequestType.AuthenticationSSPI: - await AuthenticateGSS(async); - return; - - case AuthenticationRequestType.AuthenticationGSSContinue: - throw new NpgsqlException("Can't start auth cycle with AuthenticationGSSContinue"); - - default: - throw new NotSupportedException($"Authentication method not supported (Received: {msg.AuthRequestType})"); - } - } - - async Task AuthenticateCleartext(string username, bool async, CancellationToken cancellationToken = default) - { - var passwd = GetPassword(username); - if (passwd == null) - throw new NpgsqlException("No password has been provided but the backend requires one (in cleartext)"); - - var encoded = new byte[Encoding.UTF8.GetByteCount(passwd) + 1]; - Encoding.UTF8.GetBytes(passwd, 0, passwd.Length, encoded, 0); - - await WritePassword(encoded, async, cancellationToken); - await Flush(async, cancellationToken); - Expect(await ReadMessage(async), this); - } - - async Task AuthenticateSASL(List mechanisms, string username, bool async, CancellationToken cancellationToken = default) - { - // At the time of writing PostgreSQL only supports SCRAM-SHA-256 and SCRAM-SHA-256-PLUS - var supportsSha256 = mechanisms.Contains("SCRAM-SHA-256"); - var supportsSha256Plus = mechanisms.Contains("SCRAM-SHA-256-PLUS"); - if (!supportsSha256 && !supportsSha256Plus) - throw new NpgsqlException("No supported SASL mechanism found (only SCRAM-SHA-256 and SCRAM-SHA-256-PLUS are supported for now). " + - "Mechanisms received from server: " + string.Join(", ", mechanisms)); - - var mechanism = string.Empty; - var cbindFlag = string.Empty; - var cbind = string.Empty; - var successfulBind = false; - - if (supportsSha256Plus) - { - var sslStream = (SslStream)_stream; - if (sslStream.RemoteCertificate is null) - { - Log.Warn("Remote certificate null, falling back to SCRAM-SHA-256"); - } - else - { - using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate); - // Checking for hashing algorithms - HashAlgorithm? hashAlgorithm = null; - var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName; - if (algorithmName is null) - { - Log.Warn("Signature algorithm was null, falling back to SCRAM-SHA-256"); - } - else if (algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) || - algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) || - algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase)) - { - hashAlgorithm = SHA256.Create(); - } - else if (algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase)) - { - hashAlgorithm = SHA384.Create(); - } - else if (algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase)) - { - hashAlgorithm = SHA512.Create(); - } - else - { - Log.Warn($"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256"); - } - - if (hashAlgorithm != null) - { - using var _ = hashAlgorithm; - - // RFC 5929 - mechanism = "SCRAM-SHA-256-PLUS"; - // PostgreSQL only supports tls-server-end-point binding - cbindFlag = "p=tls-server-end-point"; - // SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine - var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,"); - - var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData()); - var cbindBytes = cbindFlagBytes.Concat(certificateHash).ToArray(); - cbind = Convert.ToBase64String(cbindBytes); - successfulBind = true; - IsScramPlus = true; - } - } - } - - if (!successfulBind && supportsSha256) - { - mechanism = "SCRAM-SHA-256"; - // We can get here if PostgreSQL supports only SCRAM-SHA-256 or there was an error while binding to SCRAM-SHA-256-PLUS - // So, we set 'n' (client does not support binding) if there was an error while binding - // or 'y' (client supports but server doesn't) in other case - cbindFlag = supportsSha256Plus ? "n" : "y"; - cbind = supportsSha256Plus ? "biws" : "eSws"; - successfulBind = true; - IsScram = true; - } - - if (!successfulBind) - { - // We can get here if PostgreSQL supports only SCRAM-SHA-256-PLUS but there was an error while binding to it - throw new NpgsqlException("Unable to bind to SCRAM-SHA-256-PLUS, check logs for more information"); - } - - var passwd = GetPassword(username) ?? - throw new NpgsqlException($"No password has been provided but the backend requires one (in SASL/{mechanism})"); - - // Assumption: the write buffer is big enough to contain all our outgoing messages - var clientNonce = GetNonce(); - - await WriteSASLInitialResponse(mechanism, PGUtil.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken); - await Flush(async, cancellationToken); - - var saslContinueMsg = Expect(await ReadMessage(async), this); - if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLContinue) - throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); - var firstServerMsg = AuthenticationSCRAMServerFirstMessage.Load(saslContinueMsg.Payload); - if (!firstServerMsg.Nonce.StartsWith(clientNonce)) - throw new NpgsqlException("[SCRAM] Malformed SCRAMServerFirst message: server nonce doesn't start with client nonce"); - - var saltBytes = Convert.FromBase64String(firstServerMsg.Salt); - var saltedPassword = Hi(passwd.Normalize(NormalizationForm.FormKC), saltBytes, firstServerMsg.Iteration); - - var clientKey = HMAC(saltedPassword, "Client Key"); - byte[] storedKey; - using (var sha256 = SHA256.Create()) - storedKey = sha256.ComputeHash(clientKey); - - var clientFirstMessageBare = $"n=*,r={clientNonce}"; - var serverFirstMessage = $"r={firstServerMsg.Nonce},s={firstServerMsg.Salt},i={firstServerMsg.Iteration}"; - var clientFinalMessageWithoutProof = $"c={cbind},r={firstServerMsg.Nonce}"; - - var authMessage = $"{clientFirstMessageBare},{serverFirstMessage},{clientFinalMessageWithoutProof}"; - - var clientSignature = HMAC(storedKey, authMessage); - var clientProofBytes = Xor(clientKey, clientSignature); - var clientProof = Convert.ToBase64String(clientProofBytes); - - var serverKey = HMAC(saltedPassword, "Server Key"); - var serverSignature = HMAC(serverKey, authMessage); - - var messageStr = $"{clientFinalMessageWithoutProof},p={clientProof}"; - - await WriteSASLResponse(Encoding.UTF8.GetBytes(messageStr), async, cancellationToken); - await Flush(async, cancellationToken); - - var saslFinalServerMsg = Expect(await ReadMessage(async), this); - if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLFinal) - throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); - - var scramFinalServerMsg = AuthenticationSCRAMServerFinalMessage.Load(saslFinalServerMsg.Payload); - if (scramFinalServerMsg.ServerSignature != Convert.ToBase64String(serverSignature)) - throw new NpgsqlException("[SCRAM] Unable to verify server signature"); - - var okMsg = Expect(await ReadMessage(async), this); - if (okMsg.AuthRequestType != AuthenticationRequestType.AuthenticationOk) - throw new NpgsqlException("[SASL] Expected AuthenticationOK message"); - - static string GetNonce() - { - using var rncProvider = RandomNumberGenerator.Create(); - var nonceBytes = new byte[18]; - - rncProvider.GetBytes(nonceBytes); - return Convert.ToBase64String(nonceBytes); - } - - static byte[] Hi(string str, byte[] salt, int count) - { - using var hmac = new HMACSHA256(Encoding.UTF8.GetBytes(str)); - var salt1 = new byte[salt.Length + 4]; - byte[] hi, u1; - - Buffer.BlockCopy(salt, 0, salt1, 0, salt.Length); - salt1[salt1.Length - 1] = 1; - - hi = u1 = hmac.ComputeHash(salt1); - - for (var i = 1; i < count; i++) - { - var u2 = hmac.ComputeHash(u1); - Xor(hi, u2); - u1 = u2; - } - - return hi; - } - - static byte[] Xor(byte[] buffer1, byte[] buffer2) - { - for (var i = 0; i < buffer1.Length; i++) - buffer1[i] ^= buffer2[i]; - return buffer1; - } - - static byte[] HMAC(byte[] data, string key) - { - using var hmacsha256 = new HMACSHA256(data); - return hmacsha256.ComputeHash(Encoding.UTF8.GetBytes(key)); - } - } - - async Task AuthenticateMD5(string username, byte[] salt, bool async, CancellationToken cancellationToken = default) - { - var passwd = GetPassword(username); - if (passwd == null) - throw new NpgsqlException("No password has been provided but the backend requires one (in MD5)"); - - byte[] result; - using (var md5 = MD5.Create()) - { - // First phase - var passwordBytes = PGUtil.UTF8Encoding.GetBytes(passwd); - var usernameBytes = PGUtil.UTF8Encoding.GetBytes(username); - var cryptBuf = new byte[passwordBytes.Length + usernameBytes.Length]; - passwordBytes.CopyTo(cryptBuf, 0); - usernameBytes.CopyTo(cryptBuf, passwordBytes.Length); - - var sb = new StringBuilder(); - var hashResult = md5.ComputeHash(cryptBuf); - foreach (var b in hashResult) - sb.Append(b.ToString("x2")); - - var prehash = sb.ToString(); - - var prehashbytes = PGUtil.UTF8Encoding.GetBytes(prehash); - cryptBuf = new byte[prehashbytes.Length + 4]; - - Array.Copy(salt, 0, cryptBuf, prehashbytes.Length, 4); - - // 2. - prehashbytes.CopyTo(cryptBuf, 0); - - sb = new StringBuilder("md5"); - hashResult = md5.ComputeHash(cryptBuf); - foreach (var b in hashResult) - sb.Append(b.ToString("x2")); - - var resultString = sb.ToString(); - result = new byte[Encoding.UTF8.GetByteCount(resultString) + 1]; - Encoding.UTF8.GetBytes(resultString, 0, resultString.Length, result, 0); - result[result.Length - 1] = 0; - } - - await WritePassword(result, async, cancellationToken); - await Flush(async, cancellationToken); - Expect(await ReadMessage(async), this); - } - - async Task AuthenticateGSS(bool async) - { - if (!IntegratedSecurity) - throw new NpgsqlException("GSS/SSPI authentication but IntegratedSecurity not enabled"); - - using var negotiateStream = new NegotiateStream(new GSSPasswordMessageStream(this), true); - try - { - var targetName = $"{KerberosServiceName}/{Host}"; - if (async) - await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName); - else - negotiateStream.AuthenticateAsClient(CredentialCache.DefaultNetworkCredentials, targetName); - } - catch (AuthenticationCompleteException) - { - return; - } - catch (IOException e) when (e.InnerException is AuthenticationCompleteException) - { - return; - } - catch (IOException e) when (e.InnerException is PostgresException) - { - throw e.InnerException; - } - - throw new NpgsqlException("NegotiateStream.AuthenticateAsClient completed unexpectedly without signaling success"); - } - - /// - /// This Stream is placed between NegotiateStream and the socket's NetworkStream (or SSLStream). It intercepts - /// traffic and performs the following operations: - /// * Outgoing messages are framed in PostgreSQL's PasswordMessage, and incoming are stripped of it. - /// * NegotiateStream frames payloads with a 5-byte header, which PostgreSQL doesn't understand. This header is - /// stripped from outgoing messages and added to incoming ones. - /// - /// - /// See https://referencesource.microsoft.com/#System/net/System/Net/_StreamFramer.cs,16417e735f0e9530,references - /// - class GSSPasswordMessageStream : Stream - { - readonly NpgsqlConnector _connector; - int _leftToWrite; - int _leftToRead, _readPos; - byte[]? _readBuf; - - internal GSSPasswordMessageStream(NpgsqlConnector connector) - => _connector = connector; - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => Write(buffer, offset, count, true, cancellationToken); - - public override void Write(byte[] buffer, int offset, int count) - => Write(buffer, offset, count, false).GetAwaiter().GetResult(); - - async Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (_leftToWrite == 0) - { - // We're writing the frame header, which contains the payload size. - _leftToWrite = (buffer[3] << 8) | buffer[4]; - - buffer[0] = 22; - if (buffer[1] != 1) - throw new NotSupportedException($"Received frame header major v {buffer[1]} (different from 1)"); - if (buffer[2] != 0) - throw new NotSupportedException($"Received frame header minor v {buffer[2]} (different from 0)"); - - // In case of payload data in the same buffer just after the frame header - if (count == 5) - return; - count -= 5; - offset += 5; - } - - if (count > _leftToWrite) - throw new NpgsqlException($"NegotiateStream trying to write {count} bytes but according to frame header we only have {_leftToWrite} left!"); - await _connector.WritePassword(buffer, offset, count, async, cancellationToken); - await _connector.Flush(async, cancellationToken); - _leftToWrite -= count; - } - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => Read(buffer, offset, count, true, cancellationToken); - - public override int Read(byte[] buffer, int offset, int count) - => Read(buffer, offset, count, false).GetAwaiter().GetResult(); - - async Task Read(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (_leftToRead == 0) - { - var response = Expect(await _connector.ReadMessage(async), _connector); - if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) - throw new AuthenticationCompleteException(); - var gssMsg = response as AuthenticationGSSContinueMessage; - if (gssMsg == null) - throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); - _readBuf = gssMsg.AuthenticationData; - _leftToRead = gssMsg.AuthenticationData.Length; - _readPos = 0; - buffer[0] = 22; - buffer[1] = 1; - buffer[2] = 0; - buffer[3] = (byte)((_leftToRead >> 8) & 0xFF); - buffer[4] = (byte)(_leftToRead & 0xFF); - return 5; - } - - if (count > _leftToRead) - throw new NpgsqlException($"NegotiateStream trying to read {count} bytes but according to frame header we only have {_leftToRead} left!"); - count = Math.Min(count, _leftToRead); - Array.Copy(_readBuf!, _readPos, buffer, offset, count); - _leftToRead -= count; - return count; - } - - public override void Flush() { } - - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - - public override bool CanRead => true; - public override bool CanWrite => true; - public override bool CanSeek => false; - public override long Length => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - } - - class AuthenticationCompleteException : Exception { } - - string? GetPassword(string username) - { - var password = Settings.Password; - if (password != null) - return password; - - if (ProvidePasswordCallback is { } passwordCallback) - try - { - Log.Trace($"Taking password from {nameof(ProvidePasswordCallback)} delegate"); - password = passwordCallback(Host, Port, Settings.Database!, username); - } - catch (Exception e) - { - throw new NpgsqlException($"Obtaining password using {nameof(NpgsqlConnection)}.{nameof(ProvidePasswordCallback)} delegate failed", e); - } - - if (password is null) - password = PostgresEnvironment.Password; - - if (password != null) - return password; - - var passFile = Settings.Passfile ?? PostgresEnvironment.PassFile; - if (passFile is null && PostgresEnvironment.PassFileDefault is string passFileDefault) - { - passFile = passFileDefault; - } - - if (passFile != null) - { - var matchingEntry = new PgPassFile(passFile!) - .GetFirstMatchingEntry(Host, Port, Settings.Database!, username); - if (matchingEntry != null) - { - Log.Trace("Taking password from pgpass file"); - password = matchingEntry.Password; - } - } - - return password; - } - } -} diff --git a/src/Npgsql/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/NpgsqlConnector.FrontendMessages.cs deleted file mode 100644 index 4f7f90ec13..0000000000 --- a/src/Npgsql/NpgsqlConnector.FrontendMessages.cs +++ /dev/null @@ -1,468 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Util; -// ReSharper disable VariableHidesOuterVariable - -namespace Npgsql -{ - partial class NpgsqlConnector - { - internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(name.All(c => c < 128)); - - var len = sizeof(byte) + // Message code - sizeof(int) + // Length - sizeof(byte) + // Statement or portal - (name.Length + 1); // Statement/portal name - - if (WriteBuffer.WriteSpaceLeft < len) - return FlushAndWrite(len, statementOrPortal, name, async); - - Write(len, statementOrPortal, name); - return Task.CompletedTask; - - async Task FlushAndWrite(int len, StatementOrPortal statementOrPortal, string name, bool async) - { - await Flush(async, cancellationToken); - Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); - Write(len, statementOrPortal, name); - } - - void Write(int len, StatementOrPortal statementOrPortal, string name) - { - WriteBuffer.WriteByte(FrontendMessageCode.Describe); - WriteBuffer.WriteInt32(len - 1); - WriteBuffer.WriteByte((byte)statementOrPortal); - WriteBuffer.WriteNullTerminatedString(name); - } - } - - internal Task WriteSync(bool async, CancellationToken cancellationToken = default) - { - const int len = sizeof(byte) + // Message code - sizeof(int); // Length - - if (WriteBuffer.WriteSpaceLeft < len) - return FlushAndWrite(async); - - Write(); - return Task.CompletedTask; - - async Task FlushAndWrite(bool async) - { - await Flush(async, cancellationToken); - Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); - Write(); - } - - void Write() - { - WriteBuffer.WriteByte(FrontendMessageCode.Sync); - WriteBuffer.WriteInt32(len - 1); - } - } - - internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellationToken = default) - { - // Note: non-empty portal currently not supported - - const int len = sizeof(byte) + // Message code - sizeof(int) + // Length - sizeof(byte) + // Null-terminated portal name (always empty for now) - sizeof(int); // Max number of rows - - if (WriteBuffer.WriteSpaceLeft < len) - return FlushAndWrite(maxRows, async); - - Write(maxRows); - return Task.CompletedTask; - - async Task FlushAndWrite(int maxRows, bool async) - { - await Flush(async, cancellationToken); - Debug.Assert(10 <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length 10 which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); - Write(maxRows); - } - - void Write(int maxRows) - { - WriteBuffer.WriteByte(FrontendMessageCode.Execute); - WriteBuffer.WriteInt32(len - 1); - WriteBuffer.WriteByte(0); // Portal is always empty for now - WriteBuffer.WriteInt32(maxRows); - } - } - - internal async Task WriteParse(string sql, string statementName, List inputParameters, bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(statementName.All(c => c < 128)); - - int queryByteLen; - try - { - queryByteLen = TextEncoding.GetByteCount(sql); - } - catch (Exception e) - { - Break(e); - throw; - } - - if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1) - await Flush(async, cancellationToken); - - var messageLength = - sizeof(byte) + // Message code - sizeof(int) + // Length - statementName.Length + // Statement name - sizeof(byte) + // Null terminator for the statement name - queryByteLen + sizeof(byte) + // SQL query length plus null terminator - sizeof(ushort) + // Number of parameters - inputParameters.Count * sizeof(int); // Parameter OIDs - - WriteBuffer.WriteByte(FrontendMessageCode.Parse); - WriteBuffer.WriteInt32(messageLength - 1); - WriteBuffer.WriteNullTerminatedString(statementName); - - await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken); - - if (WriteBuffer.WriteSpaceLeft < 1 + 2) - await Flush(async, cancellationToken); - WriteBuffer.WriteByte(0); // Null terminator for the query - WriteBuffer.WriteUInt16((ushort)inputParameters.Count); - - foreach (var p in inputParameters) - { - if (WriteBuffer.WriteSpaceLeft < 4) - await Flush(async, cancellationToken); - - WriteBuffer.WriteInt32((int)p.Handler!.PostgresType.OID); - } - } - - internal async Task WriteBind( - List inputParameters, - string portal, - string statement, - bool allResultTypesAreUnknown, - bool[]? unknownResultTypeList, - bool async, - CancellationToken cancellationToken = default) - { - Debug.Assert(statement.All(c => c < 128)); - Debug.Assert(portal.All(c => c < 128)); - - var headerLength = - sizeof(byte) + // Message code - sizeof(int) + // Message length - sizeof(byte) + // Portal is always empty (only a null terminator) - statement.Length + sizeof(byte) + // Statement name plus null terminator - sizeof(ushort); // Number of parameter format codes that follow - - if (WriteBuffer.WriteSpaceLeft < headerLength) - { - Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header"); - await Flush(async, cancellationToken); - } - - var formatCodesSum = 0; - var paramsLength = 0; - foreach (var p in inputParameters) - { - formatCodesSum += (int)p.FormatCode; - p.LengthCache?.Rewind(); - paramsLength += p.ValidateAndGetLength(); - } - - var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == inputParameters.Count ? 1 : inputParameters.Count; - - var messageLength = headerLength + - sizeof(short) * formatCodeListLength + // List of format codes - sizeof(short) + // Number of parameters - sizeof(int) * inputParameters.Count + // Parameter lengths - paramsLength + // Parameter values - sizeof(short) + // Number of result format codes - sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes - - WriteBuffer.WriteByte(FrontendMessageCode.Bind); - WriteBuffer.WriteInt32(messageLength - 1); - Debug.Assert(portal == string.Empty); - WriteBuffer.WriteByte(0); // Portal is always empty - - WriteBuffer.WriteNullTerminatedString(statement); - WriteBuffer.WriteInt16(formatCodeListLength); - - // 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match - if (formatCodeListLength == 1) - { - if (WriteBuffer.WriteSpaceLeft < 2) - await Flush(async, cancellationToken); - WriteBuffer.WriteInt16((short)FormatCode.Binary); - } - else if (formatCodeListLength > 1) - { - foreach (var p in inputParameters) - { - if (WriteBuffer.WriteSpaceLeft < 2) - await Flush(async, cancellationToken); - WriteBuffer.WriteInt16((short)p.FormatCode); - } - } - - if (WriteBuffer.WriteSpaceLeft < 2) - await Flush(async, cancellationToken); - - WriteBuffer.WriteUInt16((ushort)inputParameters.Count); - - foreach (var param in inputParameters) - { - param.LengthCache?.Rewind(); - await param.WriteWithLength(WriteBuffer, async, cancellationToken); - } - - if (unknownResultTypeList != null) - { - if (WriteBuffer.WriteSpaceLeft < 2 + unknownResultTypeList.Length * 2) - await Flush(async, cancellationToken); - WriteBuffer.WriteInt16(unknownResultTypeList.Length); - foreach (var t in unknownResultTypeList) - WriteBuffer.WriteInt16(t ? 0 : 1); - } - else - { - if (WriteBuffer.WriteSpaceLeft < 4) - await Flush(async, cancellationToken); - WriteBuffer.WriteInt16(1); - WriteBuffer.WriteInt16(allResultTypesAreUnknown ? 0 : 1); - } - } - - internal Task WriteClose(StatementOrPortal type, string name, bool async, CancellationToken cancellationToken = default) - { - var len = sizeof(byte) + // Message code - sizeof(int) + // Length - sizeof(byte) + // Statement or portal - name.Length + sizeof(byte); // Statement or portal name plus null terminator - - if (WriteBuffer.WriteSpaceLeft < 10) - return FlushAndWrite(len, type, name, async); - - Write(len, type, name); - return Task.CompletedTask; - - async Task FlushAndWrite(int len, StatementOrPortal type, string name, bool async) - { - await Flush(async, cancellationToken); - Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); - Write(len, type, name); - } - - void Write(int len, StatementOrPortal type, string name) - { - WriteBuffer.WriteByte(FrontendMessageCode.Close); - WriteBuffer.WriteInt32(len - 1); - WriteBuffer.WriteByte((byte)type); - WriteBuffer.WriteNullTerminatedString(name); - } - } - - internal void WriteQuery(string sql) => WriteQuery(sql, false).GetAwaiter().GetResult(); - - internal async Task WriteQuery(string sql, bool async, CancellationToken cancellationToken = default) - { - var queryByteLen = TextEncoding.GetByteCount(sql); - - if (WriteBuffer.WriteSpaceLeft < 1 + 4) - await Flush(async, cancellationToken); - - WriteBuffer.WriteByte(FrontendMessageCode.Query); - WriteBuffer.WriteInt32( - sizeof(int) + // Message length (including self excluding code) - queryByteLen + // Query byte length - sizeof(byte)); // Null terminator - - await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken); - if (WriteBuffer.WriteSpaceLeft < 1) - await Flush(async, cancellationToken); - WriteBuffer.WriteByte(0); // Null terminator - } - - internal void WriteCopyDone() => WriteCopyDone(false).GetAwaiter().GetResult(); - - internal async Task WriteCopyDone(bool async, CancellationToken cancellationToken = default) - { - const int len = sizeof(byte) + // Message code - sizeof(int); // Length - - if (WriteBuffer.WriteSpaceLeft < len) - await Flush(async, cancellationToken); - - WriteBuffer.WriteByte(FrontendMessageCode.CopyDone); - WriteBuffer.WriteInt32(len - 1); - } - - internal async Task WriteCopyFail(bool async, CancellationToken cancellationToken = default) - { - // Note: error message not supported for now - - const int len = sizeof(byte) + // Message code - sizeof(int) + // Length - sizeof(byte); // Error message is always empty (only a null terminator) - - if (WriteBuffer.WriteSpaceLeft < len) - await Flush(async, cancellationToken); - - WriteBuffer.WriteByte(FrontendMessageCode.CopyFail); - WriteBuffer.WriteInt32(len - 1); - WriteBuffer.WriteByte(0); // Error message is always empty (only a null terminator) - } - - internal void WriteCancelRequest(int backendProcessId, int backendSecretKey) - { - const int len = sizeof(int) + // Length - sizeof(int) + // Cancel request code - sizeof(int) + // Backend process id - sizeof(int); // Backend secret key - - Debug.Assert(backendProcessId != 0); - - if (WriteBuffer.WriteSpaceLeft < len) - Flush(false).GetAwaiter().GetResult(); - - WriteBuffer.WriteInt32(len); - WriteBuffer.WriteInt32(1234 << 16 | 5678); - WriteBuffer.WriteInt32(backendProcessId); - WriteBuffer.WriteInt32(backendSecretKey); - } - - internal void WriteTerminate() - { - const int len = sizeof(byte) + // Message code - sizeof(int); // Length - - if (WriteBuffer.WriteSpaceLeft < len) - Flush(false).GetAwaiter().GetResult(); - - WriteBuffer.WriteByte(FrontendMessageCode.Terminate); - WriteBuffer.WriteInt32(len - 1); - } - - internal void WriteSslRequest() - { - const int len = sizeof(int) + // Length - sizeof(int); // SSL request code - - if (WriteBuffer.WriteSpaceLeft < len) - Flush(false).GetAwaiter().GetResult(); - - WriteBuffer.WriteInt32(len); - WriteBuffer.WriteInt32(80877103); - } - - internal void WriteStartup(Dictionary parameters) - { - const int protocolVersion3 = 3 << 16; // 196608 - - var len = sizeof(int) + // Length - sizeof(int) + // Protocol version - sizeof(byte); // Trailing zero byte - - foreach (var kvp in parameters) - len += PGUtil.UTF8Encoding.GetByteCount(kvp.Key) + 1 + - PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1; - - // Should really never happen, just in case - if (len > WriteBuffer.Size) - throw new Exception("Startup message bigger than buffer"); - - WriteBuffer.WriteInt32(len); - WriteBuffer.WriteInt32(protocolVersion3); - - foreach (var kv in parameters) - { - WriteBuffer.WriteString(kv.Key); - WriteBuffer.WriteByte(0); - WriteBuffer.WriteString(kv.Value); - WriteBuffer.WriteByte(0); - } - - WriteBuffer.WriteByte(0); - } - - #region Authentication - - internal Task WritePassword(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, 0, payload.Length, async, cancellationToken); - - internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int)) - await WriteBuffer.Flush(async, cancellationToken); - WriteBuffer.WriteByte(FrontendMessageCode.Password); - WriteBuffer.WriteInt32(sizeof(int) + count); - - if (count <= WriteBuffer.WriteSpaceLeft) - { - // The entire array fits in our WriteBuffer, copy it into the WriteBuffer as usual. - WriteBuffer.WriteBytes(payload, offset, count); - return; - } - - await WriteBuffer.Flush(async, cancellationToken); - await WriteBuffer.DirectWrite(new ReadOnlyMemory(payload, offset, count), async, cancellationToken); - } - - internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialResponse, bool async, CancellationToken cancellationToken = default) - { - var len = sizeof(byte) + // Message code - sizeof(int) + // Length - PGUtil.UTF8Encoding.GetByteCount(mechanism) + sizeof(byte) + // Mechanism plus null terminator - sizeof(int) + // Initial response length - (initialResponse?.Length ?? 0); // Initial response payload - - if (WriteBuffer.WriteSpaceLeft < len) - await WriteBuffer.Flush(async, cancellationToken); - - WriteBuffer.WriteByte(FrontendMessageCode.Password); - WriteBuffer.WriteInt32(len - 1); - - WriteBuffer.WriteString(mechanism); - WriteBuffer.WriteByte(0); // null terminator - if (initialResponse == null) - WriteBuffer.WriteInt32(-1); - else - { - WriteBuffer.WriteInt32(initialResponse.Length); - WriteBuffer.WriteBytes(initialResponse); - } - } - - internal Task WriteSASLResponse(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, async, cancellationToken); - - #endregion Authentication - - internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default) - { - if (WriteBuffer.WriteSpaceLeft < data.Length) - return FlushAndWrite(data, async); - - WriteBuffer.WriteBytes(data, 0, data.Length); - return Task.CompletedTask; - - async Task FlushAndWrite(byte[] data, bool async) - { - await Flush(async, cancellationToken); - Debug.Assert(data.Length <= WriteBuffer.WriteSpaceLeft, $"Pregenerated message has length {data.Length} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})"); - WriteBuffer.WriteBytes(data, 0, data.Length); - } - } - - internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult(); - - internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken); - } -} diff --git a/src/Npgsql/NpgsqlConnector.cs b/src/Npgsql/NpgsqlConnector.cs deleted file mode 100644 index 51511de613..0000000000 --- a/src/Npgsql/NpgsqlConnector.cs +++ /dev/null @@ -1,2394 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Linq; -using System.Net; -using System.Net.Security; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Runtime.ExceptionServices; -using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; -using System.Text; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.Logging; -using Npgsql.TypeMapping; -using Npgsql.Util; -using static Npgsql.Util.Statics; - -namespace Npgsql -{ - /// - /// Represents a connection to a PostgreSQL backend. Unlike NpgsqlConnection objects, which are - /// exposed to users, connectors are internal to Npgsql and are recycled by the connection pool. - /// - sealed partial class NpgsqlConnector : IDisposable - { - #region Fields and Properties - - /// - /// The physical connection socket to the backend. - /// - Socket _socket = default!; - - /// - /// The physical connection stream to the backend, without anything on top. - /// - NetworkStream _baseStream = default!; - - /// - /// The physical connection stream to the backend, layered with an SSL/TLS stream if in secure mode. - /// - Stream _stream = default!; - - internal NpgsqlConnectionStringBuilder Settings { get; } - internal string ConnectionString { get; } - - ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; } - RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; } - ProvidePasswordCallback? ProvidePasswordCallback { get; } - - internal Encoding TextEncoding { get; private set; } = default!; - - /// - /// Same as , except that it does not throw an exception if an invalid char is - /// encountered (exception fallback), but rather replaces it with a question mark character (replacement - /// fallback). - /// - internal Encoding RelaxedTextEncoding { get; private set; } = default!; - - /// - /// Buffer used for reading data. - /// - internal NpgsqlReadBuffer ReadBuffer { get; private set; } = default!; - - /// - /// If we read a data row that's bigger than , we allocate an oversize buffer. - /// The original (smaller) buffer is stored here, and restored when the connection is reset. - /// - NpgsqlReadBuffer? _origReadBuffer; - - /// - /// Buffer used for writing data. - /// - internal NpgsqlWriteBuffer WriteBuffer { get; private set; } = default!; - - /// - /// The secret key of the backend for this connector, used for query cancellation. - /// - int _backendSecretKey; - - /// - /// The process ID of the backend for this connector. - /// - internal int BackendProcessId { get; private set; } - - bool SupportsPostgresCancellation => BackendProcessId != 0; - - /// - /// A unique ID identifying this connector, used for logging. Currently mapped to BackendProcessId - /// - internal int Id => BackendProcessId; - - internal NpgsqlDatabaseInfo DatabaseInfo { get; private set; } = default!; - - internal ConnectorTypeMapper TypeMapper { get; set; } = default!; - - /// - /// The current transaction status for this connector. - /// - internal TransactionStatus TransactionStatus { get; set; } - - /// - /// A transaction object for this connector. Since only one transaction can be in progress at any given time, - /// this instance is recycled. To check whether a transaction is currently in progress on this connector, - /// see . - /// - internal NpgsqlTransaction? Transaction { get; set; } - - /// - /// The NpgsqlConnection that (currently) owns this connector. Null if the connector isn't - /// owned (i.e. idle in the pool) - /// - internal NpgsqlConnection? Connection { get; set; } - - /// - /// The number of messages that were prepended to the current message chain, but not yet sent. - /// Note that this only tracks messages which produce a ReadyForQuery message - /// - internal int PendingPrependedResponses { get; set; } - - internal NpgsqlDataReader? CurrentReader; - - internal PreparedStatementManager PreparedStatementManager; - - /// - /// If the connector is currently in COPY mode, holds a reference to the importer/exporter object. - /// Otherwise null. - /// - internal ICancelable? CurrentCopyOperation; - - /// - /// Holds all run-time parameters received from the backend (via ParameterStatus messages) - /// - internal readonly Dictionary PostgresParameters; - - /// - /// Holds all run-time parameters in raw, binary format for efficient handling without allocations. - /// - readonly List<(byte[] Name, byte[] Value)> _rawParameters = new List<(byte[], byte[])>(); - - /// - /// If this connector was broken, this contains the exception that caused the break. - /// - volatile Exception? _breakReason; - - /// - /// Semaphore, used to synchronize DatabaseInfo between multiple connections, so it wouldn't be loaded in parallel. - /// - static readonly SemaphoreSlim DatabaseInfoSemaphore = new SemaphoreSlim(1); - - /// - /// - /// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write - /// isn't started concurrently. Note that since we have only one write loop, this is only ever usedto - /// protect against an over-capacity writes into a connector that's currently *asynchronously* writing. - /// - /// - /// It is guaranteed that the currently-executing - /// Specifically, reading may occur - and the connector may even be returned to the pool - before this is - /// released. - /// - /// - internal volatile int MultiplexAsyncWritingLock; - - /// - internal void FlagAsNotWritableForMultiplexing() - { - if (Settings.Multiplexing) - { - Debug.Assert(CommandsInFlightCount > 0 || IsBroken || IsClosed, - $"About to mark multiplexing connector as non-writable, but {nameof(CommandsInFlightCount)} is {CommandsInFlightCount}"); - - Interlocked.Exchange(ref MultiplexAsyncWritingLock, 1); - } - } - - /// - internal void FlagAsWritableForMultiplexing() - { - if (Settings.Multiplexing && Interlocked.CompareExchange(ref MultiplexAsyncWritingLock, 0, 1) != 1) - throw new Exception("Multiplexing lock was not taken when releasing. Please report a bug."); - } - - /// - /// The timeout for reading messages that are part of the user's command - /// (i.e. which aren't internal prepended commands). - /// - /// Precision is milliseconds - internal int UserTimeout { private get; set; } - - /// - /// A lock that's taken while a user action is in progress, e.g. a command being executed. - /// Only used when keepalive is enabled, otherwise null. - /// - SemaphoreSlim? _userLock; - - /// - /// A lock that's taken while a cancellation is being delivered; new queries are blocked until the - /// cancellation is delivered. This reduces the chance that a cancellation meant for a previous - /// command will accidentally cancel a later one, see #615. - /// - internal object CancelLock { get; } - - readonly bool _isKeepAliveEnabled; - readonly Timer? _keepAliveTimer; - - /// - /// The command currently being executed by the connector, null otherwise. - /// Used only for concurrent use error reporting purposes. - /// - NpgsqlCommand? _currentCommand; - - bool _sendResetOnClose; - - ConnectorPool? _pool; - - /// - /// Contains the UTC timestamp when this connector was opened, used to implement - /// . - /// - internal DateTime OpenTimestamp { get; private set; } - - internal int ClearCounter { get; set; } - - volatile bool _postgresCancellationPerformed; - internal bool PostgresCancellationPerformed - { - get => _postgresCancellationPerformed; - private set => _postgresCancellationPerformed = value; - } - - volatile bool _userCancellationRequested; - CancellationTokenRegistration _cancellationTokenRegistration; - internal bool UserCancellationRequested => _userCancellationRequested; - internal CancellationToken UserCancellationToken { get; set; } - internal bool AttemptPostgresCancellation { get; private set; } - static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.FromMilliseconds(-1); - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnector)); - - internal readonly Stopwatch QueryLogStopWatch = new Stopwatch(); - - #endregion - - #region Constants - - /// - /// The minimum timeout that can be set on internal commands such as COMMIT, ROLLBACK. - /// - /// Precision is seconds - internal const int MinimumInternalCommandTimeout = 3; - - #endregion - - #region Reusable Message Objects - - byte[]? _resetWithoutDeallocateMessage; - - int _resetWithoutDeallocateResponseCount; - - // Backend - readonly CommandCompleteMessage _commandCompleteMessage = new CommandCompleteMessage(); - readonly ReadyForQueryMessage _readyForQueryMessage = new ReadyForQueryMessage(); - readonly ParameterDescriptionMessage _parameterDescriptionMessage = new ParameterDescriptionMessage(); - readonly DataRowMessage _dataRowMessage = new DataRowMessage(); - readonly RowDescriptionMessage _rowDescriptionMessage = new RowDescriptionMessage(); - - // Since COPY is rarely used, allocate these lazily - CopyInResponseMessage? _copyInResponseMessage; - CopyOutResponseMessage? _copyOutResponseMessage; - CopyDataMessage? _copyDataMessage; - CopyBothResponseMessage? _copyBothResponseMessage; - - #endregion - - internal NpgsqlDataReader DataReader { get; set; } - - #region Constructors - - internal NpgsqlConnector(NpgsqlConnection connection) - : this(connection.Settings, connection.OriginalConnectionString) - { - Connection = connection; - _pool = connection.Pool; - Connection.Connector = this; - ProvideClientCertificatesCallback = Connection.ProvideClientCertificatesCallback; - UserCertificateValidationCallback = Connection.UserCertificateValidationCallback; - ProvidePasswordCallback = Connection.ProvidePasswordCallback; - } - - NpgsqlConnector(NpgsqlConnector connector) - : this(connector.Settings, connector.ConnectionString) - { - ProvideClientCertificatesCallback = connector.ProvideClientCertificatesCallback; - UserCertificateValidationCallback = connector.UserCertificateValidationCallback; - ProvidePasswordCallback = connector.ProvidePasswordCallback; - } - - /// - /// Creates a new connector with the given connection string. - /// - /// The parsed connection string. - /// The connection string. - NpgsqlConnector(NpgsqlConnectionStringBuilder settings, string connectionString) - { - State = ConnectorState.Closed; - TransactionStatus = TransactionStatus.Idle; - Settings = settings; - ConnectionString = connectionString; - PostgresParameters = new Dictionary(); - - CancelLock = new object(); - - _isKeepAliveEnabled = Settings.KeepAlive > 0; - if (_isKeepAliveEnabled) - { - _userLock = new SemaphoreSlim(1, 1); - _keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite); - } - - DataReader = new NpgsqlDataReader(this); - - // TODO: Not just for automatic preparation anymore... - PreparedStatementManager = new PreparedStatementManager(this); - - if (settings.Multiplexing) - { - // Note: It's OK for this channel to be unbounded: each command enqueued to it is accompanied by sending - // it to PostgreSQL. If we overload it, a TCP zero window will make us block on the networking side - // anyway. - // Note: the in-flight channel can probably be single-writer, but that doesn't actually do anything - // at this point. And we currently rely on being able to complete the channel at any point (from - // Break). We may want to revisit this if an optimized, SingleWriter implementation is introduced. - var commandsInFlightChannel = Channel.CreateUnbounded( - new UnboundedChannelOptions { SingleReader = true }); - CommandsInFlightReader = commandsInFlightChannel.Reader; - CommandsInFlightWriter = commandsInFlightChannel.Writer; - - // TODO: Properly implement this - if (_isKeepAliveEnabled) - throw new NotImplementedException("Keepalive not yet implemented for multiplexing"); - } - } - - #endregion - - #region Configuration settings - - string Host => Settings.Host!; - int Port => Settings.Port; - string Database => Settings.Database!; - string KerberosServiceName => Settings.KerberosServiceName; - SslMode SslMode => Settings.SslMode; - int ConnectionTimeout => Settings.Timeout; - bool IntegratedSecurity => Settings.IntegratedSecurity; - internal bool ConvertInfinityDateTime => Settings.ConvertInfinityDateTime; - - /// - /// The actual command timeout value that gets set on internal commands. - /// - /// Precision is milliseconds - int InternalCommandTimeout - { - get - { - var internalTimeout = Settings.InternalCommandTimeout; - if (internalTimeout == -1) - return Math.Max(Settings.CommandTimeout, MinimumInternalCommandTimeout) * 1000; - - // Todo: Decide what we really want here - // This assertion can easily fail if InternalCommandTimeout is set to 1 or 2 in the connection string - // We probably don't want to allow these values but in that case a Debug.Assert is the wrong way to enforce it. - Debug.Assert(internalTimeout == 0 || internalTimeout >= MinimumInternalCommandTimeout); - return internalTimeout * 1000; - } - } - - #endregion Configuration settings - - #region State management - - int _state; - - /// - /// Gets the current state of the connector - /// - internal ConnectorState State - { - get => (ConnectorState)_state; - set - { - var newState = (int)value; - if (newState == _state) - return; - Interlocked.Exchange(ref _state, newState); - } - } - - /// - /// Returns whether the connector is open, regardless of any task it is currently performing - /// - bool IsConnected - => State switch - { - ConnectorState.Ready => true, - ConnectorState.Executing => true, - ConnectorState.Fetching => true, - ConnectorState.Waiting => true, - ConnectorState.Copy => true, - ConnectorState.Replication => true, - ConnectorState.Closed => false, - ConnectorState.Connecting => false, - ConnectorState.Broken => false, - _ => throw new ArgumentOutOfRangeException("Unknown state: " + State) - }; - - internal bool IsReady => State == ConnectorState.Ready; - internal bool IsClosed => State == ConnectorState.Closed; - internal bool IsBroken => State == ConnectorState.Broken; - - #endregion - - #region Open - - /// - /// Opens the physical connection to the server. - /// - /// Usually called by the RequestConnector - /// Method of the connection pool manager. - internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - Debug.Assert(Connection != null && Connection.Connector == this); - Debug.Assert(State == ConnectorState.Closed); - - State = ConnectorState.Connecting; - - try - { - await RawOpen(timeout, async, cancellationToken); - - var username = GetUsername(); - if (Settings.Database == null) - Settings.Database = username; - - timeout.CheckAndApply(this); - WriteStartupMessage(username); - await Flush(async, cancellationToken); - - using (StartCancellableOperation(cancellationToken, attemptPgCancellation: false)) - { - await Authenticate(username, timeout, async, cancellationToken); - - // We treat BackendKeyData as optional because some PostgreSQL-like database - // don't send it (CockroachDB, CrateDB) - var msg = await ReadMessage(async); - if (msg.Code == BackendMessageCode.BackendKeyData) - { - var keyDataMsg = (BackendKeyDataMessage)msg; - BackendProcessId = keyDataMsg.BackendProcessId; - _backendSecretKey = keyDataMsg.BackendSecretKey; - msg = await ReadMessage(async); - } - - if (msg.Code != BackendMessageCode.ReadyForQuery) - throw new NpgsqlException($"Received backend message {msg.Code} while expecting ReadyForQuery. Please file a bug."); - - State = ConnectorState.Ready; - } - - await LoadDatabaseInfo(forceReload: false, timeout, async, cancellationToken); - - if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard) - { - _sendResetOnClose = true; - GenerateResetMessage(); - } - - OpenTimestamp = DateTime.UtcNow; - Log.Trace($"Opened connection to {Host}:{Port}"); - - if (Settings.Multiplexing) - { - // Start an infinite async loop, which processes incoming multiplexing traffic. - // It is intentionally not awaited and will run as long as the connector is alive. - // The CommandsInFlightWriter channel is completed in Cleanup, which should cause this task - // to complete. - _ = Task.Run(MultiplexingReadLoop, CancellationToken.None) - .ContinueWith(t => - { - // Note that we *must* observe the exception if the task is faulted. - Log.Error("Exception bubbled out of multiplexing read loop", t.Exception!, Id); - }, TaskContinuationOptions.OnlyOnFaulted); - } - - if (_isKeepAliveEnabled) - { - // Start the keep alive mechanism to work by scheduling the timer. - // Otherwise, it doesn't work for cases when no query executed during - // the connection lifetime in case of a new connector. - lock (this) - { - var keepAlive = Settings.KeepAlive * 1000; - _keepAliveTimer!.Change(keepAlive, keepAlive); - } - } - } - catch (Exception e) - { - Break(e); - throw; - } - } - - internal async ValueTask LoadDatabaseInfo(bool forceReload, NpgsqlTimeout timeout, bool async, - CancellationToken cancellationToken = default) - { - // Super hacky stuff... - - var prevBindingScope = Connection!.ConnectorBindingScope; - Connection.ConnectorBindingScope = ConnectorBindingScope.PhysicalConnecting; - using var _ = Defer(static (conn, prevScope) => conn.ConnectorBindingScope = prevScope, Connection, prevBindingScope); - - // The type loading below will need to send queries to the database, and that depends on a type mapper - // being set up (even if its empty) - TypeMapper = new ConnectorTypeMapper(this); - - var key = new NpgsqlDatabaseInfoCacheKey(Settings); - if (forceReload || !NpgsqlDatabaseInfo.Cache.TryGetValue(key, out var database)) - { - var hasSemaphore = async - ? await DatabaseInfoSemaphore.WaitAsync(timeout.TimeLeft, cancellationToken) - : DatabaseInfoSemaphore.Wait(timeout.TimeLeft, cancellationToken); - - // We've timed out - calling Check, to throw the correct exception - if (!hasSemaphore) - timeout.Check(); - - try - { - if (forceReload || !NpgsqlDatabaseInfo.Cache.TryGetValue(key, out database)) - { - NpgsqlDatabaseInfo.Cache[key] = database = await NpgsqlDatabaseInfo.Load(Connection, - timeout, async); - } - } - finally - { - DatabaseInfoSemaphore.Release(); - } - } - - DatabaseInfo = database!; - TypeMapper.Bind(DatabaseInfo); - } - - void WriteStartupMessage(string username) - { - var startupParams = new Dictionary - { - ["user"] = username, - ["client_encoding"] = Settings.ClientEncoding ?? - PostgresEnvironment.ClientEncoding ?? - "UTF8", - ["database"] = Settings.Database! - }; - - if (Settings.ApplicationName?.Length > 0) - startupParams["application_name"] = Settings.ApplicationName; - - if (Settings.SearchPath?.Length > 0) - startupParams["search_path"] = Settings.SearchPath; - - var timezone = Settings.Timezone ?? PostgresEnvironment.TimeZone; - if (timezone != null) - startupParams["TimeZone"] = timezone; - - var options = Settings.Options ?? PostgresEnvironment.Options; - if (options?.Length > 0) - startupParams["options"] = options; - - switch (Settings.ReplicationMode) - { - case ReplicationMode.Logical: - startupParams["replication"] = "database"; - break; - case ReplicationMode.Physical: - startupParams["replication"] = "true"; - break; - } - - WriteStartup(startupParams); - } - - string GetUsername() - { - var username = Settings.Username; - if (username?.Length > 0) - return username; - - username = PostgresEnvironment.User; - if (username?.Length > 0) - return username; - - if (!PGUtil.IsWindows) - { - username = KerberosUsernameProvider.GetUsername(Settings.IncludeRealm); - if (username?.Length > 0) - return username; - } - - username = Environment.UserName; - if (username?.Length > 0) - return username; - - throw new NpgsqlException("No username could be found, please specify one explicitly"); - } - - async Task RawOpen(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - var cert = default(X509Certificate2?); - try - { - if (async) - await ConnectAsync(timeout, cancellationToken); - else - Connect(timeout); - - _baseStream = new NetworkStream(_socket, true); - _stream = _baseStream; - - if (Settings.Encoding == "UTF8") - { - TextEncoding = PGUtil.UTF8Encoding; - RelaxedTextEncoding = PGUtil.RelaxedUTF8Encoding; - } - else - { - TextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback); - RelaxedTextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ReplacementFallback, DecoderFallback.ReplacementFallback); - } - - ReadBuffer = new NpgsqlReadBuffer(this, _stream, _socket, Settings.ReadBufferSize, TextEncoding, RelaxedTextEncoding); - WriteBuffer = new NpgsqlWriteBuffer(this, _stream, _socket, Settings.WriteBufferSize, TextEncoding); - - timeout.CheckAndApply(this); - - if (SslMode == SslMode.Require || SslMode == SslMode.Prefer) - { - WriteSslRequest(); - await Flush(async, cancellationToken); - - await ReadBuffer.Ensure(1, async); - var response = (char)ReadBuffer.ReadByte(); - timeout.CheckAndApply(this); - - switch (response) - { - default: - throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)"); - case 'N': - if (SslMode == SslMode.Require) - throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured."); - break; - case 'S': - var clientCertificates = new X509Certificate2Collection(); - var certPath = Settings.ClientCertificate ?? PostgresEnvironment.SslCert; - - if (certPath is null && PostgresEnvironment.SslCertDefault is string certPathDefault) - certPath = certPathDefault; - - if (certPath != null) - { - cert = new X509Certificate2(certPath, Settings.ClientCertificateKey ?? PostgresEnvironment.SslKey); - clientCertificates.Add(cert); - } - - ProvideClientCertificatesCallback?.Invoke(clientCertificates); - - var certificateValidationCallback = Settings.TrustServerCertificate - ? SslTrustServerValidation - : (Settings.RootCertificate ?? PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is { } certRootPath - ? SslRootValidation(certRootPath) - : UserCertificateValidationCallback is { } userValidation - ? userValidation - : SslDefaultValidation; - - timeout.CheckAndApply(this); - - try - { - var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false, certificateValidationCallback); - - if (async) - await sslStream.AuthenticateAsClientAsync(Host, clientCertificates, - SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, Settings.CheckCertificateRevocation); - else - sslStream.AuthenticateAsClient(Host, clientCertificates, - SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, Settings.CheckCertificateRevocation); - - _stream = sslStream; - } - catch (Exception e) - { - throw new NpgsqlException("Exception while performing SSL handshake", e); - } - - ReadBuffer.Clear(); // Reset to empty after reading single SSL char - ReadBuffer.Underlying = _stream; - WriteBuffer.Underlying = _stream; - IsSecure = true; - Log.Trace("SSL negotiation successful"); - break; - } - } - - Log.Trace($"Socket connected to {Host}:{Port}"); - } - catch - { - cert?.Dispose(); - - _stream?.Dispose(); - _stream = null!; - - _baseStream?.Dispose(); - _baseStream = null!; - - _socket?.Dispose(); - _socket = null!; - - throw; - } - } - - void Connect(NpgsqlTimeout timeout) - { - // Note that there aren't any timeout-able or cancellable DNS methods - var endpoints = Path.IsPathRooted(Host) - ? new EndPoint[] { new UnixDomainSocketEndPoint(Path.Combine(Host, $".s.PGSQL.{Port}")) } - : Dns.GetHostAddresses(Host).Select(a => new IPEndPoint(a, Port)).ToArray(); - timeout.Check(); - - // Give each endpoint an equal share of the remaining time - var perEndpointTimeout = -1; // Default to infinity - if (timeout.IsSet) - { - var timeoutTicks = timeout.TimeLeft.Ticks; - if (timeoutTicks <= 0) - throw new TimeoutException(); - perEndpointTimeout = (int)(timeoutTicks / endpoints.Length / 10); - } - - for (var i = 0; i < endpoints.Length; i++) - { - var endpoint = endpoints[i]; - Log.Trace($"Attempting to connect to {endpoint}"); - var protocolType = endpoint.AddressFamily == AddressFamily.InterNetwork ? ProtocolType.Tcp : ProtocolType.IP; - var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType) - { - Blocking = false - }; - - try - { - try - { - socket.Connect(endpoint); - } - catch (SocketException e) - { - if (e.SocketErrorCode != SocketError.WouldBlock) - throw; - } - var write = new List { socket }; - var error = new List { socket }; - Socket.Select(null, write, error, perEndpointTimeout); - var errorCode = (int) socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.Error)!; - if (errorCode != 0) - throw new SocketException(errorCode); - if (!write.Any()) - throw new TimeoutException("Timeout during connection attempt"); - socket.Blocking = true; - SetSocketOptions(socket); - _socket = socket; - return; - } - catch (Exception e) - { - try { socket.Dispose(); } - catch - { - // ignored - } - - Log.Trace($"Failed to connect to {endpoint}", e); - - if (i == endpoints.Length - 1) - throw new NpgsqlException("Exception while connecting", e); - } - } - } - - async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken) - { - // Note that there aren't any timeout-able or cancellable DNS methods - var endpoints = Path.IsPathRooted(Host) - ? new EndPoint[] { new UnixDomainSocketEndPoint(Path.Combine(Host, $".s.PGSQL.{Port}")) } - : (await Dns.GetHostAddressesAsync(Host).WithCancellationAndTimeout(timeout, cancellationToken)) - .Select(a => new IPEndPoint(a, Port)).ToArray(); - - // Give each IP an equal share of the remaining time - var perIpTimespan = default(TimeSpan); - var perIpTimeout = timeout; - if (timeout.IsSet) - { - var timeoutTicks = timeout.TimeLeft.Ticks; - if (timeoutTicks <= 0) - throw new TimeoutException(); - perIpTimespan = new TimeSpan(timeoutTicks / endpoints.Length); - perIpTimeout = new NpgsqlTimeout(perIpTimespan); - } - - for (var i = 0; i < endpoints.Length; i++) - { - var endpoint = endpoints[i]; - Log.Trace($"Attempting to connect to {endpoint}"); - var protocolType = endpoint.AddressFamily == AddressFamily.InterNetwork ? ProtocolType.Tcp : ProtocolType.IP; - var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType); - CancellationTokenSource? combinedCts = null; - try - { - // .NET 5.0 added cancellation support to ConnectAsync, which allows us to implement real - // cancellation and timeout. On older TFMs, we fake-cancel the operation, i.e. stop waiting - // and raise the exception, but the actual connection task is left running. - -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - await socket.ConnectAsync(endpoint) - .WithCancellationAndTimeout(perIpTimeout, cancellationToken); -#else - var finalCt = cancellationToken; - - if (perIpTimeout.IsSet) - { - combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - combinedCts.CancelAfter((int)perIpTimeout.TimeLeft.TotalMilliseconds); - finalCt = combinedCts.Token; - } - - await socket.ConnectAsync(endpoint, finalCt); -#endif - - SetSocketOptions(socket); - _socket = socket; - return; - } - catch (Exception e) - { - try - { - socket.Dispose(); - } - catch - { - // ignored - } - - cancellationToken.ThrowIfCancellationRequested(); - - if (e is OperationCanceledException) - e = new TimeoutException("Timeout during connection attempt"); - - Log.Trace($"Failed to connect to {endpoint}", e); - - if (i == endpoints.Length - 1) - { - throw new NpgsqlException("Exception while connecting", e); - } - } - finally - { - combinedCts?.Dispose(); - } - } - } - - void SetSocketOptions(Socket socket) - { - if (socket.AddressFamily == AddressFamily.InterNetwork) - socket.NoDelay = true; - if (Settings.SocketReceiveBufferSize > 0) - socket.ReceiveBufferSize = Settings.SocketReceiveBufferSize; - if (Settings.SocketSendBufferSize > 0) - socket.SendBufferSize = Settings.SocketSendBufferSize; - - if (Settings.TcpKeepAlive) - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true); - if (Settings.TcpKeepAliveInterval > 0 && Settings.TcpKeepAliveTime == 0) - throw new ArgumentException("If TcpKeepAliveInterval is defined, TcpKeepAliveTime must be defined as well"); - if (Settings.TcpKeepAliveTime > 0) - { - var timeSeconds = Settings.TcpKeepAliveTime; - var intervalSeconds = Settings.TcpKeepAliveInterval > 0 - ? Settings.TcpKeepAliveInterval - : Settings.TcpKeepAliveTime; - -#if NETSTANDARD2_0 || NETSTANDARD2_1 - var timeMilliseconds = timeSeconds * 1000; - var intervalMilliseconds = intervalSeconds * 1000; - - // For the following see https://msdn.microsoft.com/en-us/library/dd877220.aspx - var uintSize = Marshal.SizeOf(typeof(uint)); - var inOptionValues = new byte[uintSize * 3]; - BitConverter.GetBytes((uint)1).CopyTo(inOptionValues, 0); - BitConverter.GetBytes((uint)timeMilliseconds).CopyTo(inOptionValues, uintSize); - BitConverter.GetBytes((uint)intervalMilliseconds).CopyTo(inOptionValues, uintSize * 2); - var result = 0; - try - { - result = socket.IOControl(IOControlCode.KeepAliveValues, inOptionValues, null); - } - catch (PlatformNotSupportedException) - { - throw new PlatformNotSupportedException("Setting TCP Keepalive Time and TCP Keepalive Interval is supported only on Windows, Mono and .NET Core 3.1+. " + - "TCP keepalives can still be used on other systems but are enabled via the TcpKeepAlive option or configured globally for the machine, see the relevant docs."); - } - - if (result != 0) - throw new NpgsqlException($"Got non-zero value when trying to set TCP keepalive: {result}"); -#else - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true); - socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, timeSeconds); - socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, intervalSeconds); -#endif - } - } - - #endregion - - #region I/O - - internal readonly ChannelReader? CommandsInFlightReader; - internal readonly ChannelWriter? CommandsInFlightWriter; - - internal volatile int CommandsInFlightCount; - - internal ManualResetValueTaskSource ReaderCompleted { get; } = - new ManualResetValueTaskSource { RunContinuationsAsynchronously = true }; - - async Task MultiplexingReadLoop() - { - Debug.Assert(Settings.Multiplexing); - Debug.Assert(CommandsInFlightReader != null); - - NpgsqlCommand? command = null; - var commandsRead = 0; - - try - { - while (await CommandsInFlightReader.WaitToReadAsync()) - { - commandsRead = 0; - Debug.Assert(!InTransaction); - - while (CommandsInFlightReader.TryRead(out command)) - { - commandsRead++; - - await ReadBuffer.Ensure(5, true); - - // We have a resultset for the command - hand back control to the command (which will - // return it to the user) - ReaderCompleted.Reset(); - command.ExecutionCompletion.SetResult(this); - - // Now wait until that command's reader is disposed. Note that RunContinuationsAsynchronously is - // true, so that the user code calling NpgsqlDataReader.Dispose will not continue executing - // synchronously here. The prevents issues if the code after the next command's execution - // completion blocks. - await new ValueTask(ReaderCompleted, ReaderCompleted.Version); - Debug.Assert(!InTransaction); - } - - // Atomically update the commands in-flight counter, and check if it reached 0. If so, the - // connector is idle and can be returned. - // Note that this is racing with over-capacity writing, which can select any connector at any - // time (see MultiplexingWriteLoop), and we must make absolutely sure that if a connector is - // returned to the pool, it is *never* written to unless properly dequeued from the Idle channel. - if (Interlocked.Add(ref CommandsInFlightCount, -commandsRead) == 0) - { - // There's a race condition where the continuation of an asynchronous multiplexing write may not - // have executed yet, and the flush may still be in progress. We know all I/O has already - // been sent - because the reader has already consumed the entire resultset. So we wait until - // the connector's write lock has been released (long waiting will never occur here). - SpinWait.SpinUntil(() => MultiplexAsyncWritingLock == 0); - - _pool!.Return(this); - } - } - - Log.Trace("Exiting multiplexing read loop", Id); - } - catch (Exception e) - { - Debug.Assert(IsBroken); - - // Decrement the commands already dequeued from the in-flight counter - Interlocked.Add(ref CommandsInFlightCount, -commandsRead); - - // When a connector is broken, the causing exception is stored on it. We fail commands with - // that exception - rather than the one thrown here - since the break may have happened during - // writing, and we want to bubble that one up. - - // Drain any pending in-flight commands and fail them. Note that some have only been written - // to the buffer, and not sent to the server. - command?.ExecutionCompletion.SetException(_breakReason!); - try - { - while (true) - { - var pendingCommand = await CommandsInFlightReader.ReadAsync(); - - // TODO: the exception we have here is sometimes just the result of the write loop breaking - // the connector, so it doesn't represent the actual root cause. - pendingCommand.ExecutionCompletion.SetException(_breakReason!); - } - } - catch (ChannelClosedException) - { - // All good, drained to the channel and failed all commands - } - - // "Return" the connector to the pool to for cleanup (e.g. update total connector count) - _pool!.Return(this); - - Log.Error("Exception in multiplexing read loop", e, Id); - } - - Debug.Assert(CommandsInFlightCount == 0); - } - - #endregion - - #region Frontend message processing - - /// - /// Prepends a message to be sent at the beginning of the next message chain. - /// - internal void PrependInternalMessage(byte[] rawMessage, int responseMessageCount) - { - PendingPrependedResponses += responseMessageCount; - - var t = WritePregenerated(rawMessage); - Debug.Assert(t.IsCompleted, "Could not fully write pregenerated message into the buffer"); - } - - #endregion - - #region Backend message processing - - internal IBackendMessage ReadMessage(DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential) - => ReadMessage(async: false, dataRowLoadingMode).GetAwaiter().GetResult(); - - internal ValueTask ReadMessage(bool async, DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential) - => ReadMessage(async, dataRowLoadingMode, readingNotifications: false)!; - - internal ValueTask ReadMessageWithNotifications(bool async) - => ReadMessage(async, DataRowLoadingMode.NonSequential, readingNotifications: true); - - internal ValueTask ReadMessage( - bool async, - DataRowLoadingMode dataRowLoadingMode, - bool readingNotifications) - { - if (PendingPrependedResponses > 0 || - dataRowLoadingMode != DataRowLoadingMode.NonSequential || - readingNotifications || - ReadBuffer.ReadBytesLeft < 5) - { - return ReadMessageLong(this, async, dataRowLoadingMode, readingNotifications: readingNotifications); - } - - var messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); - switch (messageCode) - { - case BackendMessageCode.NoticeResponse: - case BackendMessageCode.NotificationResponse: - case BackendMessageCode.ParameterStatus: - case BackendMessageCode.ErrorResponse: - ReadBuffer.ReadPosition--; - return ReadMessageLong(this, async, dataRowLoadingMode, readingNotifications: false); - case BackendMessageCode.ReadyForQuery: - break; - } - - PGUtil.ValidateBackendMessageCode(messageCode); - var len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself - if (len > ReadBuffer.ReadBytesLeft) - { - ReadBuffer.ReadPosition -= 5; - return ReadMessageLong(this, async, dataRowLoadingMode, readingNotifications: false); - } - - return new ValueTask(ParseServerMessage(ReadBuffer, messageCode, len, false)); - - static async ValueTask ReadMessageLong( - NpgsqlConnector connector, - bool async, - DataRowLoadingMode dataRowLoadingMode, - bool readingNotifications, - bool isReadingPrependedMessage = false) - { - // First read the responses of any prepended messages. - if (connector.PendingPrependedResponses > 0 && !isReadingPrependedMessage) - { - try - { - // TODO: There could be room for optimization here, rather than the async call(s) - connector.ReadBuffer.Timeout = TimeSpan.FromMilliseconds(connector.InternalCommandTimeout); - for (; connector.PendingPrependedResponses > 0; connector.PendingPrependedResponses--) - await ReadMessageLong(connector, async, DataRowLoadingMode.Skip, readingNotifications: false, isReadingPrependedMessage: true); - } - catch (PostgresException e) - { - throw connector.Break(e); - } - } - - PostgresException? error = null; - - try - { - connector.ReadBuffer.Timeout = TimeSpan.FromMilliseconds(connector.UserTimeout); - - while (true) - { - await connector.ReadBuffer.Ensure(5, async, readingNotifications); - var messageCode = (BackendMessageCode)connector.ReadBuffer.ReadByte(); - PGUtil.ValidateBackendMessageCode(messageCode); - var len = connector.ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself - - if ((messageCode == BackendMessageCode.DataRow && - dataRowLoadingMode != DataRowLoadingMode.NonSequential) || - messageCode == BackendMessageCode.CopyData) - { - if (dataRowLoadingMode == DataRowLoadingMode.Skip) - { - await connector.ReadBuffer.Skip(len, async); - continue; - } - } - else if (len > connector.ReadBuffer.ReadBytesLeft) - { - if (len > connector.ReadBuffer.Size) - { - var oversizeBuffer = connector.ReadBuffer.AllocateOversize(len); - - if (connector._origReadBuffer == null) - connector._origReadBuffer = connector.ReadBuffer; - else - connector.ReadBuffer.Dispose(); - - connector.ReadBuffer = oversizeBuffer; - } - - await connector.ReadBuffer.Ensure(len, async); - } - - var msg = connector.ParseServerMessage(connector.ReadBuffer, messageCode, len, isReadingPrependedMessage); - - switch (messageCode) - { - case BackendMessageCode.ErrorResponse: - Debug.Assert(msg == null); - - // An ErrorResponse is (almost) always followed by a ReadyForQuery. Save the error - // and throw it as an exception when the ReadyForQuery is received (next). - error = PostgresException.Load(connector.ReadBuffer, connector.Settings.IncludeErrorDetails); - - if (connector.State == ConnectorState.Connecting) - { - // During the startup/authentication phase, an ErrorResponse isn't followed by - // an RFQ. Instead, the server closes the connection immediately - throw error; - } - - continue; - - case BackendMessageCode.ReadyForQuery: - if (error != null) - { - NpgsqlEventSource.Log.CommandFailed(); - throw error; - } - - break; - - // Asynchronous messages which can come anytime, they have already been handled - // in ParseServerMessage. Read the next message. - case BackendMessageCode.NoticeResponse: - case BackendMessageCode.NotificationResponse: - case BackendMessageCode.ParameterStatus: - Debug.Assert(msg == null); - if (!readingNotifications) - continue; - return null; - } - - Debug.Assert(msg != null, "Message is null for code: " + messageCode); - return msg; - } - } - catch (PostgresException e) - { - // TODO: move it up the stack, like #3126 did (relevant for non-command-execution scenarios, like COPY) - if (connector.CurrentReader is null) - connector.EndUserAction(); - - if (e.SqlState == PostgresErrorCodes.QueryCanceled && connector.PostgresCancellationPerformed) - { - // The query could be canceled because of a user cancellation or a timeout - raise the proper exception. - // If _postgresCancellationPerformed is false, this is an unsolicited cancellation - - // just bubble up thePostgresException. - throw connector.UserCancellationRequested - ? new OperationCanceledException("Query was cancelled", e, connector.UserCancellationToken) - : new NpgsqlException("Exception while reading from stream", - new TimeoutException("Timeout during reading attempt")); - } - - throw; - } - catch (NpgsqlException) - { - // An ErrorResponse isn't followed by ReadyForQuery - if (error != null) - ExceptionDispatchInfo.Capture(error).Throw(); - throw; - } - } - } - - internal IBackendMessage? ParseServerMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool isPrependedMessage) - { - switch (code) - { - case BackendMessageCode.RowDescription: - return _rowDescriptionMessage.Load(buf, TypeMapper); - case BackendMessageCode.DataRow: - return _dataRowMessage.Load(len); - case BackendMessageCode.CommandComplete: - return _commandCompleteMessage.Load(buf, len); - case BackendMessageCode.ReadyForQuery: - var rfq = _readyForQueryMessage.Load(buf); - if (!isPrependedMessage) { - // Transaction status on prepended messages shouldn't be processed, because there may be prepended messages - // before the begin transaction message. In this case, they will contain transaction status Idle, which will - // clear our Pending transaction status. Only process transaction status on RFQ's from user-provided, non - // prepended messages. - ProcessNewTransactionStatus(rfq.TransactionStatusIndicator); - } - return rfq; - case BackendMessageCode.EmptyQueryResponse: - return EmptyQueryMessage.Instance; - case BackendMessageCode.ParseComplete: - return ParseCompleteMessage.Instance; - case BackendMessageCode.ParameterDescription: - return _parameterDescriptionMessage.Load(buf); - case BackendMessageCode.BindComplete: - return BindCompleteMessage.Instance; - case BackendMessageCode.NoData: - return NoDataMessage.Instance; - case BackendMessageCode.CloseComplete: - return CloseCompletedMessage.Instance; - case BackendMessageCode.ParameterStatus: - ReadParameterStatus(buf.GetNullTerminatedBytes(), buf.GetNullTerminatedBytes()); - return null; - case BackendMessageCode.NoticeResponse: - var notice = PostgresNotice.Load(buf, Settings.IncludeErrorDetails); - Log.Debug($"Received notice: {notice.MessageText}", Id); - Connection?.OnNotice(notice); - return null; - case BackendMessageCode.NotificationResponse: - Connection?.OnNotification(new NpgsqlNotificationEventArgs(buf)); - return null; - - case BackendMessageCode.AuthenticationRequest: - var authType = (AuthenticationRequestType)buf.ReadInt32(); - return authType switch - { - AuthenticationRequestType.AuthenticationOk => (AuthenticationRequestMessage)AuthenticationOkMessage.Instance, - AuthenticationRequestType.AuthenticationCleartextPassword => AuthenticationCleartextPasswordMessage.Instance, - AuthenticationRequestType.AuthenticationMD5Password => AuthenticationMD5PasswordMessage.Load(buf), - AuthenticationRequestType.AuthenticationGSS => AuthenticationGSSMessage.Instance, - AuthenticationRequestType.AuthenticationSSPI => AuthenticationSSPIMessage.Instance, - AuthenticationRequestType.AuthenticationGSSContinue => AuthenticationGSSContinueMessage.Load(buf, len), - AuthenticationRequestType.AuthenticationSASL => new AuthenticationSASLMessage(buf), - AuthenticationRequestType.AuthenticationSASLContinue => new AuthenticationSASLContinueMessage(buf, len - 4), - AuthenticationRequestType.AuthenticationSASLFinal => new AuthenticationSASLFinalMessage(buf, len - 4), - _ => throw new NotSupportedException($"Authentication method not supported (Received: {authType})") - }; - - case BackendMessageCode.BackendKeyData: - return new BackendKeyDataMessage(buf); - - case BackendMessageCode.CopyInResponse: - return (_copyInResponseMessage ??= new CopyInResponseMessage()).Load(ReadBuffer); - case BackendMessageCode.CopyOutResponse: - return (_copyOutResponseMessage ??= new CopyOutResponseMessage()).Load(ReadBuffer); - case BackendMessageCode.CopyData: - return (_copyDataMessage ??= new CopyDataMessage()).Load(len); - case BackendMessageCode.CopyBothResponse: - return (_copyBothResponseMessage ??= new CopyBothResponseMessage()).Load(ReadBuffer); - - case BackendMessageCode.CopyDone: - return CopyDoneMessage.Instance; - - case BackendMessageCode.PortalSuspended: - throw new NpgsqlException("Unimplemented message: " + code); - case BackendMessageCode.ErrorResponse: - return null; - - case BackendMessageCode.FunctionCallResponse: - // We don't use the obsolete function call protocol - throw new NpgsqlException("Unexpected backend message: " + code); - - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {code} of enum {nameof(BackendMessageCode)}. Please file a bug."); - } - } - - /// - /// Reads backend messages and discards them, stopping only after a message of the given type has - /// been seen. Only a sync I/O version of this method exists - in async flows we inline the loop - /// rather than calling an additional async method, in order to avoid the overhead. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal IBackendMessage SkipUntil(BackendMessageCode stopAt) - { - Debug.Assert(stopAt != BackendMessageCode.DataRow, "Shouldn't be used for rows, doesn't know about sequential"); - - while (true) - { - var msg = ReadMessage(async: false, DataRowLoadingMode.Skip).GetAwaiter().GetResult()!; - Debug.Assert(!(msg is DataRowMessage)); - if (msg.Code == stopAt) - return msg; - } - } - - #endregion Backend message processing - - #region Transactions - - internal async Task Rollback(bool async, CancellationToken cancellationToken = default) - { - Log.Debug("Rolling back transaction", Id); - using (StartUserAction(cancellationToken)) - await ExecuteInternalCommand(PregeneratedMessages.RollbackTransaction, async, cancellationToken); - } - - internal bool InTransaction - => TransactionStatus switch - { - TransactionStatus.Idle => false, - TransactionStatus.Pending => true, - TransactionStatus.InTransactionBlock => true, - TransactionStatus.InFailedTransactionBlock => true, - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {TransactionStatus} of enum {nameof(TransactionStatus)}. Please file a bug.") - }; - - /// - /// Handles a new transaction indicator received on a ReadyForQuery message - /// - void ProcessNewTransactionStatus(TransactionStatus newStatus) - { - if (newStatus == TransactionStatus) - return; - - TransactionStatus = newStatus; - - switch (newStatus) - { - case TransactionStatus.Idle: - break; - case TransactionStatus.InTransactionBlock: - case TransactionStatus.InFailedTransactionBlock: - // In multiplexing mode, we can't support transaction in SQL: the connector must be removed from the - // writable connectors list, otherwise other commands may get written to it. So the user must tell us - // about the transaction via BeginTransaction. - if (Connection is null) - { - Debug.Assert(Settings.Multiplexing); - throw new NotSupportedException("In multiplexing mode, transactions must be started with BeginTransaction"); - } - break; - case TransactionStatus.Pending: - throw new Exception($"Internal Npgsql bug: invalid TransactionStatus {nameof(TransactionStatus.Pending)} received, should be frontend-only"); - default: - throw new InvalidOperationException( - $"Internal Npgsql bug: unexpected value {newStatus} of enum {nameof(TransactionStatus)}. Please file a bug."); - } - } - - internal void ClearTransaction() - { - Transaction?.DisposeImmediately(); - TransactionStatus = TransactionStatus.Idle; - } - - #endregion - - #region SSL - - /// - /// Returns whether SSL is being used for the connection - /// - internal bool IsSecure { get; private set; } - - /// - /// Returns whether SCRAM-SHA256 is being user for the connection - /// - internal bool IsScram { get; private set; } - - /// - /// Returns whether SCRAM-SHA256-PLUS is being user for the connection - /// - internal bool IsScramPlus { get; private set; } - - static readonly RemoteCertificateValidationCallback SslDefaultValidation = - (sender, certificate, chain, sslPolicyErrors) - => sslPolicyErrors == SslPolicyErrors.None; - - static readonly RemoteCertificateValidationCallback SslTrustServerValidation = - (sender, certificate, chain, sslPolicyErrors) - => true; - - static RemoteCertificateValidationCallback SslRootValidation(string certRootPath) => - (sender, certificate, chain, sslPolicyErrors) => - { - if (certificate is null || chain is null) - return false; - - chain.ChainPolicy.ExtraStore.Add(new X509Certificate2(certRootPath)); - return chain.Build(certificate as X509Certificate2 ?? new X509Certificate2(certificate)); - }; - - #endregion SSL - - #region Cancel - - internal void PerformUserCancellation() - { - _userCancellationRequested = true; - - if (AttemptPostgresCancellation && SupportsPostgresCancellation) - { - var cancellationTimeout = Settings.CancellationTimeout; - if (PerformPostgresCancellation() && cancellationTimeout >= 0) - { - if (cancellationTimeout > 0) - { - UserTimeout = cancellationTimeout; - ReadBuffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); - ReadBuffer.Cts.CancelAfter(cancellationTimeout); - } - - return; - } - } - - UserTimeout = -1; - ReadBuffer.Timeout = _cancelImmediatelyTimeout; - ReadBuffer.Cts.Cancel(); - } - - /// - /// Creates another connector and sends a cancel request through it for this connector. This method never throws, but returns - /// whether the cancellation attempt failed. - /// - /// - /// - /// if the cancellation request was successfully delivered, or if it was skipped because a previous - /// request was already sent. if the cancellation request could not be delivered because of an exception - /// (the method logs internally). - /// - /// - /// This does not indicate whether the cancellation attempt was successful on the PostgreSQL side - only if the request was - /// delivered. - /// - /// - internal bool PerformPostgresCancellation() - { - Debug.Assert(BackendProcessId != 0, "PostgreSQL cancellation requested by the backend doesn't support it"); - - lock (CancelLock) - { - if (PostgresCancellationPerformed) - return true; - - Log.Debug("Sending cancellation...", Id); - PostgresCancellationPerformed = true; - - try - { - var cancelConnector = new NpgsqlConnector(this); - cancelConnector.DoCancelRequest(BackendProcessId, _backendSecretKey); - } - catch (Exception e) - { - var socketException = e.InnerException as SocketException; - if (socketException == null || socketException.SocketErrorCode != SocketError.ConnectionReset) - { - Log.Debug("Exception caught while attempting to cancel command", e, Id); - return false; - } - } - - return true; - } - } - - void DoCancelRequest(int backendProcessId, int backendSecretKey) - { - Debug.Assert(State == ConnectorState.Closed); - - try - { - RawOpen(new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)), false, CancellationToken.None) - .GetAwaiter().GetResult(); - WriteCancelRequest(backendProcessId, backendSecretKey); - Flush(); - - Debug.Assert(ReadBuffer.ReadPosition == 0); - - // Now wait for the server to close the connection, better chance of the cancellation - // actually being delivered before we continue with the user's logic. - var count = _stream.Read(ReadBuffer.Buffer, 0, 1); - if (count > 0) - Log.Error("Received response after sending cancel request, shouldn't happen! First byte: " + ReadBuffer.Buffer[0]); - } - finally - { - lock (this) - Cleanup(); - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal CancellationTokenRegistration StartCancellableOperation( - CancellationToken cancellationToken = default, - bool attemptPgCancellation = true) - { - _userCancellationRequested = PostgresCancellationPerformed = false; - UserCancellationToken = cancellationToken; - ReadBuffer.Cts.ResetCts(); - - AttemptPostgresCancellation = attemptPgCancellation; - return _cancellationTokenRegistration = - cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); - } - - /// - /// Starts a new cancellable operation within an ongoing user action. This should only be used if a single user - /// action spans several different actions which each has its own cancellation tokens. For example, a command - /// execution is a single user action, but spans ExecuteReaderQuery, NextResult, Read and so forth. - /// - /// - /// Only one level of nested operations is supported. It is an error to call this method if it has previously - /// been called, and the returned was not disposed. - /// - /// - /// The cancellation token provided by the user. Callbacks will be registered on this token for executing the - /// cancellation, and the token will be included in any thrown . - /// - /// - /// If , PostgreSQL cancellation will be attempted when the user requests cancellation or - /// a timeout occurs, followed by a client-side socket cancellation once - /// has elapsed. If , - /// PostgreSQL cancellation will be skipped and client-socket cancellation will occur immediately. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal CancellationTokenRegistration StartNestedCancellableOperation( - CancellationToken cancellationToken = default, - bool attemptPgCancellation = true) - { - UserCancellationToken = cancellationToken; - AttemptPostgresCancellation = attemptPgCancellation; - - return _cancellationTokenRegistration = - cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); - } - - #endregion Cancel - - #region Close / Reset - - /// - /// Closes ongoing operations, i.e. an open reader exists or a COPY operation still in progress, as - /// part of a connection close. - /// - internal async Task CloseOngoingOperations(bool async, CancellationToken cancellationToken = default) - { - var reader = CurrentReader; - var copyOperation = CurrentCopyOperation; - - if (reader != null) - await reader.Close(connectionClosing: true, async, isDisposing: false); - else if (copyOperation != null) - { - // TODO: There's probably a race condition as the COPY operation may finish on its own during the next few lines - - // Note: we only want to cancel import operations, since in these cases cancel is safe. - // Export cancellations go through the PostgreSQL "asynchronous" cancel mechanism and are - // therefore vulnerable to the race condition in #615. - if (copyOperation is NpgsqlBinaryImporter || - copyOperation is NpgsqlCopyTextWriter || - copyOperation is NpgsqlRawCopyStream rawCopyStream && rawCopyStream.CanWrite) - { - try - { - copyOperation.Cancel(); - } - catch (Exception e) - { - Log.Warn("Error while cancelling COPY on connector close", e, Id); - } - } - - try - { - copyOperation.Dispose(); - } - catch (Exception e) - { - Log.Warn("Error while disposing cancelled COPY on connector close", e, Id); - } - } - } - - // TODO in theory this should be async-optional, but the only I/O done here is the Terminate Flush, which is - // very unlikely to block (plus locking would need to be worked out) - internal void Close() - { - lock (this) - { - Log.Trace("Closing connector", Id); - - if (IsReady) - { - try - { - WriteTerminate(); - Flush(); - } - catch (Exception e) - { - Log.Error("Exception while closing connector", e, Id); - Debug.Assert(IsBroken); - } - } - - switch (State) - { - case ConnectorState.Broken: - case ConnectorState.Closed: - return; - } - - State = ConnectorState.Closed; - Cleanup(); - } - } - - public void Dispose() => Close(); - - /// - /// Called when an unexpected message has been received during an action. Breaks the - /// connector and returns the appropriate message. - /// - internal Exception UnexpectedMessageReceived(BackendMessageCode received) - => throw Break(new Exception($"Received unexpected backend message {received}. Please file a bug.")); - - /// - /// Called when a connector becomes completely unusable, e.g. when an unexpected I/O exception is raised or when - /// we lose protocol sync. - /// Note that fatal errors during the Open phase do *not* pass through here. - /// - /// The exception that caused the break. - /// The exception given in for chaining calls. - internal Exception Break(Exception reason) - { - Debug.Assert(!IsClosed); - - lock (this) - { - if (State != ConnectorState.Broken) - { - Log.Error("Breaking connector", reason, Id); - - // Note that we may be reading and writing from the same connector concurrently, so safely set - // the original reason for the break before actually closing the socket etc. - Interlocked.CompareExchange(ref _breakReason, reason, null); - - State = ConnectorState.Broken; - Cleanup(); - } - - return reason; - } - } - - /// - /// Closes the socket and cleans up client-side resources associated with this connector. - /// - /// - /// This method doesn't actually perform any meaningful I/O, and therefore is sync-only. - /// - void Cleanup() - { - Debug.Assert(Monitor.IsEntered(this)); - - if (Settings.Multiplexing) - { - FlagAsNotWritableForMultiplexing(); - - // Note that in multiplexing, this could be called from the read loop, while the write loop is - // writing into the channel. To make sure this race condition isn't a problem, the channel currently - // isn't set up with SingleWriter (since at this point it doesn't do anything). - CommandsInFlightWriter!.Complete(); - - // The connector's read loop has a continuation to observe and log any exception coming out - // (see Open) - } - - - Log.Trace("Cleaning up connector", Id); - try - { - _stream?.Dispose(); - } - catch - { - // ignored - } - - if (CurrentReader != null) - { - CurrentReader.Command.State = CommandState.Idle; - try - { - // Will never complete asynchronously (stream is already closed) - CurrentReader.Close(); - } - catch - { - // ignored - } - CurrentReader = null; - } - - ClearTransaction(); -#pragma warning disable CS8625 - - _stream = null; - _baseStream = null; - _origReadBuffer?.Dispose(); - _origReadBuffer = null; - ReadBuffer?.Dispose(); - ReadBuffer = null; - WriteBuffer?.Dispose(); - WriteBuffer = null; - Connection = null; - PostgresParameters.Clear(); - _currentCommand = null; - - if (_isKeepAliveEnabled) - { - _userLock!.Dispose(); - _userLock = null; - _keepAliveTimer!.Change(Timeout.Infinite, Timeout.Infinite); - _keepAliveTimer.Dispose(); - } -#pragma warning restore CS8625 - } - - void GenerateResetMessage() - { - var sb = new StringBuilder("SET SESSION AUTHORIZATION DEFAULT;RESET ALL;"); - _resetWithoutDeallocateResponseCount = 2; - if (DatabaseInfo.SupportsCloseAll) - { - sb.Append("CLOSE ALL;"); - _resetWithoutDeallocateResponseCount++; - } - if (DatabaseInfo.SupportsUnlisten) - { - sb.Append("UNLISTEN *;"); - _resetWithoutDeallocateResponseCount++; - } - if (DatabaseInfo.SupportsAdvisoryLocks) - { - sb.Append("SELECT pg_advisory_unlock_all();"); - _resetWithoutDeallocateResponseCount += 2; - } - if (DatabaseInfo.SupportsDiscardSequences) - { - sb.Append("DISCARD SEQUENCES;"); - _resetWithoutDeallocateResponseCount++; - } - if (DatabaseInfo.SupportsDiscardTemp) - { - sb.Append("DISCARD TEMP"); - _resetWithoutDeallocateResponseCount++; - } - - _resetWithoutDeallocateResponseCount++; // One ReadyForQuery at the end - - _resetWithoutDeallocateMessage = PregeneratedMessages.Generate(WriteBuffer, sb.ToString()); - } - - /// - /// Called when a pooled connection is closed, and its connector is returned to the pool. - /// Resets the connector back to its initial state, releasing server-side sources - /// (e.g. prepared statements), resetting parameters to their defaults, and resetting client-side - /// state - /// - internal async Task Reset(bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(IsReady); - - // Our buffer may contain unsent prepended messages (such as BeginTransaction), clear it out completely - WriteBuffer.Clear(); - PendingPrependedResponses = 0; - - // We may have allocated an oversize read buffer, switch back to the original one - // TODO: Replace this with array pooling, #2326 - if (_origReadBuffer != null) - { - ReadBuffer.Dispose(); - ReadBuffer = _origReadBuffer; - _origReadBuffer = null; - } - - Transaction?.UnbindIfNecessary(); - - var endBindingScope = false; - - // Must rollback transaction before sending DISCARD ALL - switch (TransactionStatus) - { - case TransactionStatus.Idle: - // There is an undisposed transaction on multiplexing connection - endBindingScope = Connection?.ConnectorBindingScope == ConnectorBindingScope.Transaction; - break; - case TransactionStatus.Pending: - // BeginTransaction() was called, but was left in the write buffer and not yet sent to server. - // Just clear the transaction state. - ProcessNewTransactionStatus(TransactionStatus.Idle); - ClearTransaction(); - endBindingScope = true; - break; - case TransactionStatus.InTransactionBlock: - case TransactionStatus.InFailedTransactionBlock: - await Rollback(async, cancellationToken); - ClearTransaction(); - endBindingScope = true; - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {TransactionStatus} of enum {nameof(TransactionStatus)}. Please file a bug."); - } - - if (_sendResetOnClose) - { - if (PreparedStatementManager.NumPrepared > 0) - { - // We have prepared statements, so we can't reset the connection state with DISCARD ALL - // Note: the send buffer has been cleared above, and we assume all this will fit in it. - PrependInternalMessage(_resetWithoutDeallocateMessage!, _resetWithoutDeallocateResponseCount); - } - else - { - // There are no prepared statements. - // We simply send DISCARD ALL which is more efficient than sending the above messages separately - PrependInternalMessage(PregeneratedMessages.DiscardAll, 2); - } - } - - DataReader.UnbindIfNecessary(); - - if (endBindingScope) - { - // Connection is null if a connection enlisted in a TransactionScope was closed before the - // TransactionScope completed - the connector is still enlisted, but has no connection. - Connection?.EndBindingScope(ConnectorBindingScope.Transaction); - } - } - - internal void UnprepareAll() - { - ExecuteInternalCommand("DEALLOCATE ALL"); - PreparedStatementManager.ClearAll(); - } - - #endregion Close / Reset - - #region Locking - - internal UserAction StartUserAction(CancellationToken cancellationToken = default, bool attemptPgCancellation = true) - => StartUserAction(ConnectorState.Executing, command: null, cancellationToken, attemptPgCancellation); - - internal UserAction StartUserAction( - ConnectorState newState, - CancellationToken cancellationToken = default, - bool attemptPgCancellation = true) - => StartUserAction(newState, command: null, cancellationToken, attemptPgCancellation); - - /// - /// Starts a user action. This makes sure that another action isn't already in progress, handles synchronization with keepalive, - /// and sets up cancellation. - /// - /// The new state to be set when entering this user action. - /// - /// The that is starting execution - if an is - /// thrown, it will reference this. - /// - /// - /// The cancellation token provided by the user. Callbacks will be registered on this token for executing the cancellation, - /// and the token will be included in any thrown . - /// - /// - /// If , PostgreSQL cancellation will be attempted when the user requests cancellation or a timeout - /// occurs, followed by a client-side socket cancellation once has - /// elapsed. If , PostgreSQL cancellation will be skipped and client-socket cancellation will occur - /// immediately. - /// - internal UserAction StartUserAction( - ConnectorState newState, - NpgsqlCommand? command, - CancellationToken cancellationToken = default, - bool attemptPgCancellation = true) - { - // If keepalive is enabled, we must protect state transitions with a SemaphoreSlim - // (which itself must be protected by a lock, since its dispose isn't thread-safe). - // This will make the keepalive abort safely if a user query is in progress, and make - // the user query wait if a keepalive is in progress. - - // If keepalive isn't enabled, we don't use the semaphore and rely only on the connector's - // state (updated via Interlocked.Exchange) to detect concurrent use, on a best-effort basis. - if (!_isKeepAliveEnabled) - return DoStartUserAction(); - - lock (this) - { - if (!_userLock!.Wait(0)) - { - var currentCommand = _currentCommand; - throw currentCommand == null - ? new NpgsqlOperationInProgressException(State) - : new NpgsqlOperationInProgressException(currentCommand); - } - - try - { - // Disable keepalive, it will be restarted at the end of the user action - _keepAliveTimer!.Change(Timeout.Infinite, Timeout.Infinite); - - // We now have both locks and are sure nothing else is running. - // Check that the connector is ready. - return DoStartUserAction(); - } - catch - { - _userLock.Release(); - throw; - } - } - - UserAction DoStartUserAction() - { - switch (State) - { - case ConnectorState.Ready: - break; - case ConnectorState.Closed: - case ConnectorState.Broken: - throw new InvalidOperationException("Connection is not open"); - case ConnectorState.Executing: - case ConnectorState.Fetching: - case ConnectorState.Waiting: - case ConnectorState.Replication: - case ConnectorState.Connecting: - case ConnectorState.Copy: - var currentCommand = _currentCommand; - throw currentCommand == null - ? new NpgsqlOperationInProgressException(State) - : new NpgsqlOperationInProgressException(currentCommand); - default: - throw new ArgumentOutOfRangeException(nameof(State), State, "Invalid connector state: " + State); - } - - Debug.Assert(IsReady); - - cancellationToken.ThrowIfCancellationRequested(); - - Log.Trace("Start user action", Id); - State = newState; - _currentCommand = command; - - StartCancellableOperation(cancellationToken, attemptPgCancellation); - - // We reset the UserTimeout for every user action, so it wouldn't leak from the previous query or action - // For example, we might have successfully cancelled the previous query (so the connection is not broken) - // But the next time, we call the Prepare, which doesn't set it's own timeout - UserTimeout = (command?.CommandTimeout ?? Settings.CommandTimeout) * 1000; - - return new UserAction(this); - } - } - - internal void EndUserAction() - { - Debug.Assert(CurrentReader == null); - - _cancellationTokenRegistration.Dispose(); - - if (_isKeepAliveEnabled) - { - lock (this) - { - if (IsReady || !IsConnected) - return; - - var keepAlive = Settings.KeepAlive * 1000; - _keepAliveTimer!.Change(keepAlive, keepAlive); - - Log.Trace("End user action", Id); - _currentCommand = null; - _userLock!.Release(); - State = ConnectorState.Ready; - } - } - else - { - if (IsReady || !IsConnected) - return; - - Log.Trace("End user action", Id); - _currentCommand = null; - State = ConnectorState.Ready; - } - } - - /// - /// An IDisposable wrapper around . - /// - internal readonly struct UserAction : IDisposable - { - readonly NpgsqlConnector _connector; - internal UserAction(NpgsqlConnector connector) => _connector = connector; - public void Dispose() => _connector.EndUserAction(); - } - - #endregion - - #region Keepalive - -#pragma warning disable CA1801 // Review unused parameters - void PerformKeepAlive(object? state) - { - Debug.Assert(_isKeepAliveEnabled); - - // SemaphoreSlim.Dispose() isn't thread-safe - it may be in progress so we shouldn't try to wait on it; - // we need a standard lock to protect it. - if (!Monitor.TryEnter(this)) - return; - - try - { - // There may already be a user action, or the connector may be closed etc. - if (!IsReady) - return; - - Log.Trace("Performed keepalive", Id); - WritePregenerated(PregeneratedMessages.KeepAlive); - Flush(); - SkipUntil(BackendMessageCode.ReadyForQuery); - } - catch (Exception e) - { - Log.Error("Keepalive failure", e, Id); - try - { - Break(e); - } - catch (Exception e2) - { - Log.Error("Further exception while breaking connector on keepalive failure", e2, Id); - } - } - finally - { - Monitor.Exit(this); - } - } -#pragma warning restore CA1801 // Review unused parameters - - #endregion - - #region Wait - - internal async Task Wait(bool async, int timeout, CancellationToken cancellationToken = default) - { - using var _ = StartUserAction(ConnectorState.Waiting, cancellationToken: cancellationToken, attemptPgCancellation: false); - - // We may have prepended messages in the connection's write buffer - these need to be flushed now. - await Flush(async, cancellationToken); - - var keepaliveMs = Settings.KeepAlive * 1000; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - var timeoutForKeepalive = _isKeepAliveEnabled && (timeout <= 0 || keepaliveMs < timeout); - UserTimeout = timeoutForKeepalive ? keepaliveMs : timeout; - try - { - var msg = await ReadMessageWithNotifications(async); - if (msg != null) - { - throw Break( - new NpgsqlException($"Received unexpected message of type {msg.Code} while waiting")); - } - return true; - } - catch (NpgsqlException e) when (e.InnerException is TimeoutException) - { - if (!timeoutForKeepalive) // We really timed out - return false; - } - - // Time for a keepalive - var keepaliveTime = Stopwatch.StartNew(); - await WritePregenerated(PregeneratedMessages.KeepAlive, async, cancellationToken); - await Flush(async, cancellationToken); - - var receivedNotification = false; - var expectedMessageCode = BackendMessageCode.RowDescription; - - while (true) - { - IBackendMessage? msg; - - try - { - msg = await ReadMessageWithNotifications(async); - } - catch (Exception e) when (e is OperationCanceledException || e is NpgsqlException npgEx && npgEx.InnerException is TimeoutException) - { - // We're somewhere in the middle of a reading keepalive messages - // Breaking the connection, as we've lost protocol sync - throw Break(e); - } - - if (msg == null) - { - receivedNotification = true; - continue; - } - - if (msg.Code != expectedMessageCode) - throw new NpgsqlException($"Received unexpected message of type {msg.Code} while expecting {expectedMessageCode} as part of keepalive"); - - switch (msg.Code) - { - case BackendMessageCode.RowDescription: - expectedMessageCode = BackendMessageCode.DataRow; - continue; - case BackendMessageCode.DataRow: - // DataRow is usually consumed by a reader, here we have to skip it manually. - await ReadBuffer.Skip(((DataRowMessage)msg).Length, async); - expectedMessageCode = BackendMessageCode.CommandComplete; - continue; - case BackendMessageCode.CommandComplete: - expectedMessageCode = BackendMessageCode.ReadyForQuery; - continue; - case BackendMessageCode.ReadyForQuery: - break; - } - Log.Trace("Performed keepalive", Id); - - if (receivedNotification) - return true; // Notification was received during the keepalive - cancellationToken.ThrowIfCancellationRequested(); - break; - } - - if (timeout > 0) - timeout -= (keepaliveMs + (int)keepaliveTime.ElapsedMilliseconds); - } - } - - #endregion - - #region Supported features and PostgreSQL settings - - /// - /// The connection's timezone as reported by PostgreSQL, in the IANA/Olson database format. - /// - internal string Timezone { get; private set; } = default!; - - #endregion Supported features and PostgreSQL settings - - #region Execute internal command - - internal void ExecuteInternalCommand(string query) - => ExecuteInternalCommand(query, false).GetAwaiter().GetResult(); - - internal async Task ExecuteInternalCommand(string query, bool async, CancellationToken cancellationToken = default) - { - Log.Trace($"Executing internal command: {query}", Id); - - await WriteQuery(query, async, cancellationToken); - await Flush(async, cancellationToken); - Expect(await ReadMessage(async), this); - Expect(await ReadMessage(async), this); - } - - internal async Task ExecuteInternalCommand(byte[] data, bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(State != ConnectorState.Ready, "Forgot to start a user action..."); - - Log.Trace("Executing internal pregenerated command", Id); - - await WritePregenerated(data, async, cancellationToken); - await Flush(async, cancellationToken); - Expect(await ReadMessage(async), this); - Expect(await ReadMessage(async), this); - } - - #endregion - - #region Misc - - void ReadParameterStatus(ReadOnlySpan incomingName, ReadOnlySpan incomingValue) - { - byte[] rawName; - byte[] rawValue; - - foreach (var current in _rawParameters) - if (incomingName.SequenceEqual(current.Name)) - { - if (incomingValue.SequenceEqual(current.Value)) - return; - - rawName = current.Name; - rawValue = incomingValue.ToArray(); - goto ProcessParameter; - } - - rawName = incomingName.ToArray(); - rawValue = incomingValue.ToArray(); - _rawParameters.Add((rawName, rawValue)); - - ProcessParameter: - var name = TextEncoding.GetString(rawName); - var value = TextEncoding.GetString(rawValue); - - PostgresParameters[name] = value; - - switch (name) - { - case "standard_conforming_strings": - if (value != "on") - throw Break(new NotSupportedException("standard_conforming_strings must be on")); - return; - - case "TimeZone": - Timezone = value; - return; - } - } - - #endregion Misc - } - - #region Enums - - /// - /// Expresses the exact state of a connector. - /// - enum ConnectorState - { - /// - /// The connector has either not yet been opened or has been closed. - /// - Closed, - - /// - /// The connector is currently connecting to a PostgreSQL server. - /// - Connecting, - - /// - /// The connector is connected and may be used to send a new query. - /// - Ready, - - /// - /// The connector is waiting for a response to a query which has been sent to the server. - /// - Executing, - - /// - /// The connector is currently fetching and processing query results. - /// - Fetching, - - /// - /// The connector is currently waiting for asynchronous notifications to arrive. - /// - Waiting, - - /// - /// The connection was broken because an unexpected error occurred which left it in an unknown state. - /// This state isn't implemented yet. - /// - Broken, - - /// - /// The connector is engaged in a COPY operation. - /// - Copy, - - /// - /// The connector is engaged in streaming replication. - /// - Replication, - } - -#pragma warning disable CA1717 - enum TransactionStatus : byte -#pragma warning restore CA1717 - { - /// - /// Currently not in a transaction block - /// - Idle = (byte)'I', - - /// - /// Currently in a transaction block - /// - InTransactionBlock = (byte)'T', - - /// - /// Currently in a failed transaction block (queries will be rejected until block is ended) - /// - InFailedTransactionBlock = (byte)'E', - - /// - /// A new transaction has been requested but not yet transmitted to the backend. It will be transmitted - /// prepended to the next query. - /// This is a client-side state option only, and is never transmitted from the backend. - /// - Pending = byte.MaxValue, - } - - /// - /// Specifies how to load/parse DataRow messages as they're received from the backend. - /// - internal enum DataRowLoadingMode - { - /// - /// Load DataRows in non-sequential mode - /// - NonSequential, - - /// - /// Load DataRows in sequential mode - /// - Sequential, - - /// - /// Skip DataRow messages altogether - /// - Skip - } - - #endregion -} diff --git a/src/Npgsql/NpgsqlDataAdapter.cs b/src/Npgsql/NpgsqlDataAdapter.cs index 77957c09ae..c18773b2d6 100644 --- a/src/Npgsql/NpgsqlDataAdapter.cs +++ b/src/Npgsql/NpgsqlDataAdapter.cs @@ -1,220 +1,230 @@ using System; using System.Data; using System.Data.Common; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; -using JetBrains.Annotations; -namespace Npgsql +namespace Npgsql; + +/// +/// Represents the method that handles the events. +/// +/// The source of the event. +/// An that contains the event data. +public delegate void NpgsqlRowUpdatedEventHandler(object sender, NpgsqlRowUpdatedEventArgs e); + +/// +/// Represents the method that handles the events. +/// +/// The source of the event. +/// An that contains the event data. +public delegate void NpgsqlRowUpdatingEventHandler(object sender, NpgsqlRowUpdatingEventArgs e); + +/// +/// This class represents an adapter from many commands: select, update, insert and delete to fill a . +/// +[System.ComponentModel.DesignerCategory("")] +public sealed class NpgsqlDataAdapter : DbDataAdapter { /// - /// Represents the method that handles the RowUpdated events. + /// Row updated event. /// - /// The source of the event. - /// A NpgsqlRowUpdatedEventArgs that contains the event data. - public delegate void NpgsqlRowUpdatedEventHandler(object sender, NpgsqlRowUpdatedEventArgs e); + public event NpgsqlRowUpdatedEventHandler? RowUpdated; /// - /// Represents the method that handles the RowUpdating events. + /// Row updating event. /// - /// The source of the event. - /// A NpgsqlRowUpdatingEventArgs that contains the event data. - public delegate void NpgsqlRowUpdatingEventHandler(object sender, NpgsqlRowUpdatingEventArgs e); + public event NpgsqlRowUpdatingEventHandler? RowUpdating; /// - /// This class represents an adapter from many commands: select, update, insert and delete to fill Datasets. + /// Default constructor. /// - [System.ComponentModel.DesignerCategory("")] - public sealed class NpgsqlDataAdapter : DbDataAdapter + public NpgsqlDataAdapter() {} + + /// + /// Constructor. + /// + /// + public NpgsqlDataAdapter(NpgsqlCommand selectCommand) + => SelectCommand = selectCommand; + + /// + /// Constructor. + /// + /// + /// + public NpgsqlDataAdapter(string selectCommandText, NpgsqlConnection selectConnection) + : this(new NpgsqlCommand(selectCommandText, selectConnection)) {} + + /// + /// Constructor. + /// + /// + /// + public NpgsqlDataAdapter(string selectCommandText, string selectConnectionString) + : this(selectCommandText, new NpgsqlConnection(selectConnectionString)) {} + + /// + /// Create row updated event. + /// + protected override RowUpdatedEventArgs CreateRowUpdatedEvent(DataRow dataRow, IDbCommand? command, + System.Data.StatementType statementType, + DataTableMapping tableMapping) + => new NpgsqlRowUpdatedEventArgs(dataRow, command, statementType, tableMapping); + + /// + /// Create row updating event. + /// + protected override RowUpdatingEventArgs CreateRowUpdatingEvent(DataRow dataRow, IDbCommand? command, + System.Data.StatementType statementType, + DataTableMapping tableMapping) + => new NpgsqlRowUpdatingEventArgs(dataRow, command, statementType, tableMapping); + + /// + /// Raise the RowUpdated event. + /// + /// + protected override void OnRowUpdated(RowUpdatedEventArgs value) { - /// - /// Row updated event. - /// - public event NpgsqlRowUpdatedEventHandler? RowUpdated; - - /// - /// Row updating event. - /// - public event NpgsqlRowUpdatingEventHandler? RowUpdating; - - /// - /// Default constructor. - /// - public NpgsqlDataAdapter() {} - - /// - /// Constructor. - /// - /// - public NpgsqlDataAdapter(NpgsqlCommand selectCommand) - => SelectCommand = selectCommand; - - /// - /// Constructor. - /// - /// - /// - public NpgsqlDataAdapter(string selectCommandText, NpgsqlConnection selectConnection) - : this(new NpgsqlCommand(selectCommandText, selectConnection)) {} - - /// - /// Constructor. - /// - /// - /// - public NpgsqlDataAdapter(string selectCommandText, string selectConnectionString) - : this(selectCommandText, new NpgsqlConnection(selectConnectionString)) {} - - /// - /// Create row updated event. - /// - protected override RowUpdatedEventArgs CreateRowUpdatedEvent(DataRow dataRow, IDbCommand? command, - System.Data.StatementType statementType, - DataTableMapping tableMapping) - => new NpgsqlRowUpdatedEventArgs(dataRow, command, statementType, tableMapping); - - /// - /// Create row updating event. - /// - protected override RowUpdatingEventArgs CreateRowUpdatingEvent(DataRow dataRow, IDbCommand? command, - System.Data.StatementType statementType, - DataTableMapping tableMapping) - => new NpgsqlRowUpdatingEventArgs(dataRow, command, statementType, tableMapping); - - /// - /// Raise the RowUpdated event. - /// - /// - protected override void OnRowUpdated(RowUpdatedEventArgs value) - { - //base.OnRowUpdated(value); - if (value is NpgsqlRowUpdatedEventArgs args) - RowUpdated?.Invoke(this, args); - //if (RowUpdated != null && value is NpgsqlRowUpdatedEventArgs args) - // RowUpdated(this, args); - } + //base.OnRowUpdated(value); + if (value is NpgsqlRowUpdatedEventArgs args) + RowUpdated?.Invoke(this, args); + //if (RowUpdated != null && value is NpgsqlRowUpdatedEventArgs args) + // RowUpdated(this, args); + } - /// - /// Raise the RowUpdating event. - /// - /// - protected override void OnRowUpdating(RowUpdatingEventArgs value) - { - if (value is NpgsqlRowUpdatingEventArgs args) - RowUpdating?.Invoke(this, args); - } + /// + /// Raise the RowUpdating event. + /// + /// + protected override void OnRowUpdating(RowUpdatingEventArgs value) + { + if (value is NpgsqlRowUpdatingEventArgs args) + RowUpdating?.Invoke(this, args); + } - /// - /// Delete command. - /// - public new NpgsqlCommand? DeleteCommand - { - get => (NpgsqlCommand?)base.DeleteCommand; - set => base.DeleteCommand = value; - } + /// + /// Delete command. + /// + public new NpgsqlCommand? DeleteCommand + { + get => (NpgsqlCommand?)base.DeleteCommand; + set => base.DeleteCommand = value; + } - /// - /// Select command. - /// - public new NpgsqlCommand? SelectCommand - { - get => (NpgsqlCommand?)base.SelectCommand; - set => base.SelectCommand = value; - } + /// + /// Select command. + /// + public new NpgsqlCommand? SelectCommand + { + get => (NpgsqlCommand?)base.SelectCommand; + set => base.SelectCommand = value; + } - /// - /// Update command. - /// - public new NpgsqlCommand? UpdateCommand - { - get => (NpgsqlCommand?)base.UpdateCommand; - set => base.UpdateCommand = value; - } + /// + /// Update command. + /// + public new NpgsqlCommand? UpdateCommand + { + get => (NpgsqlCommand?)base.UpdateCommand; + set => base.UpdateCommand = value; + } - /// - /// Insert command. - /// - public new NpgsqlCommand? InsertCommand - { - get => (NpgsqlCommand?)base.InsertCommand; - set => base.InsertCommand = value; - } + /// + /// Insert command. + /// + public new NpgsqlCommand? InsertCommand + { + get => (NpgsqlCommand?)base.InsertCommand; + set => base.InsertCommand = value; + } - // Temporary implementation, waiting for official support in System.Data via https://github.com/dotnet/runtime/issues/22109 - internal async Task Fill(DataTable dataTable, bool async, CancellationToken cancellationToken = default) + // Temporary implementation, waiting for official support in System.Data via https://github.com/dotnet/runtime/issues/22109 + [RequiresUnreferencedCode("Members from serialized types or types used in expressions may be trimmed if not referenced directly.")] + internal async Task Fill(DataTable dataTable, bool async, CancellationToken cancellationToken = default) + { + var command = SelectCommand; + var activeConnection = command?.Connection ?? throw new InvalidOperationException("Connection required"); + var originalState = ConnectionState.Closed; + + try { - var command = SelectCommand; - var activeConnection = command?.Connection ?? throw new InvalidOperationException("Connection required"); - var originalState = ConnectionState.Closed; + originalState = activeConnection.State; + if (ConnectionState.Closed == originalState) + await activeConnection.Open(async, cancellationToken).ConfigureAwait(false); + var dataReader = await command.ExecuteReader(async, CommandBehavior.Default, cancellationToken).ConfigureAwait(false); try { - originalState = activeConnection.State; - if (ConnectionState.Closed == originalState) - await activeConnection.Open(async, cancellationToken); - - using var dataReader = await command.ExecuteReader(CommandBehavior.Default, async, cancellationToken); - - return await Fill(dataTable, dataReader, async, cancellationToken); + return await Fill(dataTable, dataReader, async, cancellationToken).ConfigureAwait(false); } finally { - if (ConnectionState.Closed == originalState) - activeConnection.Close(); + if (async) + await dataReader.DisposeAsync().ConfigureAwait(false); + else + dataReader.Dispose(); } } + finally + { + if (ConnectionState.Closed == originalState) + activeConnection.Close(); + } + } - async Task Fill(DataTable dataTable, NpgsqlDataReader dataReader, bool async, CancellationToken cancellationToken = default) + [RequiresUnreferencedCode("Members from serialized types or types used in expressions may be trimmed if not referenced directly.")] + async Task Fill(DataTable dataTable, NpgsqlDataReader dataReader, bool async, CancellationToken cancellationToken = default) + { + dataTable.BeginLoadData(); + try { - dataTable.BeginLoadData(); - try + var rowsAdded = 0; + var count = dataReader.FieldCount; + var columnCollection = dataTable.Columns; + for (var i = 0; i < count; ++i) { - var rowsAdded = 0; - var count = dataReader.FieldCount; - var columnCollection = dataTable.Columns; - for (var i = 0; i < count; ++i) + var fieldName = dataReader.GetName(i); + if (!columnCollection.Contains(fieldName)) { - var fieldName = dataReader.GetName(i); - if (!columnCollection.Contains(fieldName)) - { - var fieldType = dataReader.GetFieldType(i); - var dataColumn = new DataColumn(fieldName, fieldType); - columnCollection.Add(dataColumn); - } + var fieldType = dataReader.GetFieldType(i); + var dataColumn = new DataColumn(fieldName, fieldType); + columnCollection.Add(dataColumn); } + } - var values = new object[count]; + var values = new object[count]; - while (async ? await dataReader.ReadAsync(cancellationToken) : dataReader.Read()) - { - dataReader.GetValues(values); - dataTable.LoadDataRow(values, true); - rowsAdded++; - } - return rowsAdded; - } - finally + while (async ? await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false) : dataReader.Read()) { - dataTable.EndLoadData(); + dataReader.GetValues(values); + dataTable.LoadDataRow(values, true); + rowsAdded++; } + return rowsAdded; + } + finally + { + dataTable.EndLoadData(); } } +} #pragma warning disable 1591 - public class NpgsqlRowUpdatingEventArgs : RowUpdatingEventArgs - { - public NpgsqlRowUpdatingEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, - DataTableMapping tableMapping) - : base(dataRow, command, statementType, tableMapping) {} - } +public class NpgsqlRowUpdatingEventArgs : RowUpdatingEventArgs +{ + public NpgsqlRowUpdatingEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, + DataTableMapping tableMapping) + : base(dataRow, command, statementType, tableMapping) {} +} - public class NpgsqlRowUpdatedEventArgs : RowUpdatedEventArgs - { - public NpgsqlRowUpdatedEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, - DataTableMapping tableMapping) - : base(dataRow, command, statementType, tableMapping) {} - } +public class NpgsqlRowUpdatedEventArgs : RowUpdatedEventArgs +{ + public NpgsqlRowUpdatedEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, + DataTableMapping tableMapping) + : base(dataRow, command, statementType, tableMapping) {} +} #pragma warning restore 1591 -} diff --git a/src/Npgsql/NpgsqlDataReader.cs b/src/Npgsql/NpgsqlDataReader.cs index e4296f5b17..0dba265918 100644 --- a/src/Npgsql/NpgsqlDataReader.cs +++ b/src/Npgsql/NpgsqlDataReader.cs @@ -1,650 +1,842 @@ using System; +using System.Buffers; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data; using System.Data.Common; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; -using System.Linq; using System.Runtime.CompilerServices; -using System.Text; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; -using JetBrains.Annotations; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; -using Npgsql.Logging; +using Npgsql.Internal; +using Npgsql.Internal.Converters; using Npgsql.PostgresTypes; using Npgsql.Schema; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandling; -using Npgsql.Util; using NpgsqlTypes; using static Npgsql.Util.Statics; #pragma warning disable CA2222 // Do not decrease inherited member visibility -namespace Npgsql +namespace Npgsql; + +/// +/// Reads a forward-only stream of rows from a data source. +/// +#pragma warning disable CA1010 +public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator +#pragma warning restore CA1010 { + static readonly Task TrueTask = Task.FromResult(true); + static readonly Task FalseTask = Task.FromResult(false); + + internal NpgsqlCommand Command { get; private set; } = default!; + internal NpgsqlConnector Connector { get; } + NpgsqlConnection? _connection; + /// - /// Reads a forward-only stream of rows from a data source. + /// The behavior of the command with which this reader was executed. /// -#pragma warning disable CA1010 - public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator -#pragma warning restore CA1010 - { - internal NpgsqlCommand Command { get; private set; } = default!; - internal NpgsqlConnector Connector { get; } - NpgsqlConnection _connection = default!; - - /// - /// The behavior of the command with which this reader was executed. - /// - CommandBehavior _behavior; - - /// - /// In multiplexing, this is as the sending is managed in the write multiplexing loop, - /// and does not need to be awaited by the reader. - /// - Task? _sendTask; - - internal ReaderState State = ReaderState.Disposed; - - internal NpgsqlReadBuffer Buffer = default!; - - /// - /// Holds the list of statements being executed by this reader. - /// - List _statements = default!; - - /// - /// The index of the current query resultset we're processing (within a multiquery) - /// - internal int StatementIndex { get; private set; } - - /// - /// The number of columns in the current row - /// - int _numColumns; - - /// - /// Records, for each column, its starting offset and length in the current row. - /// Used only in non-sequential mode. - /// - readonly List<(int Offset, int Length)> _columns = new List<(int Offset, int Length)>(); - - /// - /// The index of the column that we're on, i.e. that has already been parsed, is - /// is memory and can be retrieved. Initialized to -1, which means we're on the column - /// count (which comes before the first column). - /// - int _column; - - /// - /// For streaming types (e.g. bytea), holds the byte length of the column. - /// Does not include the length prefix. - /// - internal int ColumnLen; - - internal int PosInColumn; - - /// - /// The position in the buffer at which the current data row message ends. - /// Used only in non-sequential mode. - /// - int _dataMsgEnd; - - int _charPos; - - /// - /// The RowDescription message for the current resultset being processed - /// - internal RowDescriptionMessage? RowDescription; - - ulong? _recordsAffected; - - /// - /// Whether the current result set has rows - /// - bool _hasRows; - - /// - /// Is raised whenever Close() is called. - /// - public event EventHandler? ReaderClosed; - - bool _isSchemaOnly; - bool _isSequential; - - /// - /// A stream that has been opened on a column. - /// - NpgsqlReadBuffer.ColumnStream? _columnStream; - - /// - /// Used for internal temporary purposes - /// - char[]? _tempCharBuf; - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlDataReader)); - - internal NpgsqlDataReader(NpgsqlConnector connector) - { - Connector = connector; - } + CommandBehavior _behavior; - internal void Init( - NpgsqlCommand command, CommandBehavior behavior, List statements, Task? sendTask = null) - { - Command = command; - _connection = command.Connection!; - _behavior = behavior; - _isSchemaOnly = _behavior.HasFlag(CommandBehavior.SchemaOnly); - _isSequential = _behavior.HasFlag(CommandBehavior.SequentialAccess); - _statements = statements; - StatementIndex = -1; - _sendTask = sendTask; - State = ReaderState.BetweenResults; - _recordsAffected = null; - } + /// + /// In multiplexing, this is as the sending is managed in the write multiplexing loop, + /// and does not need to be awaited by the reader. + /// + Task? _sendTask; - #region Read + internal ReaderState State = ReaderState.Disposed; - /// - /// Advances the reader to the next record in a result set. - /// - /// true if there are more rows; otherwise false. - /// - /// The default position of a data reader is before the first record. Therefore, you must call Read to begin accessing data. - /// - public override bool Read() - { - CheckClosedOrDisposed(); + internal NpgsqlReadBuffer Buffer = default!; + PgReader PgReader => Buffer.PgReader; - var fastRead = TryFastRead(); - return fastRead.HasValue - ? fastRead.Value - : Read(false).GetAwaiter().GetResult(); - } + /// + /// Holds the list of statements being executed by this reader. + /// + List _statements = default!; - /// - /// This is the asynchronous version of - /// - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation. - public override Task ReadAsync(CancellationToken cancellationToken) - { - CheckClosedOrDisposed(); + /// + /// The index of the current query resultset we're processing (within a multiquery) + /// + internal int StatementIndex { get; private set; } + + /// + /// Records, for each column, its starting offset and length in the current row. + /// Used only in non-sequential mode. + /// + readonly List<(int Offset, int Length)> _columns = new(); + int _columnsStartPos; - var fastRead = TryFastRead(); - if (fastRead.HasValue) - return fastRead.Value ? PGUtil.TrueTask : PGUtil.FalseTask; + /// + /// The index of the column that we're on, i.e. that has already been parsed, is + /// is memory and can be retrieved. Initialized to -1, which means we're on the column + /// count (which comes before the first column). + /// + int _column; - using (NoSynchronizationContextScope.Enter()) - return Read(true, cancellationToken); - } + /// + /// The position in the buffer at which the current data row message ends. + /// Used only when the row is consumed non-sequentially. + /// + int _dataMsgEnd; + + /// + /// Determines, if we can consume the row non-sequentially. + /// Mostly useful for a sequential mode, when the row is already in the buffer. + /// Should always be true for the non-sequential mode. + /// + bool _canConsumeRowNonSequentially; + + /// + /// The RowDescription message for the current resultset being processed + /// + internal RowDescriptionMessage? RowDescription; + + int ColumnCount => RowDescription!.Count; - bool? TryFastRead() + /// + /// Stores the last converter info resolved by column, to speed up repeated reading. + /// + ColumnInfo[]? ColumnInfoCache { get; set; } + + ulong? _recordsAffected; + + /// + /// Whether the current result set has rows + /// + bool _hasRows; + + /// + /// Is raised whenever Close() is called. + /// + public event EventHandler? ReaderClosed; + + bool _isSchemaOnly; + bool _isSequential; + + internal NpgsqlNestedDataReader? CachedFreeNestedDataReader; + + long _startTimestamp; + readonly ILogger _commandLogger; + + internal NpgsqlDataReader(NpgsqlConnector connector) + { + Connector = connector; + _commandLogger = connector.CommandLogger; + } + + internal void Init( + NpgsqlCommand command, + CommandBehavior behavior, + List statements, + long startTimestamp = 0, + Task? sendTask = null) + { + Debug.Assert(ColumnInfoCache is null); + Command = command; + _connection = command.InternalConnection; + _behavior = behavior; + _isSchemaOnly = _behavior.HasFlag(CommandBehavior.SchemaOnly); + _isSequential = _behavior.HasFlag(CommandBehavior.SequentialAccess); + _statements = statements; + StatementIndex = -1; + _sendTask = sendTask; + State = ReaderState.BetweenResults; + _recordsAffected = null; + _startTimestamp = startTimestamp; + } + + #region Read + + /// + /// Advances the reader to the next record in a result set. + /// + /// true if there are more rows; otherwise false. + /// + /// The default position of a data reader is before the first record. Therefore, you must call Read to begin accessing data. + /// + public override bool Read() + { + CheckClosedOrDisposed(); + return TryRead()?.Result ?? Read(false).GetAwaiter().GetResult(); + } + + /// + /// This is the asynchronous version of + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + public override Task ReadAsync(CancellationToken cancellationToken) + { + CheckClosedOrDisposed(); + return TryRead() ?? Read(async: true, cancellationToken); + } + + // This is an optimized execution path that avoids calling any async methods for the (usual) + // case where the next row (or CommandComplete) is already in memory. + Task? TryRead() + { + switch (State) { - // This is an optimized execution path that avoids calling any async methods for the (usual) - // case where the next row (or CommandComplete) is already in memory. + case ReaderState.BeforeResult: + // First Read() after NextResult. Data row has already been processed. + State = ReaderState.InResult; + return TrueTask; + case ReaderState.InResult: + break; + default: + return FalseTask; + } - if (_behavior.HasFlag(CommandBehavior.SingleRow)) - return null; + // We have a special case path for SingleRow. + if (_behavior.HasFlag(CommandBehavior.SingleRow) || !_canConsumeRowNonSequentially) + return null; + + ConsumeRowNonSequential(); + + const int headerSize = sizeof(byte) + sizeof(int); + var buffer = Buffer; + var readPosition = buffer.ReadPosition; + var bytesLeft = buffer.FilledBytes - readPosition; + if (bytesLeft < headerSize) + return null; + var messageCode = (BackendMessageCode)buffer.ReadByte(); + var len = buffer.ReadInt32() - sizeof(int); // Transmitted length includes itself + var isDataRow = messageCode is BackendMessageCode.DataRow; + // sizeof(short) is for the number of columns + var sufficientBytes = isDataRow && _isSequential ? headerSize + sizeof(short) : headerSize + len; + if (bytesLeft < sufficientBytes + || !isDataRow && (_statements[StatementIndex].AppendErrorBarrier ?? Command.EnableErrorBarriers) + // Could be an error, let main read handle it. + || Connector.ParseResultSetMessage(buffer, messageCode, len) is not { } msg) + { + buffer.ReadPosition = readPosition; + return null; + } + ProcessMessage(msg); + return isDataRow ? TrueTask : FalseTask; + } + async Task Read(bool async, CancellationToken cancellationToken = default) + { + using var registration = Connector.StartNestedCancellableOperation(cancellationToken); + try + { switch (State) { case ReaderState.BeforeResult: // First Read() after NextResult. Data row has already been processed. State = ReaderState.InResult; return true; + case ReaderState.InResult: - if (_isSequential) - return null; - ConsumeRowNonSequential(); + await ConsumeRow(async).ConfigureAwait(false); + if (_behavior.HasFlag(CommandBehavior.SingleRow)) + { + // TODO: See optimization proposal in #410 + await Consume(async).ConfigureAwait(false); + return false; + } break; + case ReaderState.BetweenResults: case ReaderState.Consumed: case ReaderState.Closed: case ReaderState.Disposed: return false; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(); + return false; } - var readBuf = Connector.ReadBuffer; - if (readBuf.ReadBytesLeft < 5) - return null; - var messageCode = (BackendMessageCode)readBuf.ReadByte(); - var len = readBuf.ReadInt32() - 4; // Transmitted length includes itself - if (messageCode != BackendMessageCode.DataRow || readBuf.ReadBytesLeft < len) - { - readBuf.ReadPosition -= 5; - return null; - } - - var msg = Connector.ParseServerMessage(readBuf, messageCode, len, false)!; - Debug.Assert(msg.Code == BackendMessageCode.DataRow); - ProcessMessage(msg); - return true; - } - - async Task Read(bool async, CancellationToken cancellationToken = default) - { - var registration = Connector.StartNestedCancellableOperation(cancellationToken); + var msg = await ReadMessage(async).ConfigureAwait(false); - try + switch (msg.Code) { - switch (State) - { - case ReaderState.BeforeResult: - // First Read() after NextResult. Data row has already been processed. - State = ReaderState.InResult; - return true; - - case ReaderState.InResult: - await ConsumeRow(async); - if (_behavior.HasFlag(CommandBehavior.SingleRow)) - { - // TODO: See optimization proposal in #410 - await Consume(async); - return false; - } - break; + case BackendMessageCode.DataRow: + ProcessMessage(msg); + return true; - case ReaderState.BetweenResults: - case ReaderState.Consumed: - case ReaderState.Closed: - case ReaderState.Disposed: - return false; - default: - throw new ArgumentOutOfRangeException(); - } + case BackendMessageCode.CommandComplete: + case BackendMessageCode.EmptyQueryResponse: + ProcessMessage(msg); + if (_statements[StatementIndex].AppendErrorBarrier ?? Command.EnableErrorBarriers) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + return false; - var msg2 = await ReadMessage(async); - ProcessMessage(msg2); - return msg2.Code == BackendMessageCode.DataRow; + default: + throw Connector.UnexpectedMessageReceived(msg.Code); } - catch - { + } + catch + { + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) State = ReaderState.Consumed; - throw; - } - finally - { - registration.Dispose(); - } + throw; } + } - ValueTask ReadMessage(bool async) - { - return _isSequential ? ReadMessageSequential(async) : Connector.ReadMessage(async); + ValueTask ReadMessage(bool async) + { + return _isSequential ? ReadMessageSequential(Connector, async) : Connector.ReadMessage(async); - async ValueTask ReadMessageSequential(bool async2) + static async ValueTask ReadMessageSequential(NpgsqlConnector connector, bool async) + { + var msg = await connector.ReadMessage(async, DataRowLoadingMode.Sequential).ConfigureAwait(false); + if (msg.Code == BackendMessageCode.DataRow) { - var msg = await Connector.ReadMessage(async2, DataRowLoadingMode.Sequential); - if (msg.Code == BackendMessageCode.DataRow) - { - // Make sure that the datarow's column count is already buffered - await Connector.ReadBuffer.Ensure(2, async); - return msg; - } + // Make sure that the datarow's column count is already buffered + await connector.ReadBuffer.Ensure(2, async).ConfigureAwait(false); return msg; } + return msg; } + } - #endregion + #endregion - #region NextResult + #region NextResult - /// - /// Advances the reader to the next result when reading the results of a batch of statements. - /// - /// - public override bool NextResult() => (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false)) - .GetAwaiter().GetResult(); + /// + /// Advances the reader to the next result when reading the results of a batch of statements. + /// + /// + public override bool NextResult() => (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false)) + .GetAwaiter().GetResult(); - /// - /// This is the asynchronous version of NextResult. - /// - /// The token to monitor for cancellation requests. - /// A task representing the asynchronous operation. - public override Task NextResultAsync(CancellationToken cancellationToken) - { - using (NoSynchronizationContextScope.Enter()) - return _isSchemaOnly - ? NextResultSchemaOnly(async: true, cancellationToken: cancellationToken) - : NextResult(async: true, cancellationToken: cancellationToken); - } + /// + /// This is the asynchronous version of NextResult. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous operation. + public override Task NextResultAsync(CancellationToken cancellationToken) + => _isSchemaOnly + ? NextResultSchemaOnly(async: true, cancellationToken: cancellationToken) + : NextResult(async: true, cancellationToken: cancellationToken); - /// - /// Internal implementation of NextResult - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - async Task NextResult(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) - { - CheckClosedOrDisposed(); + /// + /// Internal implementation of NextResult + /// + async Task NextResult(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) + { + Debug.Assert(!_isSchemaOnly); + CheckClosedOrDisposed(); - IBackendMessage msg; - Debug.Assert(!_isSchemaOnly); + if (State is ReaderState.Consumed) + return false; + try + { using var registration = isConsuming ? default : Connector.StartNestedCancellableOperation(cancellationToken); + // If we're in the middle of a resultset, consume it + if (State is ReaderState.BeforeResult or ReaderState.InResult) + await ConsumeResultSet(async).ConfigureAwait(false); - try - { - // If we're in the middle of a resultset, consume it - switch (State) - { - case ReaderState.BeforeResult: - case ReaderState.InResult: - await ConsumeRow(async); - while (true) - { - var completedMsg = await Connector.ReadMessage(async, DataRowLoadingMode.Skip); - switch (completedMsg.Code) - { - case BackendMessageCode.CommandComplete: - case BackendMessageCode.EmptyQueryResponse: - ProcessMessage(completedMsg); - break; - default: - continue; - } + Debug.Assert(State is ReaderState.BetweenResults); - break; - } - - break; + _hasRows = false; - case ReaderState.BetweenResults: - break; + var statements = _statements; + var statementIndex = StatementIndex; + if (statementIndex >= 0) + { + if (RowDescription is { } description && statements[statementIndex].IsPrepared && ColumnInfoCache is { } cache) + description.SetColumnInfoCache(new(cache, 0, ColumnCount)); - case ReaderState.Consumed: - case ReaderState.Closed: - case ReaderState.Disposed: + if (statementIndex is 0 && _behavior.HasFlag(CommandBehavior.SingleResult) && !isConsuming) + { + await Consume(async).ConfigureAwait(false); return false; - default: - throw new ArgumentOutOfRangeException(); } + } - Debug.Assert(State == ReaderState.BetweenResults); - _hasRows = false; + // We are now at the end of the previous result set. Read up to the next result set, if any. + // Non-prepared statements receive ParseComplete, BindComplete, DescriptionRow/NoData, + // prepared statements receive only BindComplete + for (statementIndex = ++StatementIndex; statementIndex < statements.Count; statementIndex = ++StatementIndex) + { + var statement = statements[statementIndex]; - if (_behavior.HasFlag(CommandBehavior.SingleResult) && StatementIndex == 0 && !isConsuming) + IBackendMessage msg; + if (statement.TryGetPrepared(out var preparedStatement)) { - await Consume(async); - return false; + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + RowDescription = preparedStatement.Description; } - - // We are now at the end of the previous result set. Read up to the next result set, if any. - // Non-prepared statements receive ParseComplete, BindComplete, DescriptionRow/NoData, - // prepared statements receive only BindComplete - for (StatementIndex++; StatementIndex < _statements.Count; StatementIndex++) + else // Non-prepared/preparing flow { - var statement = _statements[StatementIndex]; - - if (statement.IsPrepared) - { - Expect(await Connector.ReadMessage(async), Connector); - RowDescription = statement.Description; - } - else // Non-prepared/preparing flow + preparedStatement = statement.PreparedStatement; + if (preparedStatement != null) { - var pStatement = statement.PreparedStatement; - if (pStatement != null) + Debug.Assert(!preparedStatement.IsPrepared); + if (preparedStatement.StatementBeingReplaced != null) { - Debug.Assert(!pStatement.IsPrepared); - if (pStatement.StatementBeingReplaced != null) - { - Expect(await Connector.ReadMessage(async), Connector); - pStatement.StatementBeingReplaced.CompleteUnprepare(); - pStatement.StatementBeingReplaced = null; - } + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + preparedStatement.StatementBeingReplaced.CompleteUnprepare(); + preparedStatement.StatementBeingReplaced = null; } + } - Expect(await Connector.ReadMessage(async), Connector); - Expect(await Connector.ReadMessage(async), Connector); - msg = await Connector.ReadMessage(async); - - RowDescription = statement.Description = msg.Code switch - { - BackendMessageCode.NoData => null, - - // RowDescription messages are cached on the connector, but if we're auto-preparing, we need to - // clone our own copy which will last beyond the lifetime of this invocation. - BackendMessageCode.RowDescription => pStatement == null - ? (RowDescriptionMessage)msg - : ((RowDescriptionMessage)msg).Clone(), - - _ => throw Connector.UnexpectedMessageReceived(msg.Code) - }; + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); - if (statement.IsPreparing) - { - statement.IsPreparing = false; - pStatement!.CompletePrepare(); - } + if (statement.IsPreparing) + { + preparedStatement!.State = PreparedState.Prepared; + Connector.PreparedStatementManager.NumPrepared++; + statement.IsPreparing = false; } - if (RowDescription == null) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + msg = await Connector.ReadMessage(async).ConfigureAwait(false); + + RowDescription = statement.Description = msg.Code switch { - // Statement did not generate a resultset (e.g. INSERT) - // Read and process its completion message and move on to the next statement + BackendMessageCode.NoData => null, - msg = await ReadMessage(async); - switch (msg.Code) - { - case BackendMessageCode.CommandComplete: - case BackendMessageCode.EmptyQueryResponse: - break; - default: - throw Connector.UnexpectedMessageReceived(msg.Code); - } + // RowDescription messages are cached on the connector, but if we're auto-preparing, we need to + // clone our own copy which will last beyond the lifetime of this invocation. + BackendMessageCode.RowDescription => preparedStatement == null + ? (RowDescriptionMessage)msg + : ((RowDescriptionMessage)msg).Clone(), - ProcessMessage(msg); - continue; - } + _ => throw Connector.UnexpectedMessageReceived(msg.Code) + }; + } - if (StatementIndex == 0 && Command.Parameters.HasOutputParameters) - { - // If output parameters are present and this is the first row of the first resultset, - // we must always read it in non-sequential mode because it will be traversed twice (once - // here for the parameters, then as a regular row). - msg = await Connector.ReadMessage(async); - ProcessMessage(msg); - if (msg.Code == BackendMessageCode.DataRow) - PopulateOutputParameters(); - } + if (RowDescription is not null) + { + if (ColumnInfoCache?.Length >= ColumnCount) + Array.Clear(ColumnInfoCache, 0, ColumnCount); else { - msg = await ReadMessage(async); - ProcessMessage(msg); + if (ColumnInfoCache is { } cache) + ArrayPool.Shared.Return(cache, clearArray: true); + ColumnInfoCache = ArrayPool.Shared.Rent(ColumnCount); } - + if (statement.IsPrepared) + RowDescription.LoadColumnInfoCache(Connector.SerializerOptions, ColumnInfoCache); + } + else + { + // Statement did not generate a resultset (e.g. INSERT) + // Read and process its completion message and move on to the next statement + // No need to read sequentially as it's not a DataRow + msg = await Connector.ReadMessage(async).ConfigureAwait(false); switch (msg.Code) { - case BackendMessageCode.DataRow: case BackendMessageCode.CommandComplete: + case BackendMessageCode.EmptyQueryResponse: break; + case BackendMessageCode.CopyInResponse: + throw Connector.Break(new NotSupportedException( + "COPY isn't supported in regular command execution - see https://www.npgsql.org/doc/copy.html for documentation on COPY with Npgsql. " + + "If you are trying to execute a SQL script created by pg_dump, pass the '--inserts' switch to disable generating COPY statements.")); + case BackendMessageCode.CopyOutResponse: + throw Connector.Break(new NotSupportedException( + "COPY isn't supported in regular command execution - see https://www.npgsql.org/doc/copy.html for documentation on COPY with Npgsql.")); default: throw Connector.UnexpectedMessageReceived(msg.Code); } - return true; + ProcessMessage(msg); + + if (statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + + continue; } - // There are no more queries, we're done. Read the RFQ. - ProcessMessage(Expect(await Connector.ReadMessage(async), Connector)); - RowDescription = null; - return false; + if ((Command.IsWrappedByBatch || StatementIndex is 0) && Command.InternalBatchCommands[StatementIndex]._parameters?.HasOutputParameters == true) + { + // If output parameters are present and this is the first row of the resultset, + // we must always read it in non-sequential mode because it will be traversed twice (once + // here for the parameters, then as a regular row). + msg = await Connector.ReadMessage(async).ConfigureAwait(false); + ProcessMessage(msg); + if (msg.Code == BackendMessageCode.DataRow) + PopulateOutputParameters(Command.InternalBatchCommands[StatementIndex]._parameters!); + } + else + { + msg = await ReadMessage(async).ConfigureAwait(false); + ProcessMessage(msg); + } + + switch (msg.Code) + { + case BackendMessageCode.DataRow: + Connector.State = ConnectorState.Fetching; + return true; + case BackendMessageCode.CommandComplete: + if (statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + return true; + default: + Connector.UnexpectedMessageReceived(msg.Code); + break; + } } - catch (Exception e) + + // There are no more queries, we're done. Read the RFQ. + if (_statements.Count is 0 || !(_statements[_statements.Count - 1].AppendErrorBarrier ?? Command.EnableErrorBarriers)) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + + State = ReaderState.Consumed; + RowDescription = null; + return false; + } + catch (Exception e) + { + if (e is PostgresException postgresException && StatementIndex >= 0 && StatementIndex < _statements.Count) { - State = ReaderState.Consumed; + var statement = _statements[StatementIndex]; - // Reference the triggering statement from the exception (for batching) - if (e is PostgresException postgresException && - StatementIndex >= 0 && StatementIndex < _statements.Count) + // Reference the triggering statement from the exception + postgresException.BatchCommand = statement; + + // Prevent the command or batch from being recycled (by the connection) when it's disposed. This is important since + // the exception is very likely to escape the using statement of the command, and by that time some other user may + // already be using the recycled instance. + Command.IsCacheable = false; + + // If the schema of a table changes after a statement is prepared on that table, PostgreSQL errors with + // 0A000: cached plan must not change result type. 0A000 seems like a non-specific code, but it's very unlikely the + // statement would successfully execute anyway, so invalidate the prepared statement. + if (postgresException.SqlState == PostgresErrorCodes.FeatureNotSupported && + statement.PreparedStatement is { } preparedStatement) { - postgresException.Statement = _statements[StatementIndex]; + preparedStatement.State = PreparedState.Invalidated; + Command.ResetPreparation(); + foreach (var s in Command.InternalBatchCommands) + s.ResetPreparation(); } + } - // An error means all subsequent statements were skipped by PostgreSQL. - // If any of them were being prepared, we need to update our bookkeeping to put - // them back in unprepared state. - for (; StatementIndex < _statements.Count; StatementIndex++) + // For the statement that errored, if it was being prepared we need to update our bookkeeping to put them back in unprepared + // state. + for (; StatementIndex < _statements.Count; StatementIndex++) + { + var statement = _statements[StatementIndex]; + if (statement.IsPreparing) { - var statement = _statements[StatementIndex]; - if (statement.IsPreparing) + statement.IsPreparing = false; + statement.PreparedStatement!.AbortPrepare(); + } + + // In normal, non-isolated batching, we've consumed the result set and are done. + // However, if the command has error barrier, we now have to consume results from the commands after it (unless it's the + // last one). + // Note that Consume calls NextResult (this method) recursively, the isConsuming flag tells us we're in this mode. + if ((statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) && StatementIndex < _statements.Count - 1) + { + if (isConsuming) + throw; + switch (State) { - statement.IsPreparing = false; - statement.PreparedStatement!.CompleteUnprepare(); + case ReaderState.Consumed: + case ReaderState.Closed: + case ReaderState.Disposed: + // The exception may have caused the connector to break (e.g. I/O), and so the reader is already closed. + break; + default: + // We provide Consume with the first exception which we've just caught. + // If it encounters other exceptions while consuming the rest of the result set, it will raise an AggregateException, + // otherwise it will rethrow this first exception. + await Consume(async, firstException: e).ConfigureAwait(false); + break; // Never reached, Consume always throws above } } - - throw; } + + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) + State = ReaderState.Consumed; + throw; } - void PopulateOutputParameters() + async ValueTask ConsumeResultSet(bool async) { - // The first row in a stored procedure command that has output parameters needs to be traversed twice - - // once for populating the output parameters and once for the actual result set traversal. So in this - // case we can't be sequential. - Debug.Assert(Command.Parameters.Any(p => p.IsOutputDirection)); - Debug.Assert(StatementIndex == 0); - Debug.Assert(RowDescription != null); - Debug.Assert(State == ReaderState.BeforeResult); + await ConsumeRow(async).ConfigureAwait(false); + while (true) + { + var completedMsg = await Connector.ReadMessage(async, DataRowLoadingMode.Skip).ConfigureAwait(false); + switch (completedMsg.Code) + { + case BackendMessageCode.CommandComplete: + case BackendMessageCode.EmptyQueryResponse: + ProcessMessage(completedMsg); - var currentPosition = Buffer.ReadPosition; + var statement = _statements[StatementIndex]; + if (statement.IsPrepared && ColumnInfoCache is not null) + RowDescription!.SetColumnInfoCache(new(ColumnInfoCache, 0, ColumnCount)); - // Temporarily set our state to InResult to allow us to read the values - State = ReaderState.InResult; + if (statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); - var pending = new Queue(); - var taken = new List(); - for (var i = 0; i < FieldCount; i++) - { - if (Command.Parameters.TryGetValue(GetName(i), out var p) && p.IsOutputDirection) - { - p.Value = GetValue(i); - taken.Add(p); + break; + default: + // TODO if we hit an ErrorResponse here (PG doesn't do this *today*) we should probably throw. + continue; } - else - pending.Enqueue(GetValue(i)); + + break; } + } + } + + + void PopulateOutputParameters(NpgsqlParameterCollection parameters) + { + // The first row in a stored procedure command that has output parameters needs to be traversed twice - + // once for populating the output parameters and once for the actual result set traversal. So in this + // case we can't be sequential. + Debug.Assert(RowDescription != null); + Debug.Assert(State == ReaderState.BeforeResult); + + var currentPosition = Buffer.ReadPosition; + + // Temporarily set our state to InResult to allow us to read the values + State = ReaderState.InResult; - // Not sure where this odd behavior comes from: all output parameters which did not get matched by - // name now get populated with column values which weren't matched. Keeping this for backwards compat, - // opened #2252 for investigation. - foreach (var p in Command.Parameters.Where(p => p.IsOutputDirection && !taken.Contains(p))) + var pending = new Queue(); + var taken = new List(); + for (var i = 0; i < FieldCount; i++) + { + if (parameters.TryGetValue(GetName(i), out var p) && p.IsOutputDirection) { - if (pending.Count == 0) - break; - p.Value = pending.Dequeue(); + p.Value = GetValue(i); + taken.Add(p); } + else + pending.Enqueue(GetValue(i)); + } - State = ReaderState.BeforeResult; // Set the state back - Buffer.ReadPosition = currentPosition; // Restore position + // Not sure where this odd behavior comes from: all output parameters which did not get matched by + // name now get populated with column values which weren't matched. Keeping this for backwards compat, + // opened #2252 for investigation. + foreach (var p in (IEnumerable)parameters) + { + if (!p.IsOutputDirection || taken.Contains(p)) + continue; - _column = -1; - ColumnLen = -1; - PosInColumn = 0; + if (pending.Count == 0) + break; + p.Value = pending.Dequeue(); } - /// - /// Note that in SchemaOnly mode there are no resultsets, and we read nothing from the backend (all - /// RowDescriptions have already been processed and are available) - /// - async Task NextResultSchemaOnly(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) - { - Debug.Assert(_isSchemaOnly); + PgReader.Commit(resuming: false); + State = ReaderState.BeforeResult; // Set the state back + Buffer.ReadPosition = currentPosition; // Restore position - using var registration = isConsuming ? default : Connector.StartNestedCancellableOperation(cancellationToken); + _column = -1; + } - try + /// + /// Note that in SchemaOnly mode there are no resultsets, and we read nothing from the backend (all + /// RowDescriptions have already been processed and are available) + /// + async Task NextResultSchemaOnly(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) + { + Debug.Assert(_isSchemaOnly); + + using var registration = isConsuming ? default : Connector.StartNestedCancellableOperation(cancellationToken); + + try + { + switch (State) + { + case ReaderState.BeforeResult: + case ReaderState.InResult: + case ReaderState.BetweenResults: + break; + case ReaderState.Consumed: + case ReaderState.Closed: + case ReaderState.Disposed: + return false; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(); + return false; + } + + for (StatementIndex++; StatementIndex < _statements.Count; StatementIndex++) { - switch (State) + var statement = _statements[StatementIndex]; + if (statement.TryGetPrepared(out var preparedStatement)) { - case ReaderState.BeforeResult: - case ReaderState.InResult: - case ReaderState.BetweenResults: - break; - case ReaderState.Consumed: - case ReaderState.Closed: - case ReaderState.Disposed: - return false; - default: - throw new ArgumentOutOfRangeException(); + // Row descriptions have already been populated in the statement objects at the + // Prepare phase + RowDescription = preparedStatement.Description; } - - for (StatementIndex++; StatementIndex < _statements.Count; StatementIndex++) + else { - var statement = _statements[StatementIndex]; - if (statement.IsPrepared) + var pStatement = statement.PreparedStatement; + if (pStatement != null) { - // Row descriptions have already been populated in the statement objects at the - // Prepare phase - RowDescription = _statements[StatementIndex].Description; + Debug.Assert(!pStatement.IsPrepared); + if (pStatement.StatementBeingReplaced != null) + { + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + pStatement.StatementBeingReplaced.CompleteUnprepare(); + pStatement.StatementBeingReplaced = null; + } } - else + + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + + if (statement.IsPreparing) + { + pStatement!.State = PreparedState.Prepared; + Connector.PreparedStatementManager.NumPrepared++; + statement.IsPreparing = false; + } + + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); + var msg = await Connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) { - Expect(await Connector.ReadMessage(async), Connector); - Expect(await Connector.ReadMessage(async), Connector); - var msg = await Connector.ReadMessage(async); - switch (msg.Code) + case BackendMessageCode.NoData: + RowDescription = _statements[StatementIndex].Description = null; + break; + case BackendMessageCode.RowDescription: + // We have a resultset + RowDescription = _statements[StatementIndex].Description = (RowDescriptionMessage)msg; + Command.FixupRowDescription(RowDescription, StatementIndex == 0); + break; + default: + throw Connector.UnexpectedMessageReceived(msg.Code); + } + + var forall = true; + for (var i = StatementIndex + 1; i < _statements.Count; i++) + if (!_statements[i].IsPrepared) { - case BackendMessageCode.NoData: - RowDescription = _statements[StatementIndex].Description = null; - break; - case BackendMessageCode.RowDescription: - // We have a resultset - RowDescription = _statements[StatementIndex].Description = (RowDescriptionMessage)msg; - Command.FixupRowDescription(RowDescription, StatementIndex == 0); + forall = false; break; - default: - throw Connector.UnexpectedMessageReceived(msg.Code); } - } - - // Found a resultset - if (RowDescription != null) - return true; + // There are no more queries, we're done. Read to the RFQ. + if (forall) + Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); } - // There are no more queries, we're done. Read to the RFQ. - if (!_statements.All(s => s.IsPrepared)) + // Found a resultset + if (RowDescription is not null) { - ProcessMessage(Expect(await Connector.ReadMessage(async), Connector)); - RowDescription = null; + if (ColumnInfoCache?.Length >= ColumnCount) + Array.Clear(ColumnInfoCache, 0, ColumnCount); + else + { + if (ColumnInfoCache is { } cache) + ArrayPool.Shared.Return(cache, clearArray: true); + ColumnInfoCache = ArrayPool.Shared.Rent(ColumnCount); + } + return true; } + } - return false; + State = ReaderState.Consumed; + RowDescription = null; + return false; + } + catch (Exception e) + { + // Break may have progressed the reader already. + if (State is not ReaderState.Closed) + State = ReaderState.Consumed; + + // Reference the triggering statement from the exception + if (e is PostgresException postgresException && StatementIndex >= 0 && StatementIndex < _statements.Count) + { + postgresException.BatchCommand = _statements[StatementIndex]; + + // Prevent the command or batch from being recycled (by the connection) when it's disposed. This is important since + // the exception is very likely to escape the using statement of the command, and by that time some other user may + // already be using the recycled instance. + Command.IsCacheable = false; } - catch (Exception e) + + // An error means all subsequent statements were skipped by PostgreSQL. + // If any of them were being prepared, we need to update our bookkeeping to put + // them back in unprepared state. + for (; StatementIndex < _statements.Count; StatementIndex++) { - State = ReaderState.Consumed; + var statement = _statements[StatementIndex]; + if (statement.IsPreparing) + { + statement.IsPreparing = false; + statement.PreparedStatement!.AbortPrepare(); + } + } + + throw; + } + } + + #endregion - // Reference the triggering statement from the exception (for batching) - if (e is PostgresException postgresException && - StatementIndex >= 0 && StatementIndex < _statements.Count) - { - postgresException.Statement = _statements[StatementIndex]; - } + #region ProcessMessage - throw; - } + internal void ProcessMessage(IBackendMessage msg) + { + if (msg.Code is not BackendMessageCode.DataRow) + { + HandleUncommon(msg); + return; } - #endregion - - #region ProcessMessage + var dataRow = (DataRowMessage)msg; + // The connector's buffer can actually change between DataRows: + // If a large DataRow exceeding the connector's current read buffer arrives, and we're + // reading in non-sequential mode, a new oversize buffer is allocated. We thus have to + // recapture the connector's buffer on each new DataRow. + // Note that this can happen even in sequential mode, if the row description message is big + // (see #2003) + if (!ReferenceEquals(Buffer, Connector.ReadBuffer)) + Buffer = Connector.ReadBuffer; + // We assume that the row's number of columns is identical to the description's + var numColumns = Buffer.ReadInt16(); + if (ColumnCount != numColumns) + ThrowHelper.ThrowArgumentException($"Row's number of columns ({numColumns}) differs from the row description's ({ColumnCount})"); + + var readPosition = Buffer.ReadPosition; + var msgRemainder = dataRow.Length - sizeof(short); + _dataMsgEnd = readPosition + msgRemainder; + _columnsStartPos = readPosition; + _canConsumeRowNonSequentially = msgRemainder <= Buffer.FilledBytes - readPosition; + _column = -1; + + if (_columns.Count > 0) + _columns.Clear(); + + switch (State) + { + case ReaderState.BetweenResults: + _hasRows = true; + State = ReaderState.BeforeResult; + break; + case ReaderState.BeforeResult: + State = ReaderState.InResult; + break; + case ReaderState.InResult: + break; + default: + Connector.UnexpectedMessageReceived(BackendMessageCode.DataRow); + break; + } - internal void ProcessMessage(IBackendMessage msg) + [MethodImpl(MethodImplOptions.NoInlining)] + void HandleUncommon(IBackendMessage msg) { switch (msg.Code) { - case BackendMessageCode.DataRow: - ProcessDataRowMessage((DataRowMessage)msg); - return; - case BackendMessageCode.CommandComplete: var completed = (CommandCompleteMessage)msg; switch (completed.StatementType) @@ -654,1533 +846,1505 @@ internal void ProcessMessage(IBackendMessage msg) case StatementType.Delete: case StatementType.Copy: case StatementType.Move: - if (!_recordsAffected.HasValue) - _recordsAffected = 0; + case StatementType.Merge: + _recordsAffected ??= 0; _recordsAffected += completed.Rows; break; } _statements[StatementIndex].ApplyCommandComplete(completed); - goto case BackendMessageCode.EmptyQueryResponse; - + State = ReaderState.BetweenResults; + break; case BackendMessageCode.EmptyQueryResponse: State = ReaderState.BetweenResults; - return; - - case BackendMessageCode.ReadyForQuery: - State = ReaderState.Consumed; - return; - + break; default: - throw new Exception("Received unexpected backend message of type " + msg.Code); + Connector.UnexpectedMessageReceived(msg.Code); + break; } } + } - void ProcessDataRowMessage(DataRowMessage msg) - { - Connector.State = ConnectorState.Fetching; - - // The connector's buffer can actually change between DataRows: - // If a large DataRow exceeding the connector's current read buffer arrives, and we're - // reading in non-sequential mode, a new oversize buffer is allocated. We thus have to - // recapture the connector's buffer on each new DataRow. - // Note that this can happen even in sequential mode, if the row description message is big - // (see #2003) - Buffer = Connector.ReadBuffer; + #endregion - _hasRows = true; - _column = -1; - ColumnLen = -1; - PosInColumn = 0; + /// + /// Gets a value indicating the depth of nesting for the current row. Always returns zero. + /// + public override int Depth => 0; - // We assume that the row's number of columns is identical to the description's - _numColumns = Buffer.ReadInt16(); - Debug.Assert(_numColumns == RowDescription!.NumFields, - $"Row's number of columns ({_numColumns}) differs from the row description's ({RowDescription.NumFields})"); + /// + /// Gets a value indicating whether the data reader is closed. + /// + public override bool IsClosed => State == ReaderState.Closed || State == ReaderState.Disposed; - if (!_isSequential) - { - _dataMsgEnd = Buffer.ReadPosition + msg.Length - 2; + /// + /// Gets the number of rows changed, inserted, or deleted by execution of the SQL statement. + /// + /// + /// The number of rows changed, inserted, or deleted. -1 for SELECT statements; 0 if no rows were affected or the statement failed. + /// + public override int RecordsAffected + => !_recordsAffected.HasValue + ? -1 + : _recordsAffected > int.MaxValue + ? throw new OverflowException( + $"The number of records affected exceeds int.MaxValue. Use {nameof(Rows)}.") + : (int)_recordsAffected; - // Initialize our columns array with the offset and length of the first column - _columns.Clear(); - var len = Buffer.ReadInt32(); - _columns.Add((Buffer.ReadPosition, len)); - } + /// + /// Gets the number of rows changed, inserted, or deleted by execution of the SQL statement. + /// + /// + /// The number of rows changed, inserted, or deleted. 0 for SELECT statements, if no rows were affected or the statement failed. + /// + public ulong Rows => _recordsAffected ?? 0; - switch (State) - { - case ReaderState.BetweenResults: - State = ReaderState.BeforeResult; - break; - case ReaderState.BeforeResult: - State = ReaderState.InResult; - break; - case ReaderState.InResult: - break; - default: - throw Connector.UnexpectedMessageReceived(BackendMessageCode.DataRow); - } - } + /// + /// Returns details about each statement that this reader will or has executed. + /// + /// + /// Note that some fields (i.e. rows and oid) are only populated as the reader + /// traverses the result. + /// + /// For commands with multiple queries, this exposes the number of rows affected on + /// a statement-by-statement basis, unlike + /// which exposes an aggregation across all statements. + /// + [Obsolete("Use the new DbBatch API")] + public IReadOnlyList Statements => _statements.AsReadOnly(); - #endregion - - void Cancel() => Connector.PerformPostgresCancellation(); - - /// - /// Gets a value indicating the depth of nesting for the current row. Always returns zero. - /// - public override int Depth => 0; - - /// - /// Gets a value indicating whether the data reader is closed. - /// - public override bool IsClosed => State == ReaderState.Closed; - - /// - /// Gets the number of rows changed, inserted, or deleted by execution of the SQL statement. - /// - public override int RecordsAffected => _recordsAffected.HasValue ? (int)_recordsAffected.Value : -1; - - /// - /// Returns details about each statement that this reader will or has executed. - /// - /// - /// Note that some fields (i.e. rows and oid) are only populated as the reader - /// traverses the result. - /// - /// For commands with multiple queries, this exposes the number of rows affected on - /// a statement-by-statement basis, unlike - /// which exposes an aggregation across all statements. - /// - public IReadOnlyList Statements => _statements.AsReadOnly(); - - /// - /// Gets a value that indicates whether this DbDataReader contains one or more rows. - /// - public override bool HasRows => State == ReaderState.Closed - ? throw new InvalidOperationException("Invalid attempt to call HasRows when reader is closed.") - : _hasRows; - - /// - /// Indicates whether the reader is currently positioned on a row, i.e. whether reading a - /// column is possible. - /// This property is different from in that will - /// return true even if attempting to read a column will fail, e.g. before - /// has been called - /// - public bool IsOnRow => State == ReaderState.InResult; - - /// - /// Gets the name of the column, given the zero-based column ordinal. - /// - /// The zero-based column ordinal. - /// The name of the specified column. - public override string GetName(int ordinal) => CheckRowDescriptionAndGetField(ordinal).Name; - - /// - /// Gets the number of columns in the current row. - /// - public override int FieldCount + /// + /// Gets a value that indicates whether this DbDataReader contains one or more rows. + /// + public override bool HasRows + => State switch { - get - { - CheckClosedOrDisposed(); - return RowDescription?.NumFields ?? 0; - } - } + ReaderState.Closed => throw new InvalidOperationException("Invalid attempt to call HasRows when reader is closed."), + ReaderState.Disposed => throw new ObjectDisposedException(nameof(NpgsqlDataReader)), + _ => _hasRows + }; + + /// + /// Indicates whether the reader is currently positioned on a row, i.e. whether reading a + /// column is possible. + /// This property is different from in that will + /// return true even if attempting to read a column will fail, e.g. before + /// has been called + /// + public bool IsOnRow => State == ReaderState.InResult; - #region Cleanup / Dispose + /// + /// Gets the name of the column, given the zero-based column ordinal. + /// + /// The zero-based column ordinal. + /// The name of the specified column. + public override string GetName(int ordinal) => GetField(ordinal).Name; - /// - /// Consumes all result sets for this reader, leaving the connector ready for sending and processing further - /// queries - /// - async Task Consume(bool async) + /// + /// Gets the number of columns in the current row. + /// + public override int FieldCount + { + get { - // Skip over the other result sets. Note that this does tally records affected - // from CommandComplete messages, and properly sets state for auto-prepared statements - if (_isSchemaOnly) - while (await NextResultSchemaOnly(async, isConsuming: true)) {} - else - while (await NextResult(async, isConsuming: true)) {} + CheckClosedOrDisposed(); + return RowDescription?.Count ?? 0; } + } + + #region Cleanup / Dispose + + /// + /// Consumes all result sets for this reader, leaving the connector ready for sending and processing further + /// queries + /// + async Task Consume(bool async, Exception? firstException = null) + { + var exceptions = firstException is null ? null : new List { firstException }; - /// - /// Releases the resources used by the NpgsqlDataReader. - /// - protected override void Dispose(bool disposing) + // Skip over the other result sets. Note that this does tally records affected from CommandComplete messages, and properly sets + // state for auto-prepared statements + while (true) { try { - Close(connectionClosing: false, async: false, isDisposing: true).GetAwaiter().GetResult(); + if (!(_isSchemaOnly + ? await NextResultSchemaOnly(async, isConsuming: true).ConfigureAwait(false) + : await NextResult(async, isConsuming: true).ConfigureAwait(false))) + { + break; + } } catch (Exception e) { - Log.Error("Exception caught while disposing a reader", e, Connector.Id); + exceptions ??= new(); + exceptions.Add(e); } - finally + } + + Debug.Assert(exceptions?.Count != 0); + + switch (exceptions?.Count) + { + case null: + return; + case 1: + ExceptionDispatchInfo.Capture(exceptions[0]).Throw(); + return; + default: + throw new NpgsqlException( + "Multiple exceptions occurred when consuming the result set", + new AggregateException(exceptions)); + } + } + + /// + /// Releases the resources used by the . + /// + protected override void Dispose(bool disposing) + { + try + { + Close(connectionClosing: false, async: false, isDisposing: true).GetAwaiter().GetResult(); + } + catch (Exception ex) + { + // In the case of a PostgresException (or multiple ones, if we have error barriers), the reader's state has already been set + // to Disposed in Close above; in multiplexing, we also unbind the connector (with its reader), and at that point it can be used + // by other consumers. Therefore, we only set the state fo Disposed if the exception *wasn't* a PostgresException. + if (!(ex is PostgresException || + ex is NpgsqlException { InnerException: AggregateException aggregateException } && + AllPostgresExceptions(aggregateException.InnerExceptions))) { State = ReaderState.Disposed; } + + throw; } + finally + { + Command.TraceCommandStop(); + } + } - /// - /// Releases the resources used by the NpgsqlDataReader. - /// + /// + /// Releases the resources used by the . + /// #if NETSTANDARD2_0 - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() #else - public override ValueTask DisposeAsync() + public override async ValueTask DisposeAsync() #endif + { + try { - using (NoSynchronizationContextScope.Enter()) - return DisposeAsyncCore(); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - async ValueTask DisposeAsyncCore() + await Close(connectionClosing: false, async: true, isDisposing: true).ConfigureAwait(false); + } + catch (Exception ex) + { + // In the case of a PostgresException (or multiple ones, if we have error barriers), the reader's state has already been set + // to Disposed in Close above; in multiplexing, we also unbind the connector (with its reader), and at that point it can be used + // by other consumers. Therefore, we only set the state to Disposed if the exception *wasn't* a PostgresException. + if (!(ex is PostgresException || + ex is NpgsqlException { InnerException: AggregateException aggregateException } && + AllPostgresExceptions(aggregateException.InnerExceptions))) { - try - { - await Close(connectionClosing: false, async: true, isDisposing: true); - } - catch (Exception e) - { - Log.Error("Exception caught while disposing a reader", e, Connector.Id); - } - finally - { - State = ReaderState.Disposed; - } + State = ReaderState.Disposed; } + throw; + } + finally + { + Command.TraceCommandStop(); } + } + + static bool AllPostgresExceptions(ReadOnlyCollection collection) + { + foreach (var exception in collection) + if (exception is not PostgresException) + return false; + return true; + } - /// - /// Closes the reader, allowing a new command to be executed. - /// - public override void Close() => Close(connectionClosing: false, async: false, isDisposing: false).GetAwaiter().GetResult(); + /// + /// Closes the reader, allowing a new command to be executed. + /// + public override void Close() => Close(connectionClosing: false, async: false, isDisposing: false).GetAwaiter().GetResult(); - /// - /// Closes the reader, allowing a new command to be executed. - /// + /// + /// Closes the reader, allowing a new command to be executed. + /// #if NETSTANDARD2_0 - public Task CloseAsync() + public Task CloseAsync() #else - public override Task CloseAsync() + public override Task CloseAsync() #endif - => Close(connectionClosing: false, async: true, isDisposing: false); + => Close(async: true, connectionClosing: false, isDisposing: false); - internal async Task Close(bool connectionClosing, bool async, bool isDisposing) + internal async Task Close(bool async, bool connectionClosing, bool isDisposing) + { + if (State is ReaderState.Closed or ReaderState.Disposed) { - if (State == ReaderState.Closed || State == ReaderState.Disposed) - return; + if (isDisposing) + State = ReaderState.Disposed; + return; + } - switch (Connector.State) + // Whenever a connector is broken, it also closes the current reader. + Connector.CurrentReader = null; + + switch (Connector.State) + { + case ConnectorState.Ready: + case ConnectorState.Fetching: + case ConnectorState.Executing: + case ConnectorState.Connecting: + if (State != ReaderState.Consumed) { - case ConnectorState.Ready: - case ConnectorState.Fetching: - case ConnectorState.Executing: - case ConnectorState.Connecting: - if (State != ReaderState.Consumed) - await Consume(async); - break; - case ConnectorState.Closed: - case ConnectorState.Broken: - break; - case ConnectorState.Waiting: - case ConnectorState.Copy: - case ConnectorState.Replication: - Debug.Fail("Bad connector state when closing reader: " + Connector.State); - break; - default: - throw new ArgumentOutOfRangeException(); + try + { + await Consume(async).ConfigureAwait(false); + } + catch (Exception ex) when (ex is OperationCanceledException or NpgsqlException { InnerException: TimeoutException }) + { + // Timeout/cancellation - completely normal, consume has basically completed. + } + catch (Exception ex) when ( + ex is PostgresException || + ex is NpgsqlException { InnerException: AggregateException aggregateException } && + AllPostgresExceptions(aggregateException.InnerExceptions)) + { + // In the case of a PostgresException (or multiple ones, if we have error barriers), the connection is fine and consume + // has basically completed. Defer throwing the exception until Cleanup is complete. + await Cleanup(async, connectionClosing, isDisposing).ConfigureAwait(false); + throw; + } + catch + { + Debug.Assert(Connector.IsBroken); + throw; + } } - - await Cleanup(async, connectionClosing, isDisposing); + break; + case ConnectorState.Closed: + case ConnectorState.Broken: + break; + case ConnectorState.Waiting: + case ConnectorState.Copy: + case ConnectorState.Replication: + Debug.Fail("Bad connector state when closing reader: " + Connector.State); + break; + default: + throw new ArgumentOutOfRangeException(); } - internal async Task Cleanup(bool async, bool connectionClosing = false, bool isDisposing = false) + await Cleanup(async, connectionClosing, isDisposing).ConfigureAwait(false); + } + + internal async Task Cleanup(bool async, bool connectionClosing = false, bool isDisposing = false) + { + LogMessages.ReaderCleanup(_commandLogger, Connector.Id); + + // If multiplexing isn't on, _sendTask contains the task for the writing of this command. + // Make sure that this task, which may have executed asynchronously and in parallel with the reading, + // has completed, throwing any exceptions it generated. If we don't do this, there's the possibility of a race condition where the + // user executes a new command after reader.Dispose() returns, but some additional write stuff is still finishing up from the last + // command. + if (_sendTask is { Status: not TaskStatus.RanToCompletion }) { - Log.Trace("Cleaning up reader", Connector.Id); - - // If multiplexing isn't on, _sendTask contains the task for the writing of this command. - // Make sure that this task, which may have executed asynchronously and in parallel with the reading, - // has completed, throwing any exceptions it generated. - // Note: if the following is removed, mysterious concurrent connection usage errors start happening - // on .NET Framework. - if (_sendTask != null) + // If the connector is broken, we have no reason to wait for the sendTask to complete + // as we're not going to send anything else over it + // and that can lead to deadlocks (concurrent write and read failure, see #4804) + if (Connector.IsBroken) + { + // Prevent unobserved Task notifications by observing the failed Task exception. + _ = _sendTask.ContinueWith(t => _ = t.Exception, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Current); + } + else { try { if (async) - await _sendTask; + await _sendTask.ConfigureAwait(false); else _sendTask.GetAwaiter().GetResult(); } catch (Exception e) { - // TODO: think of a better way to handle exceptios, see #1323 and #3163 - Log.Debug("Exception caught while sending the request", e, Connector.Id); + // TODO: think of a better way to handle exceptions, see #1323 and #3163 + _commandLogger.LogDebug(e, "Exception caught while sending the request", Connector.Id); } } + } - State = ReaderState.Closed; - Command.State = CommandState.Idle; - Connector.CurrentReader = null; - if (Log.IsEnabled(NpgsqlLogLevel.Debug)) { - Connector.QueryLogStopWatch.Stop(); - Log.Debug($"Query duration time: {Connector.QueryLogStopWatch.ElapsedMilliseconds}ms", Connector.Id); - Connector.QueryLogStopWatch.Reset(); - } - Connector.EndUserAction(); - NpgsqlEventSource.Log.CommandStop(); + if (ColumnInfoCache is { } cache) + { + ColumnInfoCache = null; + ArrayPool.Shared.Return(cache, clearArray: true); + } - // The reader shouldn't be unbound, if we're disposing - so the state is set prematurely - if (isDisposing) - State = ReaderState.Disposed; + State = ReaderState.Closed; + Command.State = CommandState.Idle; + Connector.CurrentReader = null; + if (_commandLogger.IsEnabled(LogLevel.Information)) + Command.LogExecutingCompleted(Connector, executing: false); + NpgsqlEventSource.Log.CommandStop(); + Connector.DataSource.MetricsReporter.ReportCommandStop(_startTimestamp); + Connector.EndUserAction(); - if (_connection.ConnectorBindingScope == ConnectorBindingScope.Reader) - { - // We may unbind the current reader, which also sets the connector to null - var connector = Connector; - UnbindIfNecessary(); + // The reader shouldn't be unbound, if we're disposing - so the state is set prematurely + if (isDisposing) + State = ReaderState.Disposed; - // TODO: Refactor... Use proper scope - _connection.Connector = null; - connector.Connection = null; - _connection.ConnectorBindingScope = ConnectorBindingScope.None; + if (_connection?.ConnectorBindingScope == ConnectorBindingScope.Reader) + { + UnbindIfNecessary(); - // If the reader is being closed as part of the connection closing, we don't apply - // the reader's CommandBehavior.CloseConnection - if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) - _connection.Close(); + // TODO: Refactor... Use proper scope + _connection.Connector = null; + Connector.Connection = null; + _connection.ConnectorBindingScope = ConnectorBindingScope.None; - connector.ReaderCompleted.SetResult(null); - } - else if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) + // If the reader is being closed as part of the connection closing, we don't apply + // the reader's CommandBehavior.CloseConnection + if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) _connection.Close(); - if (ReaderClosed != null) - { - ReaderClosed(this, EventArgs.Empty); - ReaderClosed = null; - } + Connector.ReaderCompleted.SetResult(null); } - - #endregion - - #region Simple value getters - - /// - /// Gets the value of the specified column as a Boolean. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override bool GetBoolean(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a byte. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override byte GetByte(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a single character. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override char GetChar(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a 16-bit signed integer. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override short GetInt16(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a 32-bit signed integer. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override int GetInt32(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a 64-bit signed integer. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override long GetInt64(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a object. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override DateTime GetDateTime(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as an instance of . - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override string GetString(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a object. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override decimal GetDecimal(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a double-precision floating point number. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override double GetDouble(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a single-precision floating point number. - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override float GetFloat(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a globally-unique identifier (GUID). - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override Guid GetGuid(int ordinal) => GetFieldValue(ordinal); - - /// - /// Populates an array of objects with the column values of the current row. - /// - /// An array of Object into which to copy the attribute columns. - /// The number of instances of in the array. - public override int GetValues(object[] values) + else if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) { - if (values == null) - throw new ArgumentNullException(nameof(values)); - CheckResultSet(); - - var count = Math.Min(FieldCount, values.Length); - for (var i = 0; i < count; i++) - values[i] = GetValue(i); - return count; + Debug.Assert(_connection is not null); + _connection.Close(); } - /// - /// Gets the value of the specified column as an instance of . - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override object this[int ordinal] => GetValue(ordinal); - - #endregion - - #region Provider-specific simple type getters - - /// - /// Gets the value of the specified column as an , - /// Npgsql's provider-specific type for dates. - /// - /// - /// PostgreSQL's date type represents dates from 4713 BC to 5874897 AD, while .NET's DateTime - /// only supports years from 1 to 1999. If you require years outside this range use this accessor. - /// The standard method will also return this type, but has - /// the disadvantage of boxing the value. - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public NpgsqlDate GetDate(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as a TimeSpan, - /// - /// - /// PostgreSQL's interval type has has a resolution of 1 microsecond and ranges from - /// -178000000 to 178000000 years, while .NET's TimeSpan has a resolution of 100 nanoseconds - /// and ranges from roughly -29247 to 29247 years. - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public TimeSpan GetTimeSpan(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as an , - /// Npgsql's provider-specific type for time spans. - /// - /// - /// PostgreSQL's interval type has has a resolution of 1 microsecond and ranges from - /// -178000000 to 178000000 years, while .NET's TimeSpan has a resolution of 100 nanoseconds - /// and ranges from roughly -29247 to 29247 years. If you require values from outside TimeSpan's - /// range use this accessor. - /// The standard ADO.NET method will also return this - /// type, but has the disadvantage of boxing the value. - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public NpgsqlTimeSpan GetInterval(int ordinal) => GetFieldValue(ordinal); - - /// - /// Gets the value of the specified column as an , - /// Npgsql's provider-specific type for date/time timestamps. Note that this type covers - /// both PostgreSQL's "timestamp with time zone" and "timestamp without time zone" types, - /// which differ only in how they are converted upon input/output. - /// - /// - /// PostgreSQL's timestamp type represents dates from 4713 BC to 5874897 AD, while .NET's DateTime - /// only supports years from 1 to 1999. If you require years outside this range use this accessor. - /// The standard method will also return this type, but has - /// the disadvantage of boxing the value. - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public NpgsqlDateTime GetTimeStamp(int ordinal) => GetFieldValue(ordinal); - - #endregion - - #region Special binary getters - - /// - /// Reads a stream of bytes from the specified column, starting at location indicated by dataOffset, into the buffer, starting at the location indicated by bufferOffset. - /// - /// The zero-based column ordinal. - /// The index within the row from which to begin the read operation. - /// The buffer into which to copy the data. - /// The index with the buffer to which the data will be copied. - /// The maximum number of characters to read. - /// The actual number of bytes read. - public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) + if (ReaderClosed != null) { - if (dataOffset < 0 || dataOffset > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between {0} and {int.MaxValue}"); - if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) - throw new IndexOutOfRangeException($"bufferOffset must be between {0} and {(buffer.Length)}"); - if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) - throw new IndexOutOfRangeException($"length must be between {0} and {buffer.Length - bufferOffset}"); - - var field = CheckRowAndGetField(ordinal); - var handler = field.Handler; - if (!(handler is ByteaHandler)) - throw new InvalidCastException("GetBytes() not supported for type " + field.Name); - - SeekToColumn(ordinal, false).GetAwaiter().GetResult(); - if (ColumnLen == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(field); - - if (buffer == null) - return ColumnLen; - - var dataOffset2 = (int)dataOffset; - SeekInColumn(dataOffset2, false).GetAwaiter().GetResult(); - - // Attempt to read beyond the end of the column - if (dataOffset2 + length > ColumnLen) - length = Math.Max(ColumnLen - dataOffset2, 0); - - var left = length; - while (left > 0) - { - var read = Buffer.Read(new Span(buffer, bufferOffset, left)); - bufferOffset += read; - left -= read; - } - - PosInColumn += length; - - return length; + ReaderClosed(this, EventArgs.Empty); + ReaderClosed = null; } + } - /// - /// Retrieves data as a . - /// - /// The zero-based column ordinal. - /// The returned object. - public override Stream GetStream(int ordinal) => GetStream(ordinal, false).Result; - - /// - /// Retrieves data as a . - /// - /// The zero-based column ordinal. - /// The token to monitor for cancellation requests. The default value is . - /// The returned object. - public Task GetStreamAsync(int ordinal, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return GetStream(ordinal, true, cancellationToken).AsTask(); - } + #endregion - ValueTask GetStream(int ordinal, bool async, CancellationToken cancellationToken = default) => - GetStreamInternal(CheckRowAndGetField(ordinal), ordinal, async, cancellationToken); + #region Simple value getters - ValueTask GetStreamInternal(FieldDescription field, int ordinal, bool async, CancellationToken cancellationToken = default) - { - if (_columnStream != null && !_columnStream.IsDisposed) - throw new InvalidOperationException("A stream is already open for this reader"); + /// + /// Gets the value of the specified column as a Boolean. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override bool GetBoolean(int ordinal) => GetFieldValueCore(ordinal); - var t = SeekToColumn(ordinal, async, cancellationToken); - if (!t.IsCompleted) - return new ValueTask(GetStreamLong(this, field, t, cancellationToken)); + /// + /// Gets the value of the specified column as a byte. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override byte GetByte(int ordinal) => GetFieldValueCore(ordinal); - if (ColumnLen == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(field); + /// + /// Gets the value of the specified column as a single character. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override char GetChar(int ordinal) => GetFieldValueCore(ordinal); - PosInColumn += ColumnLen; - return new ValueTask(_columnStream = (NpgsqlReadBuffer.ColumnStream)Buffer.GetStream(ColumnLen, !_isSequential)); + /// + /// Gets the value of the specified column as a 16-bit signed integer. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override short GetInt16(int ordinal) => GetFieldValueCore(ordinal); - static async Task GetStreamLong(NpgsqlDataReader reader, FieldDescription field, Task seekTask, CancellationToken cancellationToken) - { - using var registration = reader.Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + /// + /// Gets the value of the specified column as a 32-bit signed integer. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override int GetInt32(int ordinal) => GetFieldValueCore(ordinal); - await seekTask; - if (reader.ColumnLen == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(field); + /// + /// Gets the value of the specified column as a 64-bit signed integer. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override long GetInt64(int ordinal) => GetFieldValueCore(ordinal); - reader.PosInColumn += reader.ColumnLen; - return reader._columnStream = (NpgsqlReadBuffer.ColumnStream)reader.Buffer.GetStream(reader.ColumnLen, !reader._isSequential); - } - } + /// + /// Gets the value of the specified column as a object. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override DateTime GetDateTime(int ordinal) => GetFieldValueCore(ordinal); - #endregion + /// + /// Gets the value of the specified column as an instance of . + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override string GetString(int ordinal) => GetFieldValueCore(ordinal); - #region Special text getters + /// + /// Gets the value of the specified column as a object. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override decimal GetDecimal(int ordinal) => GetFieldValueCore(ordinal); - /// - /// Reads a stream of characters from the specified column, starting at location indicated by dataOffset, into the buffer, starting at the location indicated by bufferOffset. - /// - /// The zero-based column ordinal. - /// The index within the row from which to begin the read operation. - /// The buffer into which to copy the data. - /// The index with the buffer to which the data will be copied. - /// The maximum number of characters to read. - /// The actual number of characters read. - public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) - { - if (dataOffset < 0 || dataOffset > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between {0} and {int.MaxValue}"); - if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) - throw new IndexOutOfRangeException($"bufferOffset must be between {0} and {(buffer.Length)}"); - if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) - throw new IndexOutOfRangeException($"length must be between {0} and {buffer.Length - bufferOffset}"); + /// + /// Gets the value of the specified column as a double-precision floating point number. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override double GetDouble(int ordinal) => GetFieldValueCore(ordinal); - var field = CheckRowAndGetField(ordinal); - var handler = field.Handler as TextHandler; - if (handler == null) - throw new InvalidCastException("The GetChars method is not supported for type " + field.Name); + /// + /// Gets the value of the specified column as a single-precision floating point number. + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override float GetFloat(int ordinal) => GetFieldValueCore(ordinal); - SeekToColumn(ordinal, false).GetAwaiter().GetResult(); - if (ColumnLen == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(field); + /// + /// Gets the value of the specified column as a globally-unique identifier (GUID). + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override Guid GetGuid(int ordinal) => GetFieldValueCore(ordinal); - if (PosInColumn == 0) - _charPos = 0; + /// + /// Populates an array of objects with the column values of the current row. + /// + /// An array of Object into which to copy the attribute columns. + /// The number of instances of in the array. + public override int GetValues(object[] values) + { + if (values == null) + throw new ArgumentNullException(nameof(values)); + CheckResultSet(); + + var count = Math.Min(FieldCount, values.Length); + for (var i = 0; i < count; i++) + values[i] = GetValue(i); + return count; + } - var decoder = Buffer.TextEncoding.GetDecoder(); + /// + /// Gets the value of the specified column as an instance of . + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override object this[int ordinal] => GetValue(ordinal); - if (buffer == null) - { - // Note: Getting the length of a text column means decoding the entire field, - // very inefficient and also consumes the column in sequential mode. But this seems to - // be SqlClient's behavior as well. - var (bytesSkipped, charsSkipped) = SkipChars(decoder, int.MaxValue, ColumnLen - PosInColumn); - Debug.Assert(bytesSkipped == ColumnLen - PosInColumn); - PosInColumn += bytesSkipped; - _charPos += charsSkipped; - return _charPos; - } + #endregion - if (PosInColumn == ColumnLen || dataOffset < _charPos) - { - // Either the column has already been read (e.g. GetString()) or a previous GetChars() - // has positioned us in the column *after* the requested read start offset. Seek back - // (this will throw for sequential) - SeekInColumn(0, false).GetAwaiter().GetResult(); - _charPos = 0; - } + #region Provider-specific simple type getters - if (dataOffset > _charPos) - { - var charsToSkip = (int)dataOffset - _charPos; - var (bytesSkipped, charsSkipped) = SkipChars(decoder, charsToSkip, ColumnLen - PosInColumn); - decoder.Reset(); - PosInColumn += bytesSkipped; - _charPos += charsSkipped; - if (charsSkipped < charsToSkip) // data offset is beyond the column's end - return 0; - } + /// + /// Gets the value of the specified column as a TimeSpan, + /// + /// + /// PostgreSQL's interval type has has a resolution of 1 microsecond and ranges from + /// -178000000 to 178000000 years, while .NET's TimeSpan has a resolution of 100 nanoseconds + /// and ranges from roughly -29247 to 29247 years. + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public TimeSpan GetTimeSpan(int ordinal) => GetFieldValueCore(ordinal); + + /// + protected override DbDataReader GetDbDataReader(int ordinal) => GetData(ordinal); - // We're now positioned at the start of the segment of characters we need to read. - if (length == 0) - return 0; + /// + /// Returns a nested data reader for the requested column. + /// The column type must be a record or a to Npgsql known composite type, or an array thereof. + /// Currently only supported in non-sequential mode. + /// + /// The zero-based column ordinal. + /// A data reader. + public new NpgsqlNestedDataReader GetData(int ordinal) + { + if (_isSequential) + throw new NotSupportedException("GetData() not supported in sequential mode."); + + var field = CheckRowAndGetField(ordinal); + var type = field.PostgresType; + var isArray = type is PostgresArrayType; + var elementType = isArray ? ((PostgresArrayType)type).Element : type; + var compositeType = elementType as PostgresCompositeType; + if (field.DataFormat is DataFormat.Text || (elementType.InternalName != "record" && compositeType == null)) + throw new InvalidCastException("GetData() not supported for type " + field.TypeDisplayName); + + var columnLength = SeekToColumn(async: false, ordinal, field.DataFormat, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength is -1) + ThrowHelper.ThrowInvalidCastException_NoValue(field); + + if (PgReader.FieldOffset > 0) + PgReader.Rewind(PgReader.FieldOffset); + + var reader = CachedFreeNestedDataReader; + if (reader != null) + { + CachedFreeNestedDataReader = null; + reader.Init(compositeType); + } + else + { + reader = new NpgsqlNestedDataReader(this, null, 1, compositeType); + } + if (isArray) + reader.InitArray(); + else + reader.InitSingleRow(); + return reader; + } - var (bytesRead, charsRead) = DecodeChars(decoder, buffer, bufferOffset, length, ColumnLen - PosInColumn); + #endregion - PosInColumn += bytesRead; - _charPos += charsRead; - return charsRead; - } + #region Special binary getters - (int BytesRead, int CharsRead) DecodeChars(Decoder decoder, char[] output, int outputOffset, int charCount, int byteCount) - { - var (bytesRead, charsRead) = (0, 0); + /// + /// Reads a stream of bytes from the specified column, starting at location indicated by dataOffset, into the buffer, starting at the location indicated by bufferOffset. + /// + /// The zero-based column ordinal. + /// The index within the row from which to begin the read operation. + /// The buffer into which to copy the data. + /// The index with the buffer to which the data will be copied. + /// The maximum number of characters to read. + /// The actual number of bytes read. + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) + { + if (dataOffset is < 0 or > int.MaxValue) + throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between {0} and {int.MaxValue}"); + if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) + throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) + throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); + + var field = CheckRowAndGetField(ordinal); + var columnLength = SeekToColumn(async: false, ordinal, field.DataFormat, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength == -1) + ThrowHelper.ThrowInvalidCastException_NoValue(field); + + if (buffer is null) + return columnLength; + + // Move to offset + if (_isSequential && PgReader.FieldOffset > dataOffset) + ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); + + PgReader.Seek((int)dataOffset); + + // At offset, read into buffer. + length = Math.Min(length, PgReader.FieldRemaining); + PgReader.ReadBytes(new Span(buffer, bufferOffset, length)); + return length; + } - while (true) - { - Buffer.Ensure(1); // Make sure we have at least some data - - var maxBytes = Math.Min(byteCount - bytesRead, Buffer.ReadBytesLeft); - decoder.Convert(Buffer.Buffer, Buffer.ReadPosition, maxBytes, output, outputOffset, charCount - charsRead, false, - out var bytesUsed, out var charsUsed, out _); - Buffer.ReadPosition += bytesUsed; - bytesRead += bytesUsed; - charsRead += charsUsed; - if (charsRead == charCount || bytesRead == byteCount) - break; - outputOffset += charsUsed; - Buffer.Clear(); - } + /// + /// Retrieves data as a . + /// + /// The zero-based column ordinal. + /// The returned object. + public override Stream GetStream(int ordinal) + => GetFieldValueCore(ordinal); - return (bytesRead, charsRead); - } + /// + /// Retrieves data as a . + /// + /// The zero-based column ordinal. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The returned object. + public Task GetStreamAsync(int ordinal, CancellationToken cancellationToken = default) + => GetFieldValueAsync(ordinal, cancellationToken); - internal (int BytesSkipped, int CharsSkipped) SkipChars(Decoder decoder, int charCount, int byteCount) - { - // TODO: Allocate on the stack with Span - if (_tempCharBuf == null) - _tempCharBuf = new char[1024]; - var (charsSkipped, bytesSkipped) = (0, 0); - while (charsSkipped < charCount && bytesSkipped < byteCount) - { - var (bytesRead, charsRead) = DecodeChars(decoder, _tempCharBuf, 0, Math.Min(charCount, _tempCharBuf.Length), byteCount); - bytesSkipped += bytesRead; - charsSkipped += charsRead; - } - return (bytesSkipped, charsSkipped); - } + #endregion - /// - /// Retrieves data as a . - /// - /// The zero-based column ordinal. - /// The returned object. - public override TextReader GetTextReader(int ordinal) - => GetTextReader(ordinal, false).Result; - - /// - /// Retrieves data as a . - /// - /// The zero-based column ordinal. - /// The token to monitor for cancellation requests. The default value is . - /// The returned object. - public Task GetTextReaderAsync(int ordinal, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return GetTextReader(ordinal, true, cancellationToken).AsTask(); - } + #region Special text getters - async ValueTask GetTextReader(int ordinal, bool async, CancellationToken cancellationToken = default) - { - var field = CheckRowAndGetField(ordinal); - if (field.Handler is ITextReaderHandler handler) - return handler.GetTextReader(async - ? await GetStreamInternal(field, ordinal, true, cancellationToken) - : GetStreamInternal(field, ordinal, false, CancellationToken.None).Result); + /// + /// Reads a stream of characters from the specified column, starting at location indicated by dataOffset, into the buffer, starting at the location indicated by bufferOffset. + /// + /// The zero-based column ordinal. + /// The index within the row from which to begin the read operation. + /// The buffer into which to copy the data. + /// The index with the buffer to which the data will be copied. + /// The maximum number of characters to read. + /// The actual number of characters read. + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) + { + if (dataOffset is < 0 or > int.MaxValue) + throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between 0 and {int.MaxValue}"); + if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) + throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) + throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); + + // Check whether we have a GetChars implementation for this column type. + var field = GetInfo(ordinal, typeof(GetChars), out var converter, out var bufferRequirement, out var asObject); + + var columnLength = SeekToColumn(async: false, ordinal, field, resumableOp: true).GetAwaiter().GetResult(); + if (columnLength == -1) + ThrowHelper.ThrowInvalidCastException_NoValue(CheckRowAndGetField(ordinal)); + + var reader = PgReader; + dataOffset = buffer is null ? 0 : dataOffset; + if (_isSequential && reader.CharsRead > dataOffset) + ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); + + reader.StartCharsRead(checked((int)dataOffset), + buffer is not null ? new ArraySegment(buffer, bufferOffset, length) : (ArraySegment?)null); + + reader.StartRead(bufferRequirement); + var result = asObject + ? (GetChars)converter.ReadAsObject(reader) + : ((PgConverter)converter).Read(reader); + reader.EndRead(); + + reader.EndCharsRead(); + return result.Read; + } - throw new InvalidCastException($"The GetTextReader method is not supported for type {field.Handler.PgDisplayName}"); - } + /// + /// Retrieves data as a . + /// + /// The zero-based column ordinal. + /// The returned object. + public override TextReader GetTextReader(int ordinal) + => GetFieldValueCore(ordinal); - #endregion + /// + /// Retrieves data as a . + /// + /// The zero-based column ordinal. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The returned object. + public Task GetTextReaderAsync(int ordinal, CancellationToken cancellationToken = default) + => GetFieldValueAsync(ordinal, cancellationToken); - #region GetFieldValue + #endregion - /// - /// Asynchronously gets the value of the specified column as a type. - /// - /// The type of the value to be returned. - /// The type of the value to be returned. - /// The token to monitor for cancellation requests. - /// - public override Task GetFieldValueAsync(int ordinal, CancellationToken cancellationToken) - { - if (typeof(T) == typeof(Stream)) - return (Task)(object)GetStreamAsync(ordinal, cancellationToken); + #region GetFieldValue - if (typeof(T) == typeof(TextReader)) - return (Task)(object)GetTextReaderAsync(ordinal, cancellationToken); + /// + /// Asynchronously gets the value of the specified column as a type. + /// + /// The type of the value to be returned. + /// The type of the value to be returned. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + public override Task GetFieldValueAsync(int ordinal, CancellationToken cancellationToken) + { + // In non-sequential, we know that the column is already buffered - no I/O will take place + if (!_isSequential) + return Task.FromResult(GetFieldValueCore(ordinal)); - // In non-sequential, we know that the column is already buffered - no I/O will take place - if (!_isSequential) - return Task.FromResult(GetFieldValue(ordinal)); + // The only statically mapped converter, it always exists. + if (typeof(T) == typeof(Stream)) + return GetStream(ordinal, cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return GetFieldValueSequential(ordinal, true, cancellationToken).AsTask(); - } + return Core(ordinal, cancellationToken).AsTask(); - /// - /// Synchronously gets the value of the specified column as a type. - /// - /// Synchronously gets the value of the specified column as a type. - /// The column to be retrieved. - /// The column to be retrieved. - public override T GetFieldValue(int ordinal) + async ValueTask Core(int ordinal, CancellationToken cancellationToken) { - if (typeof(T) == typeof(Stream)) - return (T)(object)GetStream(ordinal); + using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - if (typeof(T) == typeof(TextReader)) - return (T)(object)GetTextReader(ordinal); + var field = GetInfo(ordinal, typeof(T), out var converter, out var bufferRequirement, out var asObject); - if (_isSequential) - return GetFieldValueSequential(ordinal, false).GetAwaiter().GetResult(); + var columnLength = await SeekToColumn(async: true, ordinal, field).ConfigureAwait(false); + if (columnLength is -1) + return DbNullValueOrThrow(ordinal); - // In non-sequential, we know that the column is already buffered - no I/O will take place + if (typeof(T) == typeof(TextReader)) + PgReader.ThrowIfStreamActive(); + + Debug.Assert(asObject || converter is PgConverter); + await PgReader.StartReadAsync(bufferRequirement, cancellationToken).ConfigureAwait(false); + var result = asObject + ? (T)await converter.ReadAsObjectAsync(PgReader, cancellationToken).ConfigureAwait(false) + : await converter.UnsafeDowncast().ReadAsync(PgReader, cancellationToken).ConfigureAwait(false); + await PgReader.EndReadAsync().ConfigureAwait(false); + return result; + } - var field = CheckRowAndGetField(ordinal); - SeekToColumnNonSequential(ordinal); + async Task GetStream(int ordinal, CancellationToken cancellationToken) + { + using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - if (ColumnLen == -1) - { - // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) - return default!; + var field = GetDefaultInfo(ordinal, out _, out _); + PgReader.ThrowIfStreamActive(); - if (typeof(T) == typeof(object)) - return (T)(object)DBNull.Value; + var columnLength = await SeekToColumn(async: true, ordinal, field).ConfigureAwait(false); - ThrowHelper.ThrowInvalidCastException_NoValue(field); - } + if (columnLength == -1) + return DbNullValueOrThrow(ordinal); - var position = Buffer.ReadPosition; - try - { - return NullableHandler.Exists - ? NullableHandler.Read(field.Handler, Buffer, ColumnLen, field) - : typeof(T) == typeof(object) - ? (T)field.Handler.ReadAsObject(Buffer, ColumnLen, field) - : field.Handler.Read(Buffer, ColumnLen, field); - } - catch - { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - Buffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; - } + return (T)(object)PgReader.GetStream(canSeek: !_isSequential); } + } - async ValueTask GetFieldValueSequential(int column, bool async, CancellationToken cancellationToken = default) - { - using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - - var field = CheckRowAndGetField(column); - await SeekToColumnSequential(column, async, CancellationToken.None); - CheckColumnStart(); + /// + /// Synchronously gets the value of the specified column as a type. + /// + /// Synchronously gets the value of the specified column as a type. + /// The column to be retrieved. + /// The column to be retrieved. + public override T GetFieldValue(int ordinal) => GetFieldValueCore(ordinal); - if (ColumnLen == -1) - { - // When T is a Nullable (and only in that case), we support returning null - if (NullableHandler.Exists) - return default!; + T GetFieldValueCore(int ordinal) + { + // The only statically mapped converter, it always exists. + if (typeof(T) == typeof(Stream)) + return GetStream(ordinal); + + var field = GetInfo(ordinal, typeof(T), out var converter, out var bufferRequirement, out var asObject); + + if (typeof(T) == typeof(TextReader)) + PgReader.ThrowIfStreamActive(); + + var columnLength = + _isSequential + ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() + : SeekToColumnNonSequential(ordinal, field); + if (columnLength is -1) + return DbNullValueOrThrow(ordinal); + + Debug.Assert(asObject || converter is PgConverter); + PgReader.StartRead(bufferRequirement); + var result = asObject + ? (T)converter.ReadAsObject(PgReader) + : converter.UnsafeDowncast().Read(PgReader); + PgReader.EndRead(); + return result; + + [MethodImpl(MethodImplOptions.NoInlining)] + T GetStream(int ordinal) + { + var field = GetDefaultInfo(ordinal, out _, out _); + PgReader.ThrowIfStreamActive(); - if (typeof(T) == typeof(object)) - return (T)(object)DBNull.Value; + var columnLength = + _isSequential + ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() + : SeekToColumnNonSequential(ordinal, field); - ThrowHelper.ThrowInvalidCastException_NoValue(field); - } + if (columnLength == -1) + return DbNullValueOrThrow(ordinal); - var position = Buffer.ReadPosition; - try - { - return NullableHandler.Exists - ? ColumnLen <= Buffer.ReadBytesLeft - ? NullableHandler.Read(field.Handler, Buffer, ColumnLen, field) - : await NullableHandler.ReadAsync(field.Handler, Buffer, ColumnLen, async, field) - : typeof(T) == typeof(object) - ? ColumnLen <= Buffer.ReadBytesLeft - ? (T)field.Handler.ReadAsObject(Buffer, ColumnLen, field) - : (T)await field.Handler.ReadAsObject(Buffer, ColumnLen, async, field) - : ColumnLen <= Buffer.ReadBytesLeft - ? field.Handler.Read(Buffer, ColumnLen, field) - : await field.Handler.Read(Buffer, ColumnLen, async, field); - } - catch - { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - await Buffer.Skip(remainingBytes, async); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; - } + return (T)(object)PgReader.GetStream(canSeek: !_isSequential); } + } - #endregion - - #region GetValue - - /// - /// Gets the value of the specified column as an instance of . - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override object GetValue(int ordinal) - { - var fieldDescription = CheckRowAndGetField(ordinal); + #endregion - if (_isSequential) { - SeekToColumnSequential(ordinal, false).GetAwaiter().GetResult(); - CheckColumnStart(); - } else - SeekToColumnNonSequential(ordinal); + #region GetValue - if (ColumnLen == -1) - return DBNull.Value; + /// + /// Gets the value of the specified column as an instance of . + /// + /// The zero-based column ordinal. + /// The value of the specified column. + public override object GetValue(int ordinal) + { + var field = GetDefaultInfo(ordinal, out var converter, out var bufferRequirement); + var columnLength = + _isSequential + ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() + : SeekToColumnNonSequential(ordinal, field); + if (columnLength == -1) + return DBNull.Value; + + PgReader.StartRead(bufferRequirement); + var result = converter.ReadAsObject(PgReader); + PgReader.EndRead(); + + return result; + } - object result; - var position = Buffer.ReadPosition; - try - { - result = _isSequential - ? fieldDescription.Handler.ReadAsObject(Buffer, ColumnLen, false, fieldDescription).GetAwaiter().GetResult() - : fieldDescription.Handler.ReadAsObject(Buffer, ColumnLen, fieldDescription); - } - catch - { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - Buffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; - } + /// + /// Gets the value of the specified column as an instance of . + /// + /// The name of the column. + /// The value of the specified column. + public override object this[string name] => GetValue(GetOrdinal(name)); - // Used for Entity Framework <= 6 compability - var objectResultType = Command.ObjectResultTypes?[ordinal]; - if (objectResultType != null) - { - result = objectResultType == typeof(DateTimeOffset) - ? new DateTimeOffset((DateTime)result) - : Convert.ChangeType(result, objectResultType)!; - } + #endregion - return result; - } + #region IsDBNull - /// - /// Gets the value of the specified column as an instance of . - /// - /// The zero-based column ordinal. - /// The value of the specified column. - public override object GetProviderSpecificValue(int ordinal) - { - var fieldDescription = CheckRowAndGetField(ordinal); + /// + /// Gets a value that indicates whether the column contains nonexistent or missing values. + /// + /// The zero-based column ordinal. + /// true if the specified column is equivalent to ; otherwise false. + public override bool IsDBNull(int ordinal) + => SeekToColumn(async: false, ordinal, CheckRowAndGetField(ordinal).DataFormat, resumableOp: true).GetAwaiter().GetResult() is -1; - if (_isSequential) - { - SeekToColumnSequential(ordinal, false).GetAwaiter().GetResult(); - CheckColumnStart(); - } - else - SeekToColumnNonSequential(ordinal); + /// + /// An asynchronous version of , which gets a value that indicates whether the column contains non-existent or missing values. + /// The parameter is currently ignored. + /// + /// The zero-based column to be retrieved. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// true if the specified column value is equivalent to otherwise false. + public override Task IsDBNullAsync(int ordinal, CancellationToken cancellationToken) + { + if (!_isSequential) + return IsDBNull(ordinal) ? TrueTask : FalseTask; - if (ColumnLen == -1) - return DBNull.Value; + return Core(ordinal, cancellationToken); - var position = Buffer.ReadPosition; - try - { - return _isSequential - ? fieldDescription.Handler.ReadPsvAsObject(Buffer, ColumnLen, false, fieldDescription).GetAwaiter().GetResult() - : fieldDescription.Handler.ReadPsvAsObject(Buffer, ColumnLen, fieldDescription); - } - catch - { - if (Connector.State != ConnectorState.Broken) - { - var writtenBytes = Buffer.ReadPosition - position; - var remainingBytes = ColumnLen - writtenBytes; - if (remainingBytes > 0) - Buffer.Skip(remainingBytes, false).GetAwaiter().GetResult(); - } - throw; - } - finally - { - // Important: position must still be updated - PosInColumn += ColumnLen; - } + async Task Core(int ordinal, CancellationToken cancellationToken) + { + using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + return await SeekToColumn(async: true, ordinal, CheckRowAndGetField(ordinal).DataFormat, resumableOp: true).ConfigureAwait(false) is -1; } + } - /// - /// Gets the value of the specified column as an instance of . - /// - /// The name of the column. - /// The value of the specified column. - public override object this[string name] => GetValue(GetOrdinal(name)); + #endregion - #endregion + #region Other public accessors - #region IsDBNull + /// + /// Gets the column ordinal given the name of the column. + /// + /// The name of the column. + /// The zero-based column ordinal. + public override int GetOrdinal(string name) + { + if (string.IsNullOrEmpty(name)) + ThrowHelper.ThrowArgumentException($"{nameof(name)} cannot be empty", nameof(name)); + CheckClosedOrDisposed(); + if (RowDescription is null) + ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); + return RowDescription.GetFieldIndex(name); + } - /// - /// Gets a value that indicates whether the column contains nonexistent or missing values. - /// - /// The zero-based column ordinal. - /// true if the specified column is equivalent to ; otherwise false. - public override bool IsDBNull(int ordinal) - { - CheckRowAndGetField(ordinal); + /// + /// Gets a representation of the PostgreSQL data type for the specified field. + /// The returned representation can be used to access various information about the field. + /// + /// The zero-based column index. + public PostgresType GetPostgresType(int ordinal) => GetField(ordinal).PostgresType; - if (_isSequential) - SeekToColumnSequential(ordinal, false).GetAwaiter().GetResult(); - else - SeekToColumnNonSequential(ordinal); + /// + /// Gets the data type information for the specified field. + /// This is the PostgreSQL type name (e.g. double precision), not the .NET type + /// (see for that). + /// + /// The zero-based column index. + public override string GetDataTypeName(int ordinal) => GetField(ordinal).TypeDisplayName; - return ColumnLen == -1; - } + /// + /// Gets the OID for the PostgreSQL type for the specified field, as it appears in the pg_type table. + /// + /// + /// This is a PostgreSQL-internal value that should not be relied upon and should only be used for + /// debugging purposes. + /// + /// The zero-based column index. + public uint GetDataTypeOID(int ordinal) => GetField(ordinal).TypeOID; - /// - /// An asynchronous version of , which gets a value that indicates whether the column contains non-existent or missing values. - /// The parameter is currently ignored. - /// - /// The zero-based column to be retrieved. - /// The token to monitor for cancellation requests. - /// true if the specified column value is equivalent to otherwise false. - public override Task IsDBNullAsync(int ordinal, CancellationToken cancellationToken) - { - CheckRowAndGetField(ordinal); + /// + /// Gets the data type of the specified column. + /// + /// The zero-based column ordinal. + /// The data type of the specified column. + [UnconditionalSuppressMessage("ILLink", "IL2093", + Justification = "Members are only dynamically accessed by Npgsql via GetFieldType by GetSchema, and only in certain cases. " + + "Holding PublicFields and PublicProperties metadata on all our mapped types just for that case is the wrong tradeoff.")] + public override Type GetFieldType(int ordinal) + => GetField(ordinal).FieldType; - if (!_isSequential) - return IsDBNull(ordinal) ? PGUtil.TrueTask : PGUtil.FalseTask; + /// + /// Returns an that can be used to iterate through the rows in the data reader. + /// + /// An that can be used to iterate through the rows in the data reader. + public override IEnumerator GetEnumerator() + => new DbEnumerator(this); - using (NoSynchronizationContextScope.Enter()) - return IsDBNullAsyncInternal(); + /// + /// Returns schema information for the columns in the current resultset. + /// + /// + public ReadOnlyCollection GetColumnSchema() + => GetColumnSchema(async: false).GetAwaiter().GetResult(); - // ReSharper disable once InconsistentNaming - async Task IsDBNullAsyncInternal() - { - using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + ReadOnlyCollection IDbColumnSchemaGenerator.GetColumnSchema() + { + var columns = GetColumnSchema(); + var result = new DbColumn[columns.Count]; + var i = 0; + foreach (var column in columns) + result[i++] = column; - await SeekToColumn(ordinal, true, cancellationToken); - return ColumnLen == -1; - } - } + return new ReadOnlyCollection(result); + } - #endregion + /// + /// Asynchronously returns schema information for the columns in the current resultset. + /// + /// +#if NET5_0_OR_GREATER + public new Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) +#else + public Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) +#endif + => GetColumnSchema(async: true, cancellationToken); - #region Other public accessors + Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) + => RowDescription == null || ColumnCount == 0 + ? Task.FromResult(new List().AsReadOnly()) + : new DbColumnSchemaGenerator(_connection!, RowDescription, _behavior.HasFlag(CommandBehavior.KeyInfo)) + .GetColumnSchema(async, cancellationToken); - /// - /// Gets the column ordinal given the name of the column. - /// - /// The name of the column. - /// The zero-based column ordinal. - public override int GetOrdinal(string name) - { - if (string.IsNullOrEmpty(name)) - throw new ArgumentException("name cannot be empty", nameof(name)); - if (State == ReaderState.Closed) - throw new InvalidOperationException("The reader is closed"); - if (RowDescription is null) - throw new InvalidOperationException("No resultset is currently being traversed"); - return RowDescription.GetFieldIndex(name); - } + #endregion - /// - /// Gets a representation of the PostgreSQL data type for the specified field. - /// The returned representation can be used to access various information about the field. - /// - /// The zero-based column index. - public PostgresType GetPostgresType(int ordinal) => CheckRowDescriptionAndGetField(ordinal).PostgresType; - - /// - /// Gets the data type information for the specified field. - /// This will be the PostgreSQL type name (e.g. double precision), not the .NET type - /// (see for that). - /// - /// The zero-based column index. - public override string GetDataTypeName(int ordinal) => CheckRowDescriptionAndGetField(ordinal).TypeDisplayName; - - /// - /// Gets the OID for the PostgreSQL type for the specified field, as it appears in the pg_type table. - /// - /// - /// This is a PostgreSQL-internal value that should not be relied upon and should only be used for - /// debugging purposes. - /// - /// The zero-based column index. - public uint GetDataTypeOID(int ordinal) => CheckRowDescriptionAndGetField(ordinal).TypeOID; - - /// - /// Gets the data type of the specified column. - /// - /// The zero-based column ordinal. - /// The data type of the specified column. - public override Type GetFieldType(int ordinal) - => Command.ObjectResultTypes?[ordinal] - ?? CheckRowDescriptionAndGetField(ordinal).FieldType; - - /// - /// Returns the provider-specific field type of the specified column. - /// - /// The zero-based column ordinal. - /// The Type object that describes the data type of the specified column. - public override Type GetProviderSpecificFieldType(int ordinal) - { - var fieldDescription = CheckRowDescriptionAndGetField(ordinal); - return fieldDescription.Handler.GetProviderSpecificFieldType(fieldDescription); - } + #region Schema metadata table - /// - /// Gets all provider-specific attribute columns in the collection for the current row. - /// - /// An array of Object into which to copy the attribute columns. - /// The number of instances of in the array. - public override int GetProviderSpecificValues(object[] values) - { - if (values == null) - throw new ArgumentNullException(nameof(values)); - if (State != ReaderState.InResult) - throw new InvalidOperationException("No row is available"); - - var count = Math.Min(FieldCount, values.Length); - for (var i = 0; i < count; i++) - values[i] = GetProviderSpecificValue(i); - return count; - } + /// + /// Returns a System.Data.DataTable that describes the column metadata of the DataReader. + /// + [UnconditionalSuppressMessage( + "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", "IL2026")] + public override DataTable? GetSchemaTable() + => GetSchemaTable(async: false).GetAwaiter().GetResult(); - /// - /// Returns an that can be used to iterate through the rows in the data reader. - /// - /// An that can be used to iterate through the rows in the data reader. - public override IEnumerator GetEnumerator() - => new DbEnumerator(this); - - /// - /// Returns schema information for the columns in the current resultset. - /// - /// - public ReadOnlyCollection GetColumnSchema() - => GetColumnSchema(async: false).GetAwaiter().GetResult(); - - ReadOnlyCollection IDbColumnSchemaGenerator.GetColumnSchema() - => new ReadOnlyCollection(GetColumnSchema().Select(c => (DbColumn)c).ToList()); - - /// - /// Asynchronously returns schema information for the columns in the current resultset. - /// - /// -#if NET - public new Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) + /// + /// Asynchronously returns a System.Data.DataTable that describes the column metadata of the DataReader. + /// + [UnconditionalSuppressMessage( + "Composite type mapping currently isn't trimming-safe, and warnings are generated at the MapComposite level.", "IL2026")] +#if NET5_0_OR_GREATER + public override Task GetSchemaTableAsync(CancellationToken cancellationToken = default) #else - public Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) + public Task GetSchemaTableAsync(CancellationToken cancellationToken = default) #endif + => GetSchemaTable(async: true, cancellationToken); + + [UnconditionalSuppressMessage("Trimming", "IL2111", Justification = "typeof(Type).TypeInitializer is not used.")] + async Task GetSchemaTable(bool async, CancellationToken cancellationToken = default) + { + if (FieldCount == 0) // No resultset + return null; + + var table = new DataTable("SchemaTable"); + + // Note: column order is important to match SqlClient's, some ADO.NET users appear + // to assume ordering (see #1671) + table.Columns.Add("ColumnName", typeof(string)); + table.Columns.Add("ColumnOrdinal", typeof(int)); + table.Columns.Add("ColumnSize", typeof(int)); + table.Columns.Add("NumericPrecision", typeof(int)); + table.Columns.Add("NumericScale", typeof(int)); + table.Columns.Add("IsUnique", typeof(bool)); + table.Columns.Add("IsKey", typeof(bool)); + table.Columns.Add("BaseServerName", typeof(string)); + table.Columns.Add("BaseCatalogName", typeof(string)); + table.Columns.Add("BaseColumnName", typeof(string)); + table.Columns.Add("BaseSchemaName", typeof(string)); + table.Columns.Add("BaseTableName", typeof(string)); + table.Columns.Add("DataType", typeof(Type)); + table.Columns.Add("AllowDBNull", typeof(bool)); + table.Columns.Add("ProviderType", typeof(int)); + table.Columns.Add("IsAliased", typeof(bool)); + table.Columns.Add("IsExpression", typeof(bool)); + table.Columns.Add("IsIdentity", typeof(bool)); + table.Columns.Add("IsAutoIncrement", typeof(bool)); + table.Columns.Add("IsRowVersion", typeof(bool)); + table.Columns.Add("IsHidden", typeof(bool)); + table.Columns.Add("IsLong", typeof(bool)); + table.Columns.Add("IsReadOnly", typeof(bool)); + table.Columns.Add("ProviderSpecificDataType", typeof(Type)); + table.Columns.Add("DataTypeName", typeof(string)); + + foreach (var column in await GetColumnSchema(async, cancellationToken).ConfigureAwait(false)) { - using (NoSynchronizationContextScope.Enter()) - return GetColumnSchema(async: true, cancellationToken); + var row = table.NewRow(); + + row["ColumnName"] = column.ColumnName; + row["ColumnOrdinal"] = column.ColumnOrdinal ?? -1; + row["ColumnSize"] = column.ColumnSize ?? -1; + row["NumericPrecision"] = column.NumericPrecision ?? 0; + row["NumericScale"] = column.NumericScale ?? 0; + row["IsUnique"] = column.IsUnique == true; + row["IsKey"] = column.IsKey == true; + row["BaseServerName"] = ""; + row["BaseCatalogName"] = column.BaseCatalogName; + row["BaseColumnName"] = column.BaseColumnName; + row["BaseSchemaName"] = column.BaseSchemaName; + row["BaseTableName"] = column.BaseTableName; + row["DataType"] = column.DataType; + row["AllowDBNull"] = (object?)column.AllowDBNull ?? DBNull.Value; + row["ProviderType"] = column.NpgsqlDbType ?? NpgsqlDbType.Unknown; + row["IsAliased"] = column.IsAliased == true; + row["IsExpression"] = column.IsExpression == true; + row["IsIdentity"] = column.IsIdentity == true; + row["IsAutoIncrement"] = column.IsAutoIncrement == true; + row["IsRowVersion"] = false; + row["IsHidden"] = column.IsHidden == true; + row["IsLong"] = column.IsLong == true; + row["IsReadOnly"] = column.IsReadOnly == true; + row["DataTypeName"] = column.DataTypeName; + + table.Rows.Add(row); } - Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) - => RowDescription == null || RowDescription.Fields.Count == 0 - ? Task.FromResult(new List().AsReadOnly()) - : new DbColumnSchemaGenerator(_connection, RowDescription, _behavior.HasFlag(CommandBehavior.KeyInfo)) - .GetColumnSchema(async, cancellationToken); + return table; + } - #endregion + #endregion Schema metadata table - #region Schema metadata table + #region Seeking - /// - /// Returns a System.Data.DataTable that describes the column metadata of the DataReader. - /// - public override DataTable? GetSchemaTable() - => GetSchemaTable(async: false).GetAwaiter().GetResult(); + /// + /// Seeks to the given column. The 4-byte length is read and returned. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + ValueTask SeekToColumn(bool async, int ordinal, DataFormat dataFormat, bool resumableOp = false) + => _isSequential + ? SeekToColumnSequential(async, ordinal, dataFormat, resumableOp) + : new(SeekToColumnNonSequential(ordinal, dataFormat, resumableOp)); - /// - /// Asynchronously returns a System.Data.DataTable that describes the column metadata of the DataReader. - /// -#if NET - public override Task GetSchemaTableAsync(CancellationToken cancellationToken = default) -#else - public Task GetSchemaTableAsync(CancellationToken cancellationToken = default) -#endif + int SeekToColumnNonSequential(int ordinal, DataFormat dataFormat, bool resumableOp = false) + { + var currentColumn = _column; + var buffer = Buffer; + var pgReader = PgReader; + + // Deals with current column commit and rereads + int columnLength; + if (currentColumn >= 0) { - using (NoSynchronizationContextScope.Enter()) - return GetSchemaTable(async: true, cancellationToken); + if (currentColumn == ordinal) + return HandleReread(pgReader.Resumable && resumableOp); + pgReader.Commit(resuming: false); } - async Task GetSchemaTable(bool async, CancellationToken cancellationToken = default) + // Deals with forward movement + Debug.Assert(ordinal != currentColumn); + if (ordinal > currentColumn) { - if (FieldCount == 0) // No resultset - return null; - - var table = new DataTable("SchemaTable"); - - // Note: column order is important to match SqlClient's, some ADO.NET users appear - // to assume ordering (see #1671) - table.Columns.Add("ColumnName", typeof(string)); - table.Columns.Add("ColumnOrdinal", typeof(int)); - table.Columns.Add("ColumnSize", typeof(int)); - table.Columns.Add("NumericPrecision", typeof(int)); - table.Columns.Add("NumericScale", typeof(int)); - table.Columns.Add("IsUnique", typeof(bool)); - table.Columns.Add("IsKey", typeof(bool)); - table.Columns.Add("BaseServerName", typeof(string)); - table.Columns.Add("BaseCatalogName", typeof(string)); - table.Columns.Add("BaseColumnName", typeof(string)); - table.Columns.Add("BaseSchemaName", typeof(string)); - table.Columns.Add("BaseTableName", typeof(string)); - table.Columns.Add("DataType", typeof(Type)); - table.Columns.Add("AllowDBNull", typeof(bool)); - table.Columns.Add("ProviderType", typeof(int)); - table.Columns.Add("IsAliased", typeof(bool)); - table.Columns.Add("IsExpression", typeof(bool)); - table.Columns.Add("IsIdentity", typeof(bool)); - table.Columns.Add("IsAutoIncrement", typeof(bool)); - table.Columns.Add("IsRowVersion", typeof(bool)); - table.Columns.Add("IsHidden", typeof(bool)); - table.Columns.Add("IsLong", typeof(bool)); - table.Columns.Add("IsReadOnly", typeof(bool)); - table.Columns.Add("ProviderSpecificDataType", typeof(Type)); - table.Columns.Add("DataTypeName", typeof(string)); - - foreach (var column in await GetColumnSchema(async, cancellationToken)) + // Written as a while to be able to increment _column directly after reading into it. + while (_column < ordinal - 1) { - var row = table.NewRow(); - - row["ColumnName"] = column.ColumnName; - row["ColumnOrdinal"] = column.ColumnOrdinal ?? -1; - row["ColumnSize"] = column.ColumnSize ?? -1; - row["NumericPrecision"] = column.NumericPrecision ?? 0; - row["NumericScale"] = column.NumericScale ?? 0; - row["IsUnique"] = column.IsUnique == true; - row["IsKey"] = column.IsKey == true; - row["BaseServerName"] = ""; - row["BaseCatalogName"] = column.BaseCatalogName; - row["BaseColumnName"] = column.BaseColumnName; - row["BaseSchemaName"] = column.BaseSchemaName; - row["BaseTableName"] = column.BaseTableName; - row["DataType"] = column.DataType; - row["AllowDBNull"] = (object?)column.AllowDBNull ?? DBNull.Value; - row["ProviderType"] = column.NpgsqlDbType ?? NpgsqlDbType.Unknown; - row["IsAliased"] = column.IsAliased == true; - row["IsExpression"] = column.IsExpression == true; - row["IsIdentity"] = column.IsIdentity == true; - row["IsAutoIncrement"] = column.IsAutoIncrement == true; - row["IsRowVersion"] = false; - row["IsHidden"] = column.IsHidden == true; - row["IsLong"] = column.IsLong == true; - row["DataTypeName"] = column.DataTypeName; - - table.Rows.Add(row); + columnLength = buffer.ReadInt32(); + _column++; + Debug.Assert(columnLength >= -1); + if (columnLength > 0) + buffer.Skip(columnLength); } - - return table; + columnLength = buffer.ReadInt32(); } + else + columnLength = SeekBackwards(); - #endregion Schema metadata table + pgReader.Init(columnLength, dataFormat, resumableOp); + _column = ordinal; - #region Seeking + return columnLength; - Task SeekToColumn(int column, bool async, CancellationToken cancellationToken = default) + int HandleReread(bool resuming) { - if (_isSequential) - return SeekToColumnSequential(column, async, cancellationToken); - SeekToColumnNonSequential(column); - return Task.CompletedTask; + Debug.Assert(pgReader.Initialized); + var columnLength = pgReader.FieldSize; + pgReader.Commit(resuming); + if (!resuming && columnLength > 0) + buffer.ReadPosition -= columnLength; + pgReader.Init(columnLength, dataFormat, resumableOp); + return columnLength; } - void SeekToColumnNonSequential(int column) + // On the first call to SeekBackwards we'll fill up the columns list as we may need seek positions more than once. + [MethodImpl(MethodImplOptions.NoInlining)] + int SeekBackwards() { - // Shut down any streaming going on on the column - if (_columnStream != null) + // Backfill the first column. + if (_columns.Count is 0) { - _columnStream.Dispose(); - _columnStream = null; + buffer.ReadPosition = _columnsStartPos; + var len = buffer.ReadInt32(); + _columns.Add((buffer.ReadPosition, len)); } - - for (var lastColumnRead = _columns.Count; column >= lastColumnRead; lastColumnRead++) + for (var lastColumnRead = _columns.Count; ordinal >= lastColumnRead; lastColumnRead++) { - int lastColumnLen; - (Buffer.ReadPosition, lastColumnLen) = _columns[lastColumnRead-1]; - if (lastColumnLen != -1) - Buffer.ReadPosition += lastColumnLen; + (Buffer.ReadPosition, var lastLen) = _columns[lastColumnRead - 1]; + if (lastLen > 0) + buffer.Skip(lastLen); var len = Buffer.ReadInt32(); _columns.Add((Buffer.ReadPosition, len)); } - - (Buffer.ReadPosition, ColumnLen) = _columns[column]; - _column = column; - PosInColumn = 0; + (Buffer.ReadPosition, var columnLength) = _columns[ordinal]; + return columnLength; } + } - /// - /// Seeks to the given column. The 4-byte length is read and stored in . - /// - async Task SeekToColumnSequential(int column, bool async, CancellationToken cancellationToken = default) + ValueTask SeekToColumnSequential(bool async, int ordinal, DataFormat dataFormat, bool resumableOp = false) + { + var reread = _column == ordinal; + // Column rereading rules for sequential mode: + // * We never allow rereading if the column didn't get initialized as resumable the previous time + // * If it did get initialized as resumable we only allow rereading when either of the following is true: + // - The op is a resumable one again + // - The op isn't resumable but the field is still entirely unconsumed + if (ordinal < _column || (reread && (!PgReader.Resumable || (!resumableOp && !PgReader.IsAtStart)))) + ThrowHelper.ThrowInvalidOperationException( + $"Invalid attempt to read from column ordinal '{ordinal}'. With CommandBehavior.SequentialAccess, " + + $"you may only read from column ordinal '{_column}' or greater."); + + var committed = false; + if (!PgReader.CommitHasIO(reread)) { - if (column < 0 || column >= _numColumns) - throw new IndexOutOfRangeException("Column index out of range"); - - if (column < _column) - throw new InvalidOperationException($"Invalid attempt to read from column ordinal '{column}'. With CommandBehavior.SequentialAccess, you may only read from column ordinal '{_column}' or greater."); - - if (column == _column) - return; - - // Need to seek forward - - // Shut down any streaming going on on the column - if (_columnStream != null) + var columnLength = PgReader.FieldSize; + PgReader.Commit(reread); + committed = true; + if (reread) { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; + PgReader.Init(columnLength, dataFormat, columnLength is -1 || resumableOp); + return new(columnLength); } - // Skip to end of column if needed - // TODO: Simplify by better initializing _columnLen/_posInColumn - var remainingInColumn = ColumnLen == -1 ? 0 : ColumnLen - PosInColumn; - if (remainingInColumn > 0) - await Buffer.Skip(remainingInColumn, async); - - // Skip over unwanted fields - for (; _column < column - 1; _column++) + if (TrySeekBuffered(ordinal, out columnLength)) { - await Buffer.Ensure(4, async); - var len = Buffer.ReadInt32(); - if (len != -1) - await Buffer.Skip(len, async); + PgReader.Init(columnLength, dataFormat, columnLength is -1 || resumableOp); + return new(columnLength); } - await Buffer.Ensure(4, async); - ColumnLen = Buffer.ReadInt32(); - PosInColumn = 0; - _column = column; + // If we couldn't consume the column TrySeekBuffered had to stop at, do so now. + if (columnLength > -1) + { + // Resumable: true causes commit to consume without error. + PgReader.Init(columnLength, dataFormat, resumable: true); + committed = false; + } } - Task SeekInColumn(int posInColumn, bool async, CancellationToken cancellationToken = default) - { - if (_isSequential) - return SeekInColumnSequential(posInColumn, async, cancellationToken); + return Core(async, reread, !committed, ordinal, dataFormat, resumableOp); - if (posInColumn > ColumnLen) - posInColumn = ColumnLen; +#if NET6_0_OR_GREATER + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] +#endif + async ValueTask Core(bool async, bool reread, bool commit, int ordinal, DataFormat dataFormat, bool resumableOp) + { + if (commit) + { + Debug.Assert(ordinal != _column); + if (async) + await PgReader.CommitAsync(reread).ConfigureAwait(false); + else + PgReader.Commit(reread); + } - Buffer.ReadPosition = _columns[_column].Offset + posInColumn; - PosInColumn = posInColumn; - return Task.CompletedTask; + if (reread) + { + PgReader.Init(PgReader.FieldSize, dataFormat, PgReader.FieldSize is -1 || resumableOp); + return PgReader.FieldSize; + } - async Task SeekInColumnSequential(int posInColumn2, bool async2, CancellationToken cancellationToken2) + // Seek to the requested column + int columnLength; + var buffer = Buffer; + // Written as a while to be able to increment _column directly after reading into it. + while (_column < ordinal - 1) { - Debug.Assert(_column > -1); + await buffer.Ensure(4, async).ConfigureAwait(false); + columnLength = buffer.ReadInt32(); + _column++; + Debug.Assert(columnLength >= -1); + if (columnLength > 0) + await buffer.Skip(columnLength, async).ConfigureAwait(false); + } - if (posInColumn2 < PosInColumn) - throw new InvalidOperationException("Attempt to read a position in the column which has already been read"); + await buffer.Ensure(4, async).ConfigureAwait(false); + columnLength = buffer.ReadInt32(); + _column = ordinal; - if (posInColumn2 > ColumnLen) - posInColumn2 = ColumnLen; + PgReader.Init(columnLength, dataFormat, resumableOp); + return columnLength; + } - if (posInColumn2 > PosInColumn) + bool TrySeekBuffered(int ordinal, out int columnLength) + { + // Skip over unwanted fields + columnLength = -1; + var buffer = Buffer; + // Written as a while to be able to increment _column directly after reading into it. + while (_column < ordinal - 1) + { + if (buffer.ReadBytesLeft < 4) + { + columnLength = -1; + return false; + } + columnLength = buffer.ReadInt32(); + _column++; + Debug.Assert(columnLength >= -1); + if (columnLength > 0) { - await Buffer.Skip(posInColumn2 - PosInColumn, async2); - PosInColumn = posInColumn2; + if (buffer.ReadBytesLeft < columnLength) + return false; + buffer.Skip(columnLength); } } - } - #endregion + if (buffer.ReadBytesLeft < 4) + { + columnLength = -1; + return false; + } - #region ConsumeRow + columnLength = buffer.ReadInt32(); + _column = ordinal; + return true; + } + } - Task ConsumeRow(bool async) - { - Debug.Assert(State == ReaderState.InResult || State == ReaderState.BeforeResult); + #endregion - if (_isSequential) - return ConsumeRowSequential(async); - ConsumeRowNonSequential(); - return Task.CompletedTask; + #region ConsumeRow - async Task ConsumeRowSequential(bool async2) - { - if (_columnStream != null) - { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; - } + Task ConsumeRow(bool async) + { + Debug.Assert(State is ReaderState.InResult or ReaderState.BeforeResult); - // TODO: Potential for code-sharing with ReadColumn above, which also skips - // Skip to end of column if needed - var remainingInColumn = ColumnLen == -1 ? 0 : ColumnLen - PosInColumn; - if (remainingInColumn > 0) - await Buffer.Skip(remainingInColumn, async2); + if (!_canConsumeRowNonSequentially) + return ConsumeRowSequential(async); - // Skip over the remaining columns in the row - for (; _column < _numColumns - 1; _column++) - { - await Buffer.Ensure(4, async2); - var len = Buffer.ReadInt32(); - if (len != -1) - await Buffer.Skip(len, async2); - } - } - } + // We get here, if we're in a non-sequential mode (or the row is already in the buffer) + ConsumeRowNonSequential(); + return Task.CompletedTask; - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void ConsumeRowNonSequential() + async Task ConsumeRowSequential(bool async) { - Debug.Assert(State == ReaderState.InResult || State == ReaderState.BeforeResult); + if (async) + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + else + PgReader.Commit(resuming: false); - if (_columnStream != null) + // Skip over the remaining columns in the row + var buffer = Buffer; + // Written as a while to be able to increment _column directly after reading into it. + while (_column < ColumnCount - 1) { - _columnStream.Dispose(); - _columnStream = null; - // Disposing the stream leaves us at the end of the column - PosInColumn = ColumnLen; + await buffer.Ensure(4, async).ConfigureAwait(false); + var columnLength = buffer.ReadInt32(); + _column++; + Debug.Assert(columnLength >= -1); + if (columnLength > 0) + await buffer.Skip(columnLength, async).ConfigureAwait(false); } - Buffer.ReadPosition = _dataMsgEnd; } + } - #endregion + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void ConsumeRowNonSequential() + { + Debug.Assert(State is ReaderState.InResult or ReaderState.BeforeResult); + PgReader.Commit(resuming: false); + Buffer.ReadPosition = _dataMsgEnd; + } - #region Checks + #endregion - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckResultSet() + #region Checks + + void CheckResultSet() + { + switch (State) { - switch (State) - { - case ReaderState.BeforeResult: - case ReaderState.InResult: - break; - case ReaderState.Closed: - throw new InvalidOperationException("The reader is closed"); - case ReaderState.Disposed: - throw new ObjectDisposedException(nameof(NpgsqlDataReader)); - default: - throw new InvalidOperationException("No resultset is currently being traversed"); - } + case ReaderState.BeforeResult: + case ReaderState.InResult: + return; + case ReaderState.Closed: + ThrowHelper.ThrowInvalidOperationException("The reader is closed"); + return; + case ReaderState.Disposed: + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); + return; + default: + ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); + return; } + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - FieldDescription CheckRowAndGetField(int column) - { - switch (State) - { - case ReaderState.InResult: - break; - case ReaderState.Closed: - throw new InvalidOperationException("The reader is closed"); - case ReaderState.Disposed: - throw new ObjectDisposedException(nameof(NpgsqlDataReader)); - default: - throw new InvalidOperationException("No row is available"); - } + [MethodImpl(MethodImplOptions.NoInlining)] + T DbNullValueOrThrow(int ordinal) + { + // When T is a Nullable (and only in that case), we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; - if (column < 0 || column >= RowDescription!.NumFields) - throw new IndexOutOfRangeException($"Column must be between {0} and {RowDescription!.NumFields - 1}"); + if (typeof(T) == typeof(object)) + return (T)(object)DBNull.Value; - return RowDescription[column]; + ThrowHelper.ThrowInvalidCastException_NoValue(CheckRowAndGetField(ordinal)); + return default; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + DataFormat GetInfo(int ordinal, Type type, out PgConverter converter, out Size bufferRequirement, out bool asObject) + { + var state = State; + if (state is not ReaderState.InResult || (uint)ordinal > (uint)ColumnCount) + { + Unsafe.SkipInit(out converter); + Unsafe.SkipInit(out bufferRequirement); + Unsafe.SkipInit(out asObject); + HandleInvalidState(state, ColumnCount); + Debug.Fail("Should never get here"); } - /// - /// Checks that we have a RowDescription, but not necessary an actual resultset - /// (for operations which work in SchemaOnly mode. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - FieldDescription CheckRowDescriptionAndGetField(int column) + ref var info = ref ColumnInfoCache![ordinal]; + + Debug.Assert(info.ConverterInfo.IsDefault || ReferenceEquals(Connector.SerializerOptions, info.ConverterInfo.TypeInfo.Options), "Cache is bleeding over"); + + if (info.ConverterInfo.TypeToConvert == type) { - if (RowDescription == null) - throw new InvalidOperationException("No resultset is currently being traversed"); + converter = info.ConverterInfo.Converter; + bufferRequirement = info.ConverterInfo.BufferRequirement; + asObject = info.AsObject; + return info.DataFormat; + } - if (column < 0 || column >= RowDescription.NumFields) - throw new IndexOutOfRangeException($"Column must be between {0} and {RowDescription.NumFields - 1}"); + return Slow(ref info, out converter, out bufferRequirement, out asObject); - return RowDescription[column]; + [MethodImpl(MethodImplOptions.NoInlining)] + DataFormat Slow(ref ColumnInfo info, out PgConverter converter, out Size bufferRequirement, out bool asObject) + { + var field = CheckRowAndGetField(ordinal); + field.GetInfo(type, ref info); + converter = info.ConverterInfo.Converter; + bufferRequirement = info.ConverterInfo.BufferRequirement; + asObject = info.AsObject; + return field.DataFormat; } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + DataFormat GetDefaultInfo(int ordinal, out PgConverter converter, out Size bufferRequirement) + { + var field = CheckRowAndGetField(ordinal); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckColumnStart() + converter = field.ObjectOrDefaultInfo.Converter; + bufferRequirement = field.ObjectOrDefaultInfo.BufferRequirement; + return field.DataFormat; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + FieldDescription CheckRowAndGetField(int column) + { + var columns = RowDescription; + var state = State; + if (state is ReaderState.InResult && (uint)column < (uint)columns!.Count) + return columns[column]; + + return HandleInvalidState(state, columns?.Count ?? 0); + } + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static FieldDescription HandleInvalidState(ReaderState state, int maxColumns) + { + switch (state) { - Debug.Assert(_isSequential); - if (PosInColumn != 0) - throw new InvalidOperationException("Attempt to read a position in the column which has already been read"); + case ReaderState.InResult: + ThrowColumnOutOfRange(maxColumns); + break; + case ReaderState.Closed: + ThrowHelper.ThrowInvalidOperationException("The reader is closed"); + break; + case ReaderState.Disposed: + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); + break; + default: + ThrowHelper.ThrowInvalidOperationException("No row is available"); + break; } + return default!; + } + + /// + /// Checks that we have a RowDescription, but not necessary an actual resultset + /// (for operations which work in SchemaOnly mode. + /// + FieldDescription GetField(int column) + { + if (RowDescription is null) + ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); + + var columns = RowDescription; + if (column < 0 || column >= columns.Count) + ThrowColumnOutOfRange(columns.Count); + + return columns[column]; + } + + void CheckClosedOrDisposed() + { + if (State is (ReaderState.Closed or ReaderState.Disposed) and var state) + Throw(state); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void CheckClosedOrDisposed() + [MethodImpl(MethodImplOptions.NoInlining)] + static void Throw(ReaderState state) { - switch (State) + switch (state) { case ReaderState.Closed: - throw new InvalidOperationException("The reader is closed"); + ThrowHelper.ThrowInvalidOperationException("The reader is closed"); + return; case ReaderState.Disposed: - throw new ObjectDisposedException(nameof(NpgsqlDataReader)); - } + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); + return; + } } + } - #endregion + [DoesNotReturn] + static void ThrowColumnOutOfRange(int maxIndex) => + throw new IndexOutOfRangeException($"Column must be between {0} and {maxIndex - 1}"); - #region Misc + #endregion - /// - /// Unbinds reader from the connector. - /// Should be called before the connector is returned to the pool. - /// - internal void UnbindIfNecessary() + #region Misc + + /// + /// Unbinds reader from the connector. + /// Should be called before the connector is returned to the pool. + /// + internal void UnbindIfNecessary() + { + // We're closing the connection, but reader is not yet disposed + // We have to unbind the reader from the connector, otherwise there could be a concurrency issues + // See #3126 and #3290 + if (State != ReaderState.Disposed) { - // We're closing the connection, but reader is not yet disposed - // We have to unbind the reader from the connector, otherwise there could be a concurency issues - // See #3126 and #3290 - if (State != ReaderState.Disposed) - Connector.DataReader = new NpgsqlDataReader(Connector); + Connector.DataReader = Connector.UnboundDataReader is { State: ReaderState.Disposed } previousReader + ? previousReader + : new NpgsqlDataReader(Connector); + Connector.UnboundDataReader = this; } - - #endregion } - enum ReaderState - { - BeforeResult, - InResult, - BetweenResults, - Consumed, - Closed, - Disposed, - } + #endregion +} + +enum ReaderState +{ + BeforeResult, + InResult, + BetweenResults, + Consumed, + Closed, + Disposed, } diff --git a/src/Npgsql/NpgsqlDataSource.cs b/src/Npgsql/NpgsqlDataSource.cs new file mode 100644 index 0000000000..ac87d2d0cd --- /dev/null +++ b/src/Npgsql/NpgsqlDataSource.cs @@ -0,0 +1,539 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Internal.ResolverFactories; +using Npgsql.Properties; +using Npgsql.Util; + +namespace Npgsql; + +/// +public abstract class NpgsqlDataSource : DbDataSource +{ + /// + public override string ConnectionString { get; } + + /// + /// Contains the connection string returned to the user from + /// after the connection has been opened. Does not contain the password unless Persist Security Info=true. + /// + internal NpgsqlConnectionStringBuilder Settings { get; } + + internal NpgsqlDataSourceConfiguration Configuration { get; } + internal NpgsqlLoggingConfiguration LoggingConfiguration { get; } + + readonly PgTypeInfoResolverChain _resolverChain; + internal PgSerializerOptions SerializerOptions { get; private set; } = null!; // Initialized at bootstrapping + + /// + /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). + /// + internal NpgsqlDatabaseInfo DatabaseInfo { get; private set; } = null!; // Initialized at bootstrapping + + internal TransportSecurityHandler TransportSecurityHandler { get; } + internal RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; } + internal Action? ClientCertificatesCallback { get; } + + readonly Func? _passwordProvider; + readonly Func>? _passwordProviderAsync; + readonly Func>? _periodicPasswordProvider; + readonly TimeSpan _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval; + + internal IntegratedSecurityHandler IntegratedSecurityHandler { get; } + + internal Action? ConnectionInitializer { get; } + internal Func? ConnectionInitializerAsync { get; } + + readonly Timer? _periodicPasswordProviderTimer; + readonly CancellationTokenSource? _timerPasswordProviderCancellationTokenSource; + readonly Task _passwordRefreshTask = null!; + string? _password; + + internal bool IsBootstrapped { get; private set; } + + volatile DatabaseStateInfo _databaseStateInfo = new(); + + // Note that while the dictionary is protected by locking, we assume that the lists it contains don't need to be + // (i.e. access to connectors of a specific transaction won't be concurrent) + private protected readonly Dictionary> _pendingEnlistedConnectors + = new(); + + internal MetricsReporter MetricsReporter { get; } + internal string Name { get; } + + internal abstract (int Total, int Idle, int Busy) Statistics { get; } + + volatile int _isDisposed; + + readonly ILogger _connectionLogger; + + /// + /// Semaphore to ensure we don't perform type loading and mapping setup concurrently for this data source. + /// + readonly SemaphoreSlim _setupMappingsSemaphore = new(1); + + readonly INpgsqlNameTranslator _defaultNameTranslator; + + internal List? _hackyEnumTypeMappings; + + internal NpgsqlDataSource( + NpgsqlConnectionStringBuilder settings, + NpgsqlDataSourceConfiguration dataSourceConfig) + { + Settings = settings; + ConnectionString = settings.PersistSecurityInfo + ? settings.ToString() + : settings.ToStringWithoutPassword(); + + Configuration = dataSourceConfig; + + (var name, + LoggingConfiguration, + TransportSecurityHandler, + IntegratedSecurityHandler, + UserCertificateValidationCallback, + ClientCertificatesCallback, + _passwordProvider, + _passwordProviderAsync, + _periodicPasswordProvider, + _periodicPasswordSuccessRefreshInterval, + _periodicPasswordFailureRefreshInterval, + var resolverChain, + _hackyEnumTypeMappings, + _defaultNameTranslator, + ConnectionInitializer, + ConnectionInitializerAsync) + = dataSourceConfig; + _connectionLogger = LoggingConfiguration.ConnectionLogger; + + Debug.Assert(_passwordProvider is null || _passwordProviderAsync is not null); + + _resolverChain = resolverChain; + _password = settings.Password; + + if (_periodicPasswordSuccessRefreshInterval != default) + { + Debug.Assert(_periodicPasswordProvider is not null); + + _timerPasswordProviderCancellationTokenSource = new(); + + // Create the timer, but don't start it; the manual run below will will schedule the first refresh. + _periodicPasswordProviderTimer = new Timer(state => _ = RefreshPassword(), null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan); + // Trigger the first refresh attempt right now, outside the timer; this allows us to capture the Task so it can be observed + // in GetPasswordAsync. + _passwordRefreshTask = Task.Run(RefreshPassword); + } + + Name = name ?? ConnectionString; + MetricsReporter = new MetricsReporter(this); + } + + /// + public new NpgsqlConnection CreateConnection() + => NpgsqlConnection.FromDataSource(this); + + /// + public new NpgsqlConnection OpenConnection() + { + var connection = CreateConnection(); + + try + { + connection.Open(); + return connection; + } + catch + { + connection.Dispose(); + throw; + } + } + + /// + protected override DbConnection OpenDbConnection() + => OpenConnection(); + + /// + public new async ValueTask OpenConnectionAsync(CancellationToken cancellationToken = default) + { + var connection = CreateConnection(); + + try + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + return connection; + } + catch + { + await connection.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + /// + protected override async ValueTask OpenDbConnectionAsync(CancellationToken cancellationToken = default) + => await OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + /// + protected override DbConnection CreateDbConnection() + => CreateConnection(); + + /// + protected override DbCommand CreateDbCommand(string? commandText = null) + => CreateCommand(commandText); + + /// + protected override DbBatch CreateDbBatch() + => CreateBatch(); + + /// + /// Creates a command ready for use against this . + /// + /// An optional SQL for the command. + public new NpgsqlCommand CreateCommand(string? commandText = null) + => new NpgsqlDataSourceCommand(CreateConnection()) { CommandText = commandText }; + + /// + /// Creates a batch ready for use against this . + /// + public new NpgsqlBatch CreateBatch() + => new NpgsqlDataSourceBatch(CreateConnection()); + + /// + /// Creates a new for the given . + /// + public static NpgsqlDataSource Create(string connectionString) + => new NpgsqlDataSourceBuilder(connectionString).Build(); + + /// + /// Creates a new for the given . + /// + public static NpgsqlDataSource Create(NpgsqlConnectionStringBuilder connectionStringBuilder) + => Create(connectionStringBuilder.ToString()); + + internal async Task Bootstrap( + NpgsqlConnector connector, + NpgsqlTimeout timeout, + bool forceReload, + bool async, + CancellationToken cancellationToken) + { + if (IsBootstrapped && !forceReload) + return; + + var hasSemaphore = async + ? await _setupMappingsSemaphore.WaitAsync(timeout.CheckAndGetTimeLeft(), cancellationToken).ConfigureAwait(false) + : _setupMappingsSemaphore.Wait(timeout.CheckAndGetTimeLeft(), cancellationToken); + + if (!hasSemaphore) + throw new TimeoutException(); + + try + { + if (IsBootstrapped && !forceReload) + return; + + // The type loading below will need to send queries to the database, and that depends on a type mapper being set up (even if its + // empty). So we set up a minimal version here, and then later inject the actual DatabaseInfo. + connector.SerializerOptions = + new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog) + { + TextEncoding = connector.TextEncoding, + TypeInfoResolver = AdoTypeInfoResolverFactory.Instance.CreateResolver(), + }; + + NpgsqlDatabaseInfo databaseInfo; + + using (connector.StartUserAction(ConnectorState.Executing, cancellationToken)) + databaseInfo = await NpgsqlDatabaseInfo.Load(connector, timeout, async).ConfigureAwait(false); + + connector.DatabaseInfo = DatabaseInfo = databaseInfo; + connector.SerializerOptions = SerializerOptions = + new(databaseInfo, _resolverChain, CreateTimeZoneProvider(connector.Timezone)) + { + ArrayNullabilityMode = Settings.ArrayNullabilityMode, + EnableDateTimeInfinityConversions = !Statics.DisableDateTimeInfinityConversions, + TextEncoding = connector.TextEncoding, + DefaultNameTranslator = _defaultNameTranslator, + + }; + + IsBootstrapped = true; + } + finally + { + _setupMappingsSemaphore.Release(); + } + + // Func in a static function to make sure we don't capture state that might not stay around, like a connector. + static Func CreateTimeZoneProvider(string postgresTimeZone) + => () => + { + if (string.Equals(postgresTimeZone, "localtime", StringComparison.OrdinalIgnoreCase)) + throw new TimeZoneNotFoundException( + "The special PostgreSQL timezone 'localtime' is not supported when reading values of type 'timestamp with time zone'. " + + "Please specify a real timezone in 'postgresql.conf' on the server, or set the 'PGTZ' environment variable on the client."); + + return postgresTimeZone; + }; + } + + #region Password management + + /// + /// Manually sets the password to be used the next time a physical connection is opened. + /// Consider using instead. + /// + public string Password + { + set + { + if (_passwordProvider is not null || _periodicPasswordProvider is not null) + throw new NotSupportedException(NpgsqlStrings.CannotSetBothPasswordProviderAndPassword); + + _password = value; + } + } + + internal ValueTask GetPassword(bool async, CancellationToken cancellationToken = default) + { + if (_passwordProvider is not null) + return GetPassword(async, cancellationToken); + + // A periodic password provider is configured, but the first refresh hasn't completed yet (race condition). + if (_password is null && _periodicPasswordProvider is not null) + return GetInitialPeriodicPassword(async); + + return new(_password); + + async ValueTask GetInitialPeriodicPassword(bool async) + { + if (async) + await _passwordRefreshTask.ConfigureAwait(false); + else + _passwordRefreshTask.GetAwaiter().GetResult(); + Debug.Assert(_password is not null); + + return _password; + } + + async ValueTask GetPassword(bool async, CancellationToken cancellationToken) + { + try + { + return async ? await _passwordProviderAsync!(Settings, cancellationToken).ConfigureAwait(false) : _passwordProvider(Settings); + } + catch (Exception e) + { + _connectionLogger.LogError(e, "Password provider threw an exception"); + + throw new NpgsqlException("An exception was thrown from the password provider", e); + } + } + } + + async Task RefreshPassword() + { + try + { + _password = await _periodicPasswordProvider!(Settings, _timerPasswordProviderCancellationTokenSource!.Token).ConfigureAwait(false); + + _periodicPasswordProviderTimer!.Change(_periodicPasswordSuccessRefreshInterval, Timeout.InfiniteTimeSpan); + } + catch (Exception e) + { + _connectionLogger.LogError(e, "Periodic password provider threw an exception"); + + _periodicPasswordProviderTimer!.Change(_periodicPasswordFailureRefreshInterval, Timeout.InfiniteTimeSpan); + + throw new NpgsqlException("An exception was thrown from the periodic password provider", e); + } + } + + #endregion Password management + + internal abstract ValueTask Get( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken); + + internal abstract bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector); + + internal abstract ValueTask OpenNewConnector( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken); + + internal abstract void Return(NpgsqlConnector connector); + + internal abstract void Clear(); + + internal abstract bool OwnsConnectors { get; } + + #region Database state management + + internal DatabaseState GetDatabaseState(bool ignoreExpiration = false) + { + Debug.Assert(this is not NpgsqlMultiHostDataSource); + + var databaseStateInfo = _databaseStateInfo; + + return ignoreExpiration || !databaseStateInfo.Timeout.HasExpired + ? databaseStateInfo.State + : DatabaseState.Unknown; + } + + internal DatabaseState UpdateDatabaseState( + DatabaseState newState, + DateTime timeStamp, + TimeSpan stateExpiration, + bool ignoreTimeStamp = false) + { + Debug.Assert(this is not NpgsqlMultiHostDataSource); + + var databaseStateInfo = _databaseStateInfo; + + if (!ignoreTimeStamp && timeStamp <= databaseStateInfo.TimeStamp) + return _databaseStateInfo.State; + + _databaseStateInfo = new(newState, new NpgsqlTimeout(stateExpiration), timeStamp); + + return newState; + } + + #endregion Database state management + + #region Pending Enlisted Connections + + internal virtual void AddPendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) + { + lock (_pendingEnlistedConnectors) + { + if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) + list = _pendingEnlistedConnectors[transaction] = new List(1); + list.Add(connector); + } + } + + internal virtual bool TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) + { + lock (_pendingEnlistedConnectors) + { + if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) + return false; + list.Remove(connector); + if (list.Count == 0) + _pendingEnlistedConnectors.Remove(transaction); + return true; + } + } + + internal virtual bool TryRentEnlistedPending(Transaction transaction, NpgsqlConnection connection, + [NotNullWhen(true)] out NpgsqlConnector? connector) + { + lock (_pendingEnlistedConnectors) + { + if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) + { + connector = null; + return false; + } + connector = list[list.Count - 1]; + list.RemoveAt(list.Count - 1); + if (list.Count == 0) + _pendingEnlistedConnectors.Remove(transaction); + return true; + } + } + + #endregion + + #region Dispose + + /// + protected sealed override void Dispose(bool disposing) + { + if (disposing && Interlocked.CompareExchange(ref _isDisposed, 1, 0) == 0) + DisposeBase(); + } + + /// + protected virtual void DisposeBase() + { + var cancellationTokenSource = _timerPasswordProviderCancellationTokenSource; + if (cancellationTokenSource is not null) + { + cancellationTokenSource.Cancel(); + cancellationTokenSource.Dispose(); + } + + _periodicPasswordProviderTimer?.Dispose(); + _setupMappingsSemaphore.Dispose(); + MetricsReporter.Dispose(); + + Clear(); + } + + /// + protected sealed override ValueTask DisposeAsyncCore() + { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) == 0) + return DisposeAsyncBase(); + + return default; + } + +#pragma warning disable CS1998 + /// + protected virtual async ValueTask DisposeAsyncBase() + { + var cancellationTokenSource = _timerPasswordProviderCancellationTokenSource; + if (cancellationTokenSource is not null) + { + cancellationTokenSource.Cancel(); + cancellationTokenSource.Dispose(); + } + + if (_periodicPasswordProviderTimer is not null) + { +#if NET5_0_OR_GREATER + await _periodicPasswordProviderTimer.DisposeAsync().ConfigureAwait(false); +#else + _periodicPasswordProviderTimer.Dispose(); +#endif + } + + _setupMappingsSemaphore.Dispose(); + MetricsReporter.Dispose(); + + // TODO: async Clear, #4499 + Clear(); + } +#pragma warning restore CS1998 + + private protected void CheckDisposed() + { + if (_isDisposed == 1) + ThrowHelper.ThrowObjectDisposedException(GetType().FullName); + } + + #endregion + + sealed class DatabaseStateInfo + { + internal readonly DatabaseState State; + internal readonly NpgsqlTimeout Timeout; + // While the TimeStamp is not strictly required, it does lower the risk of overwriting the current state with an old value + internal readonly DateTime TimeStamp; + + public DatabaseStateInfo() : this(default, default, default) { } + + public DatabaseStateInfo(DatabaseState state, NpgsqlTimeout timeout, DateTime timeStamp) + => (State, Timeout, TimeStamp) = (state, timeout, timeStamp); + } +} diff --git a/src/Npgsql/NpgsqlDataSourceBatch.cs b/src/Npgsql/NpgsqlDataSourceBatch.cs new file mode 100644 index 0000000000..fa239ee8e6 --- /dev/null +++ b/src/Npgsql/NpgsqlDataSourceBatch.cs @@ -0,0 +1,35 @@ +using System; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Properties; + +namespace Npgsql; + +sealed class NpgsqlDataSourceBatch : NpgsqlBatch +{ + internal NpgsqlDataSourceBatch(NpgsqlConnection connection) + : base(new NpgsqlDataSourceCommand(DefaultBatchCommandsSize, connection)) + { + } + + // The below are incompatible with batches executed directly against DbDataSource, since no DbConnection + // is involved at the user API level and the batch owns the DbConnection. + public override void Prepare() + => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + + public override Task PrepareAsync(CancellationToken cancellationToken = default) + => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + + protected override DbConnection? DbConnection + { + get => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + set => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + } + + protected override DbTransaction? DbTransaction + { + get => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + set => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceBatch); + } +} diff --git a/src/Npgsql/NpgsqlDataSourceBuilder.cs b/src/Npgsql/NpgsqlDataSourceBuilder.cs new file mode 100644 index 0000000000..e304a559cc --- /dev/null +++ b/src/Npgsql/NpgsqlDataSourceBuilder.cs @@ -0,0 +1,458 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Internal.ResolverFactories; +using Npgsql.TypeMapping; +using NpgsqlTypes; + +namespace Npgsql; + +/// +/// Provides a simple API for configuring and creating an , from which database connections can be obtained. +/// +public sealed class NpgsqlDataSourceBuilder : INpgsqlTypeMapper +{ + static UnsupportedTypeInfoResolver UnsupportedTypeInfoResolver { get; } = new(); + + readonly NpgsqlSlimDataSourceBuilder _internalBuilder; + + /// + /// A diagnostics name used by Npgsql when generating tracing, logging and metrics. + /// + public string? Name + { + get => _internalBuilder.Name; + set => _internalBuilder.Name = value; + } + + /// + public INpgsqlNameTranslator DefaultNameTranslator + { + get => _internalBuilder.DefaultNameTranslator; + set => _internalBuilder.DefaultNameTranslator = value; + } + + /// + /// A connection string builder that can be used to configured the connection string on the builder. + /// + public NpgsqlConnectionStringBuilder ConnectionStringBuilder => _internalBuilder.ConnectionStringBuilder; + + /// + /// Returns the connection string, as currently configured on the builder. + /// + public string ConnectionString => _internalBuilder.ConnectionString; + + internal static void ResetGlobalMappings(bool overwrite) + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new PgTypeInfoResolverFactory[] + { + overwrite ? new AdoTypeInfoResolverFactory() : AdoTypeInfoResolverFactory.Instance, + new ExtraConversionResolverFactory(), + new JsonTypeInfoResolverFactory(), + new RecordTypeInfoResolverFactory(), + new FullTextSearchTypeInfoResolverFactory(), + new NetworkTypeInfoResolverFactory(), + new GeometricTypeInfoResolverFactory(), + new LTreeTypeInfoResolverFactory(), + }, static () => + { + var builder = new PgTypeInfoResolverChainBuilder(); + builder.EnableRanges(); + builder.EnableMultiranges(); + builder.EnableArrays(); + return builder; + }, overwrite); + + static NpgsqlDataSourceBuilder() + => ResetGlobalMappings(overwrite: false); + + /// + /// Constructs a new , optionally starting out from the given . + /// + public NpgsqlDataSourceBuilder(string? connectionString = null) + { + _internalBuilder = new(new NpgsqlConnectionStringBuilder(connectionString)); + _internalBuilder.ConfigureDefaultFactories = static instance => + { + instance.AppendDefaultFactories(); + instance.AppendResolverFactory(new ExtraConversionResolverFactory()); + instance.AppendResolverFactory(() => new JsonTypeInfoResolverFactory(instance.JsonSerializerOptions)); + instance.AppendResolverFactory(new RecordTypeInfoResolverFactory()); + instance.AppendResolverFactory(new FullTextSearchTypeInfoResolverFactory()); + instance.AppendResolverFactory(new NetworkTypeInfoResolverFactory()); + instance.AppendResolverFactory(new GeometricTypeInfoResolverFactory()); + instance.AppendResolverFactory(new LTreeTypeInfoResolverFactory()); + }; + _internalBuilder.ConfigureResolverChain = static chain => chain.Add(UnsupportedTypeInfoResolver); + _internalBuilder.EnableTransportSecurity(); + _internalBuilder.EnableIntegratedSecurity(); + _internalBuilder.EnableRanges(); + _internalBuilder.EnableMultiranges(); + _internalBuilder.EnableArrays(); + } + + /// + /// Sets the that will be used for logging. + /// + /// The logger factory to be used. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseLoggerFactory(ILoggerFactory? loggerFactory) + { + _internalBuilder.UseLoggerFactory(loggerFactory); + return this; + } + + /// + /// Enables parameters to be included in logging. This includes potentially sensitive information from data sent to PostgreSQL. + /// You should only enable this flag in development, or if you have the appropriate security measures in place based on the + /// sensitivity of this data. + /// + /// If , then sensitive data is logged. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder EnableParameterLogging(bool parameterLoggingEnabled = true) + { + _internalBuilder.EnableParameterLogging(parameterLoggingEnabled); + return this; + } + + /// + /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. + /// + /// Options to customize JSON serialization and deserialization. + /// + public NpgsqlDataSourceBuilder ConfigureJsonOptions(JsonSerializerOptions serializerOptions) + { + _internalBuilder.ConfigureJsonOptions(serializerOptions); + return this; + } + + /// + /// Sets up dynamic System.Text.Json mappings. This allows mapping arbitrary .NET types to PostgreSQL json and jsonb + /// types, as well as and its derived types. + /// + /// + /// A list of CLR types to map to PostgreSQL jsonb (no need to specify ). + /// + /// + /// A list of CLR types to map to PostgreSQL json (no need to specify ). + /// + /// + /// Due to the dynamic nature of these mappings, they are not compatible with NativeAOT or trimming. + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + public NpgsqlDataSourceBuilder EnableDynamicJson(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null) + { + _internalBuilder.EnableDynamicJson(jsonbClrTypes, jsonClrTypes); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL record type as a .NET or . + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode("The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + public NpgsqlDataSourceBuilder EnableRecordsAsTuples() + { + AddTypeInfoResolverFactory(new TupledRecordTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up mappings allowing the use of unmapped enum, range and multirange types. + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + public NpgsqlDataSourceBuilder EnableUnmappedTypes() + { + AddTypeInfoResolverFactory(new UnmappedTypeInfoResolverFactory()); + return this; + } + + #region Authentication + + /// + /// When using SSL/TLS, this is a callback that allows customizing how the PostgreSQL-provided certificate is verified. This is an + /// advanced API, consider using or instead. + /// + /// The callback containing custom callback verification logic. + /// + /// + /// Cannot be used in conjunction with , or + /// . + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseUserCertificateValidationCallback(RemoteCertificateValidationCallback userCertificateValidationCallback) + { + _internalBuilder.UseUserCertificateValidationCallback(userCertificateValidationCallback); + return this; + } + + /// + /// Specifies an SSL/TLS certificate which Npgsql will send to PostgreSQL for certificate-based authentication. + /// + /// The client certificate to be sent to PostgreSQL when opening a connection. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseClientCertificate(X509Certificate? clientCertificate) + { + _internalBuilder.UseClientCertificate(clientCertificate); + return this; + } + + /// + /// Specifies a collection of SSL/TLS certificates which Npgsql will send to PostgreSQL for certificate-based authentication. + /// + /// The client certificate collection to be sent to PostgreSQL when opening a connection. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseClientCertificates(X509CertificateCollection? clientCertificates) + { + _internalBuilder.UseClientCertificates(clientCertificates); + return this; + } + + /// + /// Specifies a callback to modify the collection of SSL/TLS client certificates which Npgsql will send to PostgreSQL for + /// certificate-based authentication. This is an advanced API, consider using or + /// instead. + /// + /// The callback to modify the client certificate collection. + /// + /// + /// The callback is invoked every time a physical connection is opened, and is therefore suitable for rotating short-lived client + /// certificates. Simply make sure the certificate collection argument has the up-to-date certificate(s). + /// + /// + /// The callback's collection argument already includes any client certificates specified via the connection string or environment + /// variables. + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseClientCertificatesCallback(Action? clientCertificatesCallback) + { + _internalBuilder.UseClientCertificatesCallback(clientCertificatesCallback); + return this; + } + + /// + /// Sets the that will be used validate SSL certificate, received from the server. + /// + /// The CA certificate. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseRootCertificate(X509Certificate2? rootCertificate) + { + _internalBuilder.UseRootCertificate(rootCertificate); + return this; + } + + /// + /// Specifies a callback that will be used to validate SSL certificate, received from the server. + /// + /// The callback to get CA certificate. + /// The same builder instance so that multiple calls can be chained. + /// + /// This overload, which accepts a callback, is suitable for scenarios where the certificate rotates + /// and might change during the lifetime of the application. + /// When that's not the case, use the overload which directly accepts the certificate. + /// + public NpgsqlDataSourceBuilder UseRootCertificateCallback(Func? rootCertificateCallback) + { + _internalBuilder.UseRootCertificateCallback(rootCertificateCallback); + return this; + } + + /// + /// Configures a periodic password provider, which is automatically called by the data source at some regular interval. This is the + /// recommended way to fetch a rotating access token. + /// + /// A callback which returns the password to be sent to PostgreSQL. + /// How long to cache the password before re-invoking the callback. + /// + /// If a password refresh attempt fails, it will be re-attempted with this interval. + /// This should typically be much lower than . + /// + /// The same builder instance so that multiple calls can be chained. + /// + /// + /// The provided callback is invoked in a timer, and not when opening connections. It therefore doesn't affect opening time. + /// + /// + /// The provided cancellation token is only triggered when the entire data source is disposed. If you'd like to apply a timeout to the + /// token fetching, do so within the provided callback. + /// + /// + public NpgsqlDataSourceBuilder UsePeriodicPasswordProvider( + Func>? passwordProvider, + TimeSpan successRefreshInterval, + TimeSpan failureRefreshInterval) + { + _internalBuilder.UsePeriodicPasswordProvider(passwordProvider, successRefreshInterval, failureRefreshInterval); + return this; + } + + /// + /// Configures a password provider, which is called by the data source when opening connections. + /// + /// + /// A callback that may be invoked during which returns the password to be sent to PostgreSQL. + /// + /// + /// A callback that may be invoked during which returns the password to be sent to PostgreSQL. + /// + /// The same builder instance so that multiple calls can be chained. + /// + /// + /// The provided callback is invoked when opening connections. Therefore its important the callback internally depends on cached + /// data or returns quickly otherwise. Any unnecessary delay will affect connection opening time. + /// + /// + public NpgsqlDataSourceBuilder UsePasswordProvider( + Func? passwordProvider, + Func>? passwordProviderAsync) + { + _internalBuilder.UsePasswordProvider(passwordProvider, passwordProviderAsync); + return this; + } + + #endregion Authentication + + #region Type mapping + + /// + public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) + => _internalBuilder.AddTypeInfoResolverFactory(factory); + + /// + void INpgsqlTypeMapper.Reset() => ((INpgsqlTypeMapper)_internalBuilder).Reset(); + + /// + public INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + { + _internalBuilder.MapEnum(pgName, nameTranslator); + return this; + } + + /// + public bool UnmapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + => _internalBuilder.UnmapEnum(pgName, nameTranslator); + + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + public INpgsqlTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _internalBuilder.MapEnum(clrType, pgName, nameTranslator); + + /// + public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _internalBuilder.UnmapEnum(clrType, pgName, nameTranslator); + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _internalBuilder.MapComposite(pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _internalBuilder.MapComposite(clrType, pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _internalBuilder.UnmapComposite(pgName, nameTranslator); + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _internalBuilder.UnmapComposite(clrType, pgName, nameTranslator); + + #endregion Type mapping + + /// + /// Register a connection initializer, which allows executing arbitrary commands when a physical database connection is first opened. + /// + /// + /// A synchronous connection initialization lambda, which will be called from when a new physical + /// connection is opened. + /// + /// + /// An asynchronous connection initialization lambda, which will be called from + /// when a new physical connection is opened. + /// + /// + /// If an initializer is registered, both sync and async versions must be provided. If you do not use sync APIs in your code, simply + /// throw , which would also catch accidental cases of sync opening. + /// + /// + /// Take care that the setting you apply in the initializer does not get reverted when the connection is returned to the pool, since + /// Npgsql sends DISCARD ALL by default. The option can be used to + /// turn this off. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UsePhysicalConnectionInitializer( + Action? connectionInitializer, + Func? connectionInitializerAsync) + { + _internalBuilder.UsePhysicalConnectionInitializer(connectionInitializer, connectionInitializerAsync); + return this; + } + + /// + /// Builds and returns an which is ready for use. + /// + public NpgsqlDataSource Build() + => _internalBuilder.Build(); + + /// + /// Builds and returns a which is ready for use for load-balancing and failover scenarios. + /// + public NpgsqlMultiHostDataSource BuildMultiHost() + => _internalBuilder.BuildMultiHost(); + + INpgsqlTypeMapper INpgsqlTypeMapper.ConfigureJsonOptions(JsonSerializerOptions serializerOptions) + => ConfigureJsonOptions(serializerOptions); + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode( + "Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableDynamicJson(Type[]? jsonbClrTypes, Type[]? jsonClrTypes) + => EnableDynamicJson(jsonbClrTypes, jsonClrTypes); + + [RequiresUnreferencedCode( + "The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableRecordsAsTuples() + => EnableRecordsAsTuples(); + + [RequiresUnreferencedCode( + "The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableUnmappedTypes() + => EnableUnmappedTypes(); +} diff --git a/src/Npgsql/NpgsqlDataSourceCommand.cs b/src/Npgsql/NpgsqlDataSourceCommand.cs new file mode 100644 index 0000000000..3ff565de66 --- /dev/null +++ b/src/Npgsql/NpgsqlDataSourceCommand.cs @@ -0,0 +1,71 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Properties; + +namespace Npgsql; + +sealed class NpgsqlDataSourceCommand : NpgsqlCommand +{ + internal NpgsqlDataSourceCommand(NpgsqlConnection connection) + : base(cmdText: null, connection) + { + } + + // For NpgsqlBatch only + internal NpgsqlDataSourceCommand(int batchCommandCapacity, NpgsqlConnection connection) + : base(batchCommandCapacity, connection) + { + } + + internal override async ValueTask ExecuteReader( + bool async, CommandBehavior behavior, + CancellationToken cancellationToken) + { + await InternalConnection!.Open(async, cancellationToken).ConfigureAwait(false); + + try + { + return await base.ExecuteReader( + async, + behavior | CommandBehavior.CloseConnection, + cancellationToken) + .ConfigureAwait(false); + } + catch + { + try + { + await InternalConnection.Close(async).ConfigureAwait(false); + } + catch + { + // Swallow to allow the original exception to bubble up + } + + throw; + } + } + + // The below are incompatible with commands executed directly against DbDataSource, since no DbConnection + // is involved at the user API level and the command owns the DbConnection. + public override void Prepare() + => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + + public override Task PrepareAsync(CancellationToken cancellationToken = default) + => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + + protected override DbConnection? DbConnection + { + get => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + set => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + } + + protected override DbTransaction? DbTransaction + { + get => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + set => throw new NotSupportedException(NpgsqlStrings.NotSupportedOnDataSourceCommand); + } +} diff --git a/src/Npgsql/NpgsqlDataSourceConfiguration.cs b/src/Npgsql/NpgsqlDataSourceConfiguration.cs new file mode 100644 index 0000000000..ec3e5e4611 --- /dev/null +++ b/src/Npgsql/NpgsqlDataSourceConfiguration.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; + +namespace Npgsql; + +sealed record NpgsqlDataSourceConfiguration(string? Name, + NpgsqlLoggingConfiguration LoggingConfiguration, + TransportSecurityHandler TransportSecurityHandler, + IntegratedSecurityHandler userCertificateValidationCallback, + RemoteCertificateValidationCallback? UserCertificateValidationCallback, + Action? ClientCertificatesCallback, + Func? PasswordProvider, + Func>? PasswordProviderAsync, + Func>? PeriodicPasswordProvider, + TimeSpan PeriodicPasswordSuccessRefreshInterval, + TimeSpan PeriodicPasswordFailureRefreshInterval, + PgTypeInfoResolverChain ResolverChain, + List HackyEnumMappings, + INpgsqlNameTranslator DefaultNameTranslator, + Action? ConnectionInitializer, + Func? ConnectionInitializerAsync); diff --git a/src/Npgsql/NpgsqlDatabaseInfo.cs b/src/Npgsql/NpgsqlDatabaseInfo.cs deleted file mode 100644 index 6d09eff688..0000000000 --- a/src/Npgsql/NpgsqlDatabaseInfo.cs +++ /dev/null @@ -1,261 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; -using Npgsql.Util; - -namespace Npgsql -{ - /// - /// Base class for implementations which provide information about PostgreSQL and PostgreSQL-like databases - /// (e.g. type definitions, capabilities...). - /// - public abstract class NpgsqlDatabaseInfo - { - #region Fields - - internal static readonly ConcurrentDictionary Cache - = new ConcurrentDictionary(); - - static volatile INpgsqlDatabaseInfoFactory[] Factories = new INpgsqlDatabaseInfoFactory[] - { - new PostgresMinimalDatabaseInfoFactory(), - new PostgresDatabaseInfoFactory() - }; - - #endregion Fields - - #region General database info - - /// - /// The hostname of IP address of the database. - /// - public string Host { get; } - /// - /// The TCP port of the database. - /// - public int Port { get; } - /// - /// The database name. - /// - public string Name { get; } - /// - /// The version of the PostgreSQL database we're connected to, as reported in the "server_version" parameter. - /// Exposed via . - /// - public Version Version { get; } - - #endregion General database info - - #region Supported capabilities and features - - /// - /// Whether the backend supports range types. - /// - public virtual bool SupportsRangeTypes => Version.IsGreaterOrEqual(9, 2, 0); - /// - /// Whether the backend supports enum types. - /// - public virtual bool SupportsEnumTypes => Version.IsGreaterOrEqual(8, 3, 0); - /// - /// Whether the backend supports the CLOSE ALL statement. - /// - public virtual bool SupportsCloseAll => Version.IsGreaterOrEqual(8, 3, 0); - /// - /// Whether the backend supports advisory locks. - /// - public virtual bool SupportsAdvisoryLocks => Version.IsGreaterOrEqual(8, 2, 0); - /// - /// Whether the backend supports the DISCARD SEQUENCES statement. - /// - public virtual bool SupportsDiscardSequences => Version.IsGreaterOrEqual(9, 4, 0); - /// - /// Whether the backend supports the UNLISTEN statement. - /// - public virtual bool SupportsUnlisten => Version.IsGreaterOrEqual(6, 4, 0); // overridden by PostgresDatabase - /// - /// Whether the backend supports the DISCARD TEMP statement. - /// - public virtual bool SupportsDiscardTemp => Version.IsGreaterOrEqual(8, 3, 0); - /// - /// Whether the backend supports the DISCARD statement. - /// - public virtual bool SupportsDiscard => Version.IsGreaterOrEqual(8, 3, 0); - - /// - /// Reports whether the backend uses the newer integer timestamp representation. - /// - public virtual bool HasIntegerDateTimes { get; protected set; } = true; - - /// - /// Whether the database supports transactions. - /// - public virtual bool SupportsTransactions { get; protected set; } = true; - - #endregion Supported capabilities and features - - #region Types - - readonly List _baseTypesMutable = new List(); - readonly List _arrayTypesMutable = new List(); - readonly List _rangeTypesMutable = new List(); - readonly List _enumTypesMutable = new List(); - readonly List _compositeTypesMutable = new List(); - readonly List _domainTypesMutable = new List(); - - internal IReadOnlyList BaseTypes => _baseTypesMutable; - internal IReadOnlyList ArrayTypes => _arrayTypesMutable; - internal IReadOnlyList RangeTypes => _rangeTypesMutable; - internal IReadOnlyList EnumTypes => _enumTypesMutable; - internal IReadOnlyList CompositeTypes => _compositeTypesMutable; - internal IReadOnlyList DomainTypes => _domainTypesMutable; - - /// - /// Indexes backend types by their type OID. - /// - internal Dictionary ByOID { get; } = new Dictionary(); - - /// - /// Indexes backend types by their PostgreSQL name, including namespace (e.g. pg_catalog.int4). - /// Only used for enums and composites. - /// - internal Dictionary ByFullName { get; } = new Dictionary(); - - /// - /// Indexes backend types by their PostgreSQL name, not including namespace. - /// If more than one type exists with the same name (i.e. in different namespaces) this - /// table will contain an entry with a null value. - /// Only used for enums and composites. - /// - internal Dictionary ByName { get; } = new Dictionary(); - - /// - /// Initializes the instance of . - /// - protected NpgsqlDatabaseInfo(string host, int port, string databaseName, Version version) - { - Host = host; - Port = port; - Name = databaseName; - Version = version; - } - - internal void ProcessTypes() - { - foreach (var type in GetTypes()) - { - ByOID[type.OID] = type; - ByFullName[type.FullName] = type; - // If more than one type exists with the same partial name, we place a null value. - // This allows us to detect this case later and force the user to use full names only. - ByName[type.Name] = ByName.ContainsKey(type.Name) - ? null - : type; - - switch (type) - { - case PostgresBaseType baseType: - _baseTypesMutable.Add(baseType); - continue; - case PostgresArrayType arrayType: - _arrayTypesMutable.Add(arrayType); - continue; - case PostgresRangeType rangeType: - _rangeTypesMutable.Add(rangeType); - continue; - case PostgresEnumType enumType: - _enumTypesMutable.Add(enumType); - continue; - case PostgresCompositeType compositeType: - _compositeTypesMutable.Add(compositeType); - continue; - case PostgresDomainType domainType: - _domainTypesMutable.Add(domainType); - continue; - default: - throw new ArgumentOutOfRangeException(); - } - } - } - - /// - /// Provides all PostgreSQL types detected in this database. - /// - /// - protected abstract IEnumerable GetTypes(); - - #endregion Types - - #region Misc - - /// - /// Parses a PostgreSQL server version (e.g. 10.1, 9.6.3) and returns a CLR Version. - /// - protected static Version ParseServerVersion(string value) - { - var versionString = value.Trim(); - for (var idx = 0; idx != versionString.Length; ++idx) - { - var c = value[idx]; - if (!char.IsDigit(c) && c != '.') - { - versionString = versionString.Substring(0, idx); - break; - } - } - if (!versionString.Contains(".")) - versionString += ".0"; - return new Version(versionString); - } - - #endregion Misc - - #region Factory management - - /// - /// Registers a new database info factory, which is used to load information about databases. - /// - public static void RegisterFactory(INpgsqlDatabaseInfoFactory factory) - { - if (factory == null) - throw new ArgumentNullException(nameof(factory)); - - var factories = new INpgsqlDatabaseInfoFactory[Factories.Length + 1]; - factories[0] = factory; - Array.Copy(Factories, 0, factories, 1, Factories.Length); - Factories = factories; - - Cache.Clear(); - } - - internal static async Task Load(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) - { - foreach (var factory in Factories) - { - var dbInfo = await factory.Load(conn, timeout, async); - if (dbInfo != null) - { - dbInfo.ProcessTypes(); - return dbInfo; - } - } - - // Should never be here - throw new NpgsqlException("No DatabaseInfoFactory could be found for this connection"); - } - - // For tests - internal static void ResetFactories() - { - Factories = new INpgsqlDatabaseInfoFactory[] - { - new PostgresMinimalDatabaseInfoFactory(), - new PostgresDatabaseInfoFactory() - }; - Cache.Clear(); - } - - #endregion Factory management - } -} diff --git a/src/Npgsql/NpgsqlDatabaseInfoCacheKey.cs b/src/Npgsql/NpgsqlDatabaseInfoCacheKey.cs deleted file mode 100644 index ff2205eeef..0000000000 --- a/src/Npgsql/NpgsqlDatabaseInfoCacheKey.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; - -namespace Npgsql -{ - readonly struct NpgsqlDatabaseInfoCacheKey : IEquatable - { - public readonly int Port; - public readonly string? Host; - public readonly string? Database; - public readonly ServerCompatibilityMode CompatibilityMode; - - public NpgsqlDatabaseInfoCacheKey(NpgsqlConnectionStringBuilder connectionString) - { - Port = connectionString.Port; - Host = connectionString.Host; - Database = connectionString.Database; - CompatibilityMode = connectionString.ServerCompatibilityMode; - } - - public bool Equals(NpgsqlDatabaseInfoCacheKey other) => - Port == other.Port && - Host == other.Host && - Database == other.Database && - CompatibilityMode == other.CompatibilityMode; - - public override bool Equals(object? obj) => - obj is NpgsqlDatabaseInfoCacheKey key && key.Equals(this); - - public override int GetHashCode() => - Port.GetHashCode() ^ - Host?.GetHashCode() ?? 0 ^ - Database?.GetHashCode() ?? 0 ^ - CompatibilityMode.GetHashCode(); - } -} diff --git a/src/Npgsql/NpgsqlDiagnostics.cs b/src/Npgsql/NpgsqlDiagnostics.cs new file mode 100644 index 0000000000..2037fec667 --- /dev/null +++ b/src/Npgsql/NpgsqlDiagnostics.cs @@ -0,0 +1,7 @@ +namespace Npgsql; + +static class NpgsqlDiagnostics +{ + public const string ConvertersExperimental = "NPG9001"; + public const string DatabaseInfoExperimental = "NPG9002"; +} diff --git a/src/Npgsql/NpgsqlEventId.cs b/src/Npgsql/NpgsqlEventId.cs new file mode 100644 index 0000000000..cf82ea063d --- /dev/null +++ b/src/Npgsql/NpgsqlEventId.cs @@ -0,0 +1,110 @@ +namespace Npgsql; + +#pragma warning disable CS1591 +#pragma warning disable RS0016 + +public static class NpgsqlEventId +{ + #region Connection + + public const int OpeningConnection = 1000; + public const int OpenedConnection = 1001; + public const int ClosingConnection = 1003; + public const int ClosedConnection = 1004; + + public const int OpeningPhysicalConnection = 1110; + public const int OpenedPhysicalConnection = 1111; + public const int ClosingPhysicalConnection = 1112; + public const int ClosedPhysicalConnection = 1113; + + public const int StartingWait = 1300; + public const int ReceivedNotice = 1301; + + public const int ConnectionExceededMaximumLifetime = 1500; + + public const int SendingKeepalive = 1600; + public const int CompletedKeepalive = 1601; + public const int KeepaliveFailed = 1602; + + public const int BreakingConnection = 1900; + public const int CaughtUserExceptionInNoticeEventHandler = 1901; + public const int CaughtUserExceptionInNotificationEventHandler = 1902; + public const int ExceptionWhenClosingPhysicalConnection = 1903; + public const int ExceptionWhenOpeningConnectionForMultiplexing = 1904; + + #endregion Connection + + #region Command + + public const int ExecutingCommand = 2000; + public const int CommandExecutionCompleted = 2001; + public const int CancellingCommand = 2002; + public const int ExecutingInternalCommand = 2003; + + public const int PreparingCommandExplicitly = 2100; + public const int CommandPreparedExplicitly = 2101; + public const int AutoPreparingStatement = 2102; + public const int UnpreparingCommand = 2103; + + public const int DerivingParameters = 2500; + + public const int ExceptionWhenWritingMultiplexedCommands = 2600; + + #endregion Command + + #region Transaction + + public const int StartedTransaction = 30000; + public const int CommittedTransaction = 30001; + public const int RolledBackTransaction = 30002; + + public const int CreatingSavepoint = 30100; + public const int RolledBackToSavepoint = 30101; + public const int ReleasedSavepoint = 30102; + + public const int ExceptionDuringTransactionDispose = 30200; + + public const int EnlistedVolatileResourceManager = 31000; + public const int CommittingSinglePhaseTransaction = 31001; + public const int RollingBackSinglePhaseTransaction = 31002; + public const int SinglePhaseTransactionRollbackFailed = 31003; + public const int PreparingTwoPhaseTransaction = 31004; + public const int CommittingTwoPhaseTransaction = 31005; + public const int TwoPhaseTransactionCommitFailed = 31006; + public const int RollingBackTwoPhaseTransaction = 31007; + public const int TwoPhaseTransactionRollbackFailed = 31008; + public const int TwoPhaseTransactionInDoubt = 31009; + public const int ConnectionInUseWhenRollingBack = 31010; + public const int CleaningUpResourceManager = 31011; + + #endregion Transaction + + #region Copy + + public const int StartingBinaryExport = 40000; + public const int StartingBinaryImport = 40001; + public const int StartingTextExport = 40002; + public const int StartingTextImport = 40003; + public const int StartingRawCopy = 40004; + + public const int CopyOperationCompleted = 40100; + public const int CopyOperationCancelled = 40101; + public const int ExceptionWhenDisposingCopyOperation = 40102; + + #endregion Copy + + #region Replication + + public const int CreatingReplicationSlot = 50000; + public const int DroppingReplicationSlot = 50001; + public const int StartingLogicalReplication = 50002; + public const int StartingPhysicalReplication = 50003; + public const int ExecutingReplicationCommand = 50004; + + public const int ReceivedReplicationPrimaryKeepalive = 50100; + public const int SendingReplicationStandbyStatusUpdate = 50101; + public const int SentReplicationFeedbackMessage = 50102; + public const int ReplicationFeedbackMessageSendingFailed = 50103; + + #endregion Replication +} diff --git a/src/Npgsql/NpgsqlEventSource.cs b/src/Npgsql/NpgsqlEventSource.cs index 6f00103128..1e9b82c9c5 100644 --- a/src/Npgsql/NpgsqlEventSource.cs +++ b/src/Npgsql/NpgsqlEventSource.cs @@ -1,211 +1,231 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Diagnostics.Tracing; -using System.Runtime.CompilerServices; -namespace Npgsql +namespace Npgsql; + +sealed class NpgsqlEventSource : EventSource { - sealed class NpgsqlEventSource : EventSource - { - public static readonly NpgsqlEventSource Log = new NpgsqlEventSource(); + public static readonly NpgsqlEventSource Log = new(); - const string EventSourceName = "Npgsql"; + const string EventSourceName = "Npgsql"; - internal const int CommandStartId = 3; - internal const int CommandStopId = 4; + internal const int CommandStartId = 3; + internal const int CommandStopId = 4; #if !NETSTANDARD2_0 - IncrementingPollingCounter? _bytesWrittenPerSecondCounter; - IncrementingPollingCounter? _bytesReadPerSecondCounter; - - IncrementingPollingCounter? _commandsPerSecondCounter; - PollingCounter? _totalCommandsCounter; - PollingCounter? _failedCommandsCounter; - PollingCounter? _currentCommandsCounter; - PollingCounter? _preparedCommandsRatioCounter; + IncrementingPollingCounter? _bytesWrittenPerSecondCounter; + IncrementingPollingCounter? _bytesReadPerSecondCounter; - PollingCounter? _poolsCounter; - PollingCounter? _idleConnectionsCounter; - PollingCounter? _busyConnectionsCounter; + IncrementingPollingCounter? _commandsPerSecondCounter; + PollingCounter? _totalCommandsCounter; + PollingCounter? _failedCommandsCounter; + PollingCounter? _currentCommandsCounter; + PollingCounter? _preparedCommandsRatioCounter; - PollingCounter? _multiplexingAverageCommandsPerBatchCounter; - PollingCounter? _multiplexingAverageWaitsPerBatchCounter; - PollingCounter? _multiplexingAverageWriteTimePerBatchCounter; + PollingCounter? _poolsCounter; + readonly object _dataSourcesLock = new(); + readonly Dictionary _dataSources = new(); + PollingCounter? _multiplexingAverageCommandsPerBatchCounter; + PollingCounter? _multiplexingAverageWriteTimePerBatchCounter; #endif - long _bytesWritten; - long _bytesRead; - long _totalCommands; - long _totalPreparedCommands; - long _currentCommands; - long _failedCommands; + long _bytesWritten; + long _bytesRead; - int _pools; + long _totalCommands; + long _totalPreparedCommands; + long _currentCommands; + long _failedCommands; - long _multiplexingBatchesSent; - long _multiplexingCommandsSent; - long _multiplexingWaits; - long _multiplexingTicksWritten; + long _multiplexingBatchesSent; + long _multiplexingCommandsSent; + long _multiplexingTicksWritten; - internal NpgsqlEventSource() : base(EventSourceName) {} + internal NpgsqlEventSource() : base(EventSourceName) {} - // NOTE - // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They - // enable creating 'activities'. - // For more information, take a look at the following blog post: - // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ - // - A stop event's event id must be next one after its start event. + // NOTE + // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They + // enable creating 'activities'. + // For more information, take a look at the following blog post: + // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ + // - A stop event's event id must be next one after its start event. - internal void BytesWritten(long bytesWritten) => Interlocked.Add(ref _bytesWritten, bytesWritten); - internal void BytesRead(long bytesRead) => Interlocked.Add(ref _bytesRead, bytesRead); + internal void BytesWritten(long bytesWritten) + { + if (IsEnabled()) + Interlocked.Add(ref _bytesWritten, bytesWritten); + } - public void CommandStart(string sql) + internal void BytesRead(long bytesRead) + { + if (IsEnabled()) + Interlocked.Add(ref _bytesRead, bytesRead); + } + + public void CommandStart(string sql) + { + if (IsEnabled()) { Interlocked.Increment(ref _totalCommands); Interlocked.Increment(ref _currentCommands); - NpgsqlSqlEventSource.Log.CommandStart(sql); } + NpgsqlSqlEventSource.Log.CommandStart(sql); + } - [MethodImpl(MethodImplOptions.NoInlining)] - public void CommandStop() - { + public void CommandStop() + { + if (IsEnabled()) Interlocked.Decrement(ref _currentCommands); - NpgsqlSqlEventSource.Log.CommandStop(); - } + NpgsqlSqlEventSource.Log.CommandStop(); + } - internal void CommandStartPrepared() => Interlocked.Increment(ref _totalPreparedCommands); + internal void CommandStartPrepared() + { + if (IsEnabled()) + Interlocked.Increment(ref _totalPreparedCommands); + } - internal void CommandFailed() => Interlocked.Increment(ref _failedCommands); + internal void CommandFailed() + { + if (IsEnabled()) + Interlocked.Increment(ref _failedCommands); + } - internal void PoolCreated() => Interlocked.Increment(ref _pools); + internal void DataSourceCreated(NpgsqlDataSource dataSource) + { +#if !NETSTANDARD2_0 + lock (_dataSourcesLock) + { + _dataSources.Add(dataSource, null); + } +#endif + } - internal void MultiplexingBatchSent(int numCommands, int waits, Stopwatch stopwatch) + internal void MultiplexingBatchSent(int numCommands, Stopwatch stopwatch) + { + // TODO: CAS loop instead of 3 separate interlocked operations? + if (IsEnabled()) { - // TODO: CAS loop instead of 4 separate interlocked operations? Interlocked.Increment(ref _multiplexingBatchesSent); Interlocked.Add(ref _multiplexingCommandsSent, numCommands); - Interlocked.Add(ref _multiplexingWaits, waits); Interlocked.Add(ref _multiplexingTicksWritten, stopwatch.ElapsedTicks); } + } #if !NETSTANDARD2_0 - static int GetIdleConnections() + double GetDataSourceCount() + { + lock (_dataSourcesLock) { - // Note: there's no attempt here to be coherent in terms of race conditions, especially not with regards - // to different counters. So idle and busy and be unsynchronized, as they're not polled together. - var sum = 0; - foreach (var kv in PoolManager.Pools) - { - var pool = kv.Pool; - if (pool == null) - return sum; - sum += pool.Statistics.Idle; - } - return sum; + return _dataSources.Count; } + } - static int GetBusyConnections() - { - // Note: there's no attempt here to be coherent in terms of race conditions, especially not with regards - // to different counters. So idle and busy and be unsynchronized, as they're not polled together. - var sum = 0; - foreach (var kv in PoolManager.Pools) - { - var pool = kv.Pool; - if (pool == null) - return sum; - var (_, _, busy) = pool.Statistics; - sum += busy; - } - return sum; - } + double GetMultiplexingAverageCommandsPerBatch() + { + var batchesSent = Interlocked.Read(ref _multiplexingBatchesSent); + if (batchesSent == 0) + return -1; - protected override void OnEventCommand(EventCommandEventArgs command) - { - if (command.Command == EventCommand.Enable) - { - // Comment taken from RuntimeEventSource in CoreCLR - // NOTE: These counters will NOT be disposed on disable command because we may be introducing - // a race condition by doing that. We still want to create these lazily so that we aren't adding - // overhead by at all times even when counters aren't enabled. - // On disable, PollingCounters will stop polling for values so it should be fine to leave them around. + var commandsSent = (double)Interlocked.Read(ref _multiplexingCommandsSent); + return commandsSent / batchesSent; + } - _bytesWrittenPerSecondCounter = new IncrementingPollingCounter("bytes-written-per-second", this, () => Interlocked.Read(ref _bytesWritten)) - { - DisplayName = "Bytes Written", - DisplayRateTimeScale = TimeSpan.FromSeconds(1) - }; + double GetMultiplexingAverageWriteTimePerBatch() + { + var batchesSent = Interlocked.Read(ref _multiplexingBatchesSent); + if (batchesSent == 0) + return -1; - _bytesReadPerSecondCounter = new IncrementingPollingCounter("bytes-read-per-second", this, () => Interlocked.Read(ref _bytesRead)) - { - DisplayName = "Bytes Read", - DisplayRateTimeScale = TimeSpan.FromSeconds(1) - }; + var ticksWritten = (double)Interlocked.Read(ref _multiplexingTicksWritten); + return ticksWritten / batchesSent / 1000; + } - _commandsPerSecondCounter = new IncrementingPollingCounter("commands-per-second", this, () => Interlocked.Read(ref _totalCommands)) - { - DisplayName = "Command Rate", - DisplayRateTimeScale = TimeSpan.FromSeconds(1) - }; + protected override void OnEventCommand(EventCommandEventArgs command) + { + if (command.Command == EventCommand.Enable) + { + // Comment taken from RuntimeEventSource in CoreCLR + // NOTE: These counters will NOT be disposed on disable command because we may be introducing + // a race condition by doing that. We still want to create these lazily so that we aren't adding + // overhead by at all times even when counters aren't enabled. + // On disable, PollingCounters will stop polling for values so it should be fine to leave them around. - _totalCommandsCounter = new PollingCounter("total-commands", this, () => Interlocked.Read(ref _totalCommands)) - { - DisplayName = "Total Commands", - }; + _bytesWrittenPerSecondCounter = new IncrementingPollingCounter("bytes-written-per-second", this, () => Interlocked.Read(ref _bytesWritten)) + { + DisplayName = "Bytes Written", + DisplayRateTimeScale = TimeSpan.FromSeconds(1) + }; - _currentCommandsCounter = new PollingCounter("current-commands", this, () => Interlocked.Read(ref _currentCommands)) - { - DisplayName = "Current Commands" - }; + _bytesReadPerSecondCounter = new IncrementingPollingCounter("bytes-read-per-second", this, () => Interlocked.Read(ref _bytesRead)) + { + DisplayName = "Bytes Read", + DisplayRateTimeScale = TimeSpan.FromSeconds(1) + }; - _failedCommandsCounter = new PollingCounter("failed-commands", this, () => Interlocked.Read(ref _failedCommands)) - { - DisplayName = "Failed Commands" - }; + _commandsPerSecondCounter = new IncrementingPollingCounter("commands-per-second", this, () => Interlocked.Read(ref _totalCommands)) + { + DisplayName = "Command Rate", + DisplayRateTimeScale = TimeSpan.FromSeconds(1) + }; - _preparedCommandsRatioCounter = new PollingCounter( - "prepared-commands-ratio", - this, - () => (double)Interlocked.Read(ref _totalPreparedCommands) / Interlocked.Read(ref _totalCommands)) - { - DisplayName = "Prepared Commands Ratio", - DisplayUnits = "%" - }; + _totalCommandsCounter = new PollingCounter("total-commands", this, () => Interlocked.Read(ref _totalCommands)) + { + DisplayName = "Total Commands", + }; - _poolsCounter = new PollingCounter("connection-pools", this, () => _pools) - { - DisplayName = "Connection Pools" - }; + _currentCommandsCounter = new PollingCounter("current-commands", this, () => Interlocked.Read(ref _currentCommands)) + { + DisplayName = "Current Commands" + }; - _idleConnectionsCounter = new PollingCounter("idle-connections", this, () => GetIdleConnections()) - { - DisplayName = "Idle Connections" - }; + _failedCommandsCounter = new PollingCounter("failed-commands", this, () => Interlocked.Read(ref _failedCommands)) + { + DisplayName = "Failed Commands" + }; - _busyConnectionsCounter = new PollingCounter("busy-connections", this, () => GetBusyConnections()) - { - DisplayName = "Busy Connections" - }; + _preparedCommandsRatioCounter = new PollingCounter( + "prepared-commands-ratio", + this, + () => (double)Interlocked.Read(ref _totalPreparedCommands) / Interlocked.Read(ref _totalCommands) * 100) + { + DisplayName = "Prepared Commands Ratio", + DisplayUnits = "%" + }; - _multiplexingAverageCommandsPerBatchCounter = new PollingCounter("multiplexing-average-commands-per-batch", this, () => (double)Interlocked.Read(ref _multiplexingCommandsSent) / Interlocked.Read(ref _multiplexingBatchesSent)) - { - DisplayName = "Average commands per multiplexing batch" - }; + _poolsCounter = new PollingCounter("connection-pools", this, GetDataSourceCount) + { + DisplayName = "Connection Pools" + }; - _multiplexingAverageWaitsPerBatchCounter = new PollingCounter("multiplexing-average-waits-per-batch", this, () => (double)Interlocked.Read(ref _multiplexingWaits) / Interlocked.Read(ref _multiplexingBatchesSent)) - { - DisplayName = "Average waits per multiplexing batch" - }; + _multiplexingAverageCommandsPerBatchCounter = new PollingCounter("multiplexing-average-commands-per-batch", this, GetMultiplexingAverageCommandsPerBatch) + { + DisplayName = "Average commands per multiplexing batch" + }; - _multiplexingAverageWriteTimePerBatchCounter = new PollingCounter("multiplexing-average-write-time-per-batch", this, () => (double)Interlocked.Read(ref _multiplexingTicksWritten) / Interlocked.Read(ref _multiplexingBatchesSent) / 1000) + _multiplexingAverageWriteTimePerBatchCounter = new PollingCounter("multiplexing-average-write-time-per-batch", this, GetMultiplexingAverageWriteTimePerBatch) + { + DisplayName = "Average write time per multiplexing batch", + DisplayUnits = "us" + }; + lock (_dataSourcesLock) + { + foreach (var dataSource in _dataSources.Keys) { - DisplayName = "Average write time per multiplexing batch (us)", - DisplayUnits = "us" - }; + if (!_dataSources[dataSource].HasValue) + { + _dataSources[dataSource] = ( + new PollingCounter($"Idle Connections ({dataSource.Settings.ToStringWithoutPassword()}])", this, () => dataSource.Statistics.Idle), + new PollingCounter($"Busy Connections ({dataSource.Settings.ToStringWithoutPassword()}])", this, () => dataSource.Statistics.Busy)); + } + } } } -#endif } + +#endif } diff --git a/src/Npgsql/NpgsqlException.cs b/src/Npgsql/NpgsqlException.cs index bf63305543..57c47a514c 100644 --- a/src/Npgsql/NpgsqlException.cs +++ b/src/Npgsql/NpgsqlException.cs @@ -4,60 +4,76 @@ using System.Net.Sockets; using System.Runtime.Serialization; -namespace Npgsql +namespace Npgsql; + +/// +/// The exception that is thrown when server-related issues occur. +/// +/// +/// PostgreSQL errors (e.g. query SQL issues, constraint violations) are raised via +/// which is a subclass of this class. +/// Purely Npgsql-related issues which aren't related to the server will be raised +/// via the standard CLR exceptions (e.g. ArgumentException). +/// +[Serializable] +public class NpgsqlException : DbException { /// - /// The exception that is thrown when server-related issues occur. + /// Initializes a new instance of the class. + /// + public NpgsqlException() {} + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified. + public NpgsqlException(string? message, Exception? innerException) + : base(message, innerException) {} + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The message that describes the error. + public NpgsqlException(string? message) + : base(message) { } + + /// + /// Specifies whether the exception is considered transient, that is, whether retrying the operation could + /// succeed (e.g. a network error or a timeout). /// - /// - /// PostgreSQL errors (e.g. query SQL issues, constraint violations) are raised via - /// which is a subclass of this class. - /// Purely Npgsql-related issues which aren't related to the server will be raised - /// via the standard CLR exceptions (e.g. ArgumentException). - /// - [Serializable] - public class NpgsqlException : DbException - { - /// - /// Initializes a new instance of the class. - /// - public NpgsqlException() {} - - /// - /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. - /// - /// The error message that explains the reason for the exception. - /// The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified. - public NpgsqlException(string? message, Exception? innerException) - : base(message, innerException) {} - - /// - /// Initializes a new instance of the class with a specified error message. - /// - /// The message that describes the error. - public NpgsqlException(string? message) - : base(message) { } - - /// - /// Specifies whether the exception is considered transient, that is, whether retrying the operation could - /// succeed (e.g. a network error or a timeout). - /// -#if NET - public override bool IsTransient +#if NET5_0_OR_GREATER + public override bool IsTransient #else - public virtual bool IsTransient + public virtual bool IsTransient #endif - => InnerException is IOException || InnerException is SocketException || InnerException is TimeoutException; + => InnerException is IOException or SocketException or TimeoutException or NpgsqlException { IsTransient: true }; - #region Serialization +#if NET6_0_OR_GREATER + /// + public new NpgsqlBatchCommand? BatchCommand { get; set; } - /// - /// Initializes a new instance of the class with serialized data. - /// - /// The SerializationInfo that holds the serialized object data about the exception being thrown. - /// The StreamingContext that contains contextual information about the source or destination. - protected internal NpgsqlException(SerializationInfo info, StreamingContext context) : base(info, context) {} + /// + protected override DbBatchCommand? DbBatchCommand => BatchCommand; +#else + /// + /// If the exception was thrown as a result of executing a , references the within + /// the batch which triggered the exception. Otherwise . + /// + public NpgsqlBatchCommand? BatchCommand { get; set; } +#endif + + #region Serialization + + /// + /// Initializes a new instance of the class with serialized data. + /// + /// The SerializationInfo that holds the serialized object data about the exception being thrown. + /// The StreamingContext that contains contextual information about the source or destination. +#if NET8_0_OR_GREATER + [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] +#endif + protected internal NpgsqlException(SerializationInfo info, StreamingContext context) : base(info, context) {} - #endregion - } + #endregion } diff --git a/src/Npgsql/NpgsqlFactory.cs b/src/Npgsql/NpgsqlFactory.cs index 302aa00576..7d21a917a0 100644 --- a/src/Npgsql/NpgsqlFactory.cs +++ b/src/Npgsql/NpgsqlFactory.cs @@ -1,112 +1,90 @@ using System; using System.Data.Common; -using System.Reflection; -using Npgsql.Logging; +using System.Diagnostics.CodeAnalysis; -namespace Npgsql +namespace Npgsql; + +/// +/// A factory to create instances of various Npgsql objects. +/// +[Serializable] +public sealed class NpgsqlFactory : DbProviderFactory, IServiceProvider { /// - /// A factory to create instances of various Npgsql objects. + /// Gets an instance of the . + /// This can be used to retrieve strongly typed data objects. + /// + public static readonly NpgsqlFactory Instance = new(); + + NpgsqlFactory() {} + + /// + /// Returns a strongly typed instance. + /// + public override DbCommand CreateCommand() => new NpgsqlCommand(); + + /// + /// Returns a strongly typed instance. + /// + public override DbConnection CreateConnection() => new NpgsqlConnection(); + + /// + /// Returns a strongly typed instance. + /// + public override DbParameter CreateParameter() => new NpgsqlParameter(); + + /// + /// Returns a strongly typed instance. /// - [Serializable] - public sealed class NpgsqlFactory : DbProviderFactory, IServiceProvider - { - /// - /// Gets an instance of the . - /// This can be used to retrieve strongly typed data objects. - /// - public static readonly NpgsqlFactory Instance = new NpgsqlFactory(); - - NpgsqlFactory() {} - - /// - /// Returns a strongly typed instance. - /// - public override DbCommand CreateCommand() => new NpgsqlCommand(); - - /// - /// Returns a strongly typed instance. - /// - public override DbConnection CreateConnection() => new NpgsqlConnection(); - - /// - /// Returns a strongly typed instance. - /// - public override DbParameter CreateParameter() => new NpgsqlParameter(); - - /// - /// Returns a strongly typed instance. - /// - public override DbConnectionStringBuilder CreateConnectionStringBuilder() => new NpgsqlConnectionStringBuilder(); - - /// - /// Returns a strongly typed instance. - /// - public override DbCommandBuilder CreateCommandBuilder() => new NpgsqlCommandBuilder(); - - /// - /// Returns a strongly typed instance. - /// - public override DbDataAdapter CreateDataAdapter() => new NpgsqlDataAdapter(); + public override DbConnectionStringBuilder CreateConnectionStringBuilder() => new NpgsqlConnectionStringBuilder(); + + /// + /// Returns a strongly typed instance. + /// + public override DbCommandBuilder CreateCommandBuilder() => new NpgsqlCommandBuilder(); + + /// + /// Returns a strongly typed instance. + /// + public override DbDataAdapter CreateDataAdapter() => new NpgsqlDataAdapter(); #if !NETSTANDARD2_0 - /// - /// Specifies whether the specific supports the class. - /// - public override bool CanCreateDataAdapter => true; - - /// - /// Specifies whether the specific supports the class. - /// - public override bool CanCreateCommandBuilder => true; + /// + /// Specifies whether the specific supports the class. + /// + public override bool CanCreateDataAdapter => true; + + /// + /// Specifies whether the specific supports the class. + /// + public override bool CanCreateCommandBuilder => true; #endif - #region IServiceProvider Members - - /// - /// Gets the service object of the specified type. - /// - /// An object that specifies the type of service object to get. - /// A service object of type serviceType, or null if there is no service object of type serviceType. - - public object? GetService(Type serviceType) - { - if (serviceType == null) - throw new ArgumentNullException(nameof(serviceType)); - - // In legacy Entity Framework, this is the entry point for obtaining Npgsql's - // implementation of DbProviderServices. We use reflection for all types to - // avoid any dependencies on EF stuff in this project. EF6 (and of course EF Core) do not use this method. - - if (serviceType.FullName != "System.Data.Common.DbProviderServices") - return null; - - // User has requested a legacy EF DbProviderServices implementation. Check our cache first. - if (_legacyEntityFrameworkServices != null) - return _legacyEntityFrameworkServices; - - // First time, attempt to find the EntityFramework5.Npgsql assembly and load the type via reflection - var assemblyName = typeof(NpgsqlFactory).GetTypeInfo().Assembly.GetName(); - assemblyName.Name = "EntityFramework5.Npgsql"; - Assembly npgsqlEfAssembly; - try { - npgsqlEfAssembly = Assembly.Load(new AssemblyName(assemblyName.FullName)); - } catch { - return null; - } - - Type? npgsqlServicesType; - if ((npgsqlServicesType = npgsqlEfAssembly.GetType("Npgsql.NpgsqlServices")) == null || - npgsqlServicesType.GetProperty("Instance") == null) - throw new Exception("EntityFramework5.Npgsql assembly does not seem to contain the correct type!"); - - return _legacyEntityFrameworkServices = npgsqlServicesType - .GetProperty("Instance", BindingFlags.Public | BindingFlags.Static)! - .GetMethod!.Invoke(null, new object[0]); - } - - static object? _legacyEntityFrameworkServices; - - #endregion - } +#if NET6_0_OR_GREATER + /// + public override bool CanCreateBatch => true; + + /// + public override DbBatch CreateBatch() => new NpgsqlBatch(); + + /// + public override DbBatchCommand CreateBatchCommand() => new NpgsqlBatchCommand(); +#endif + +#if NET7_0_OR_GREATER + /// + public override DbDataSource CreateDataSource(string connectionString) + => NpgsqlDataSource.Create(connectionString); +#endif + + #region IServiceProvider Members + + /// + /// Gets the service object of the specified type. + /// + /// An object that specifies the type of service object to get. + /// A service object of type serviceType, or null if there is no service object of type serviceType. + public object? GetService(Type serviceType) => null; + + #endregion } diff --git a/src/Npgsql/NpgsqlLargeObjectManager.cs b/src/Npgsql/NpgsqlLargeObjectManager.cs index 0615915ed7..2bc6c02751 100644 --- a/src/Npgsql/NpgsqlLargeObjectManager.cs +++ b/src/Npgsql/NpgsqlLargeObjectManager.cs @@ -1,253 +1,242 @@ using System; +using Npgsql.Util; using System.Data; +using System.Text; using System.Threading; using System.Threading.Tasks; -namespace Npgsql +namespace Npgsql; + +/// +/// Large object manager. This class can be used to store very large files in a PostgreSQL database. +/// +[Obsolete("NpgsqlLargeObjectManager allows manipulating PostgreSQL large objects via publicly available PostgreSQL functions (lo_read, lo_write); call these yourself directly.")] +public class NpgsqlLargeObjectManager { + const int InvWrite = 0x00020000; + const int InvRead = 0x00040000; + + internal NpgsqlConnection Connection { get; } + /// - /// Large object manager. This class can be used to store very large files in a PostgreSQL database. + /// The largest chunk size (in bytes) read and write operations will read/write each roundtrip to the network. Default 4 MB. /// - public class NpgsqlLargeObjectManager - { - const int InvWrite = 0x00020000; - const int InvRead = 0x00040000; + public int MaxTransferBlockSize { get; set; } - internal NpgsqlConnection Connection { get; } + /// + /// Creates an NpgsqlLargeObjectManager for this connection. The connection must be opened to perform remote operations. + /// + /// + public NpgsqlLargeObjectManager(NpgsqlConnection connection) + { + Connection = connection; + MaxTransferBlockSize = 4 * 1024 * 1024; // 4MB + } - /// - /// The largest chunk size (in bytes) read and write operations will read/write each roundtrip to the network. Default 4 MB. - /// - public int MaxTransferBlockSize { get; set; } + /// + /// Execute a function + /// + internal async Task ExecuteFunction(bool async, string function, CancellationToken cancellationToken, params object[] arguments) + { + using var command = Connection.CreateCommand(); + var stringBuilder = new StringBuilder("SELECT * FROM ").Append(function).Append('('); - /// - /// Creates an NpgsqlLargeObjectManager for this connection. The connection must be opened to perform remote operations. - /// - /// - public NpgsqlLargeObjectManager(NpgsqlConnection connection) + for (var i = 0; i < arguments.Length; i++) { - Connection = connection; - MaxTransferBlockSize = 4 * 1024 * 1024; // 4MB + if (i > 0) + stringBuilder.Append(", "); + stringBuilder.Append('$').Append(i + 1); + command.Parameters.Add(new NpgsqlParameter { Value = arguments[i] }); } - /// - /// Execute a function - /// - internal async Task ExecuteFunction(string function, bool async, CancellationToken cancellationToken, params object[] arguments) - { - using var command = new NpgsqlCommand(function, Connection) - { - CommandType = CommandType.StoredProcedure, - CommandText = function - }; + stringBuilder.Append(')'); + command.CommandText = stringBuilder.ToString(); - foreach (var argument in arguments) - command.Parameters.Add(new NpgsqlParameter { Value = argument }); + return (T)(async ? await command.ExecuteScalarAsync(cancellationToken).ConfigureAwait(false) : command.ExecuteScalar())!; + } - return (T)(async ? await command.ExecuteScalarAsync(cancellationToken) : command.ExecuteScalar())!; - } + /// + /// Execute a function that returns a byte array + /// + /// + internal async Task ExecuteFunctionGetBytes( + bool async, string function, byte[] buffer, int offset, int len, CancellationToken cancellationToken, params object[] arguments) + { + using var command = Connection.CreateCommand(); + var stringBuilder = new StringBuilder("SELECT * FROM ").Append(function).Append('('); - /// - /// Execute a function that returns a byte array - /// - /// - internal async Task ExecuteFunctionGetBytes( - string function, byte[] buffer, int offset, int len, bool async, CancellationToken cancellationToken, params object[] arguments) + for (var i = 0; i < arguments.Length; i++) { - using var command = new NpgsqlCommand(function, Connection) - { - CommandType = CommandType.StoredProcedure - }; - - foreach (var argument in arguments) - command.Parameters.Add(new NpgsqlParameter { Value = argument }); + if (i > 0) + stringBuilder.Append(", "); + stringBuilder.Append('$').Append(i + 1); + command.Parameters.Add(new NpgsqlParameter { Value = arguments[i] }); + } - using var reader = async - ? await command.ExecuteReaderAsync(CommandBehavior.SequentialAccess, cancellationToken) - : command.ExecuteReader(CommandBehavior.SequentialAccess); + stringBuilder.Append(')'); + command.CommandText = stringBuilder.ToString(); + var reader = async + ? await command.ExecuteReaderAsync(CommandBehavior.SequentialAccess, cancellationToken).ConfigureAwait(false) + : command.ExecuteReader(CommandBehavior.SequentialAccess); + try + { if (async) - await reader.ReadAsync(cancellationToken); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); else reader.Read(); return (int)reader.GetBytes(0, 0, buffer, offset, len); } - - /// - /// Create an empty large object in the database. If an oid is specified but is already in use, an PostgresException will be thrown. - /// - /// A preferred oid, or specify 0 if one should be automatically assigned - /// The oid for the large object created - /// If an oid is already in use - public uint Create(uint preferredOid = 0) => Create(preferredOid, false).GetAwaiter().GetResult(); - - // Review unused parameters - /// - /// Create an empty large object in the database. If an oid is specified but is already in use, an PostgresException will be thrown. - /// - /// A preferred oid, or specify 0 if one should be automatically assigned - /// The token to monitor for cancellation requests. The default value is . - /// The oid for the large object created - /// If an oid is already in use - public Task CreateAsync(uint preferredOid, CancellationToken cancellationToken = default) - => Create(preferredOid, true, cancellationToken); - - Task Create(uint preferredOid, bool async, CancellationToken cancellationToken = default) - => ExecuteFunction("lo_create", async, cancellationToken, (int)preferredOid); - - /// - /// Opens a large object on the backend, returning a stream controlling this remote object. - /// A transaction snapshot is taken by the backend when the object is opened with only read permissions. - /// When reading from this object, the contents reflects the time when the snapshot was taken. - /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. - /// - /// Oid of the object - /// An NpgsqlLargeObjectStream - public NpgsqlLargeObjectStream OpenRead(uint oid) - => OpenRead(oid, false).GetAwaiter().GetResult(); - - /// - /// Opens a large object on the backend, returning a stream controlling this remote object. - /// A transaction snapshot is taken by the backend when the object is opened with only read permissions. - /// When reading from this object, the contents reflects the time when the snapshot was taken. - /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. - /// - /// Oid of the object - /// The token to monitor for cancellation requests. The default value is . - /// An NpgsqlLargeObjectStream - public Task OpenReadAsync(uint oid, CancellationToken cancellationToken = default) + finally { - using (NoSynchronizationContextScope.Enter()) - return OpenRead(oid, true, cancellationToken); + if (async) + await reader.DisposeAsync().ConfigureAwait(false); + else + reader.Dispose(); } + } - async Task OpenRead(uint oid, bool async, CancellationToken cancellationToken = default) - { - var fd = await ExecuteFunction("lo_open", async, cancellationToken, (int)oid, InvRead); - return new NpgsqlLargeObjectStream(this, fd, false); - } + /// + /// Create an empty large object in the database. If an oid is specified but is already in use, an PostgresException will be thrown. + /// + /// A preferred oid, or specify 0 if one should be automatically assigned + /// The oid for the large object created + /// If an oid is already in use + public uint Create(uint preferredOid = 0) => Create(preferredOid, false).GetAwaiter().GetResult(); - /// - /// Opens a large object on the backend, returning a stream controlling this remote object. - /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. - /// - /// Oid of the object - /// An NpgsqlLargeObjectStream - public NpgsqlLargeObjectStream OpenReadWrite(uint oid) - => OpenReadWrite(oid, false).GetAwaiter().GetResult(); - - /// - /// Opens a large object on the backend, returning a stream controlling this remote object. - /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. - /// - /// Oid of the object - /// The token to monitor for cancellation requests. The default value is . - /// An NpgsqlLargeObjectStream - public Task OpenReadWriteAsync(uint oid, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return OpenReadWrite(oid, true, cancellationToken); - } + // Review unused parameters + /// + /// Create an empty large object in the database. If an oid is specified but is already in use, an PostgresException will be thrown. + /// + /// A preferred oid, or specify 0 if one should be automatically assigned + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The oid for the large object created + /// If an oid is already in use + public Task CreateAsync(uint preferredOid, CancellationToken cancellationToken = default) + => Create(preferredOid, true, cancellationToken); + + Task Create(uint preferredOid, bool async, CancellationToken cancellationToken = default) + => ExecuteFunction(async, "lo_create", cancellationToken, (int)preferredOid); - async Task OpenReadWrite(uint oid, bool async, CancellationToken cancellationToken = default) - { - var fd = await ExecuteFunction("lo_open", async, cancellationToken, (int)oid, InvRead | InvWrite); - return new NpgsqlLargeObjectStream(this, fd, true); - } + /// + /// Opens a large object on the backend, returning a stream controlling this remote object. + /// A transaction snapshot is taken by the backend when the object is opened with only read permissions. + /// When reading from this object, the contents reflects the time when the snapshot was taken. + /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. + /// + /// Oid of the object + /// An NpgsqlLargeObjectStream + public NpgsqlLargeObjectStream OpenRead(uint oid) + => OpenRead(async: false, oid).GetAwaiter().GetResult(); - /// - /// Deletes a large object on the backend. - /// - /// Oid of the object to delete - public void Unlink(uint oid) - => ExecuteFunction("lo_unlink", false, CancellationToken.None, (int)oid).GetAwaiter().GetResult(); - - /// - /// Deletes a large object on the backend. - /// - /// Oid of the object to delete - /// The token to monitor for cancellation requests. The default value is . - public Task UnlinkAsync(uint oid, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return ExecuteFunction("lo_unlink", true, cancellationToken, (int)oid); - } + /// + /// Opens a large object on the backend, returning a stream controlling this remote object. + /// A transaction snapshot is taken by the backend when the object is opened with only read permissions. + /// When reading from this object, the contents reflects the time when the snapshot was taken. + /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. + /// + /// Oid of the object + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// An NpgsqlLargeObjectStream + public Task OpenReadAsync(uint oid, CancellationToken cancellationToken = default) + => OpenRead(async: true, oid, cancellationToken); + + async Task OpenRead(bool async, uint oid, CancellationToken cancellationToken = default) + { + var fd = await ExecuteFunction(async, "lo_open", cancellationToken, (int)oid, InvRead).ConfigureAwait(false); + return new NpgsqlLargeObjectStream(this, fd, false); + } - /// - /// Exports a large object stored in the database to a file on the backend. This requires superuser permissions. - /// - /// Oid of the object to export - /// Path to write the file on the backend - public void ExportRemote(uint oid, string path) - => ExecuteFunction("lo_export", false, CancellationToken.None, (int)oid, path).GetAwaiter().GetResult(); - - /// - /// Exports a large object stored in the database to a file on the backend. This requires superuser permissions. - /// - /// Oid of the object to export - /// Path to write the file on the backend - /// The token to monitor for cancellation requests. The default value is . - public Task ExportRemoteAsync(uint oid, string path, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return ExecuteFunction("lo_export", true, cancellationToken, (int)oid, path); - } + /// + /// Opens a large object on the backend, returning a stream controlling this remote object. + /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. + /// + /// Oid of the object + /// An NpgsqlLargeObjectStream + public NpgsqlLargeObjectStream OpenReadWrite(uint oid) + => OpenReadWrite(async: false, oid).GetAwaiter().GetResult(); - /// - /// Imports a large object to be stored as a large object in the database from a file stored on the backend. This requires superuser permissions. - /// - /// Path to read the file on the backend - /// A preferred oid, or specify 0 if one should be automatically assigned - public void ImportRemote(string path, uint oid = 0) - => ExecuteFunction("lo_import", false, CancellationToken.None, path, (int)oid).GetAwaiter().GetResult(); - - /// - /// Imports a large object to be stored as a large object in the database from a file stored on the backend. This requires superuser permissions. - /// - /// Path to read the file on the backend - /// A preferred oid, or specify 0 if one should be automatically assigned - /// The token to monitor for cancellation requests. The default value is . - public Task ImportRemoteAsync(string path, uint oid, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return ExecuteFunction("lo_import", true, cancellationToken, path, (int)oid); - } + /// + /// Opens a large object on the backend, returning a stream controlling this remote object. + /// Note that this method, as well as operations on the stream must be wrapped inside a transaction. + /// + /// Oid of the object + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// An NpgsqlLargeObjectStream + public Task OpenReadWriteAsync(uint oid, CancellationToken cancellationToken = default) + => OpenReadWrite(async: true, oid, cancellationToken); + + async Task OpenReadWrite(bool async, uint oid, CancellationToken cancellationToken = default) + { + var fd = await ExecuteFunction(async, "lo_open", cancellationToken, (int)oid, InvRead | InvWrite).ConfigureAwait(false); + return new NpgsqlLargeObjectStream(this, fd, true); + } - /// - /// Since PostgreSQL 9.3, large objects larger than 2GB can be handled, up to 4TB. - /// This property returns true whether the PostgreSQL version is >= 9.3. - /// - public bool Has64BitSupport => Connection.PostgreSqlVersion >= new Version(9, 3); + /// + /// Deletes a large object on the backend. + /// + /// Oid of the object to delete + public void Unlink(uint oid) + => ExecuteFunction(async: false, "lo_unlink", CancellationToken.None, (int)oid).GetAwaiter().GetResult(); - /* - internal enum Function : uint - { - lo_open = 952, - lo_close = 953, - loread = 954, - lowrite = 955, - lo_lseek = 956, - lo_lseek64 = 3170, // Since PostgreSQL 9.3 - lo_creat = 957, - lo_create = 715, - lo_tell = 958, - lo_tell64 = 3171, // Since PostgreSQL 9.3 - lo_truncate = 1004, - lo_truncate64 = 3172, // Since PostgreSQL 9.3 - - // These four are available since PostgreSQL 9.4 - lo_from_bytea = 3457, - lo_get = 3458, - lo_get_fragment = 3459, - lo_put = 3460, - - lo_unlink = 964, - - lo_import = 764, - lo_import_with_oid = 767, - lo_export = 765, - } - */ - } + /// + /// Deletes a large object on the backend. + /// + /// Oid of the object to delete + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task UnlinkAsync(uint oid, CancellationToken cancellationToken = default) + => ExecuteFunction(async: true, "lo_unlink", cancellationToken, (int)oid); + + /// + /// Exports a large object stored in the database to a file on the backend. This requires superuser permissions. + /// + /// Oid of the object to export + /// Path to write the file on the backend + public void ExportRemote(uint oid, string path) + => ExecuteFunction(async: false, "lo_export", CancellationToken.None, (int)oid, path).GetAwaiter().GetResult(); + + /// + /// Exports a large object stored in the database to a file on the backend. This requires superuser permissions. + /// + /// Oid of the object to export + /// Path to write the file on the backend + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task ExportRemoteAsync(uint oid, string path, CancellationToken cancellationToken = default) + => ExecuteFunction(async: true, "lo_export", cancellationToken, (int)oid, path); + + /// + /// Imports a large object to be stored as a large object in the database from a file stored on the backend. This requires superuser permissions. + /// + /// Path to read the file on the backend + /// A preferred oid, or specify 0 if one should be automatically assigned + public void ImportRemote(string path, uint oid = 0) + => ExecuteFunction(async: false, "lo_import", CancellationToken.None, path, (int)oid).GetAwaiter().GetResult(); + + /// + /// Imports a large object to be stored as a large object in the database from a file stored on the backend. This requires superuser permissions. + /// + /// Path to read the file on the backend + /// A preferred oid, or specify 0 if one should be automatically assigned + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task ImportRemoteAsync(string path, uint oid, CancellationToken cancellationToken = default) + => ExecuteFunction(async: true, "lo_import", cancellationToken, path, (int)oid); + + /// + /// Since PostgreSQL 9.3, large objects larger than 2GB can be handled, up to 4TB. + /// This property returns true whether the PostgreSQL version is >= 9.3. + /// + public bool Has64BitSupport => Connection.PostgreSqlVersion.IsGreaterOrEqual(9, 3); } diff --git a/src/Npgsql/NpgsqlLargeObjectStream.cs b/src/Npgsql/NpgsqlLargeObjectStream.cs index 7cbdc9921a..2f3c8b19b0 100644 --- a/src/Npgsql/NpgsqlLargeObjectStream.cs +++ b/src/Npgsql/NpgsqlLargeObjectStream.cs @@ -1,308 +1,304 @@ -using System; +using Npgsql.Util; +using System; using System.IO; using System.Threading; using System.Threading.Tasks; -namespace Npgsql +namespace Npgsql; + +/// +/// An interface to remotely control the seekable stream for an opened large object on a PostgreSQL server. +/// Note that the OpenRead/OpenReadWrite method as well as all operations performed on this stream must be wrapped inside a database transaction. +/// +[Obsolete("NpgsqlLargeObjectStream allows manipulating PostgreSQL large objects via publicly available PostgreSQL functions (lo_read, lo_write); call these yourself directly.")] +public sealed class NpgsqlLargeObjectStream : Stream { - /// - /// An interface to remotely control the seekable stream for an opened large object on a PostgreSQL server. - /// Note that the OpenRead/OpenReadWrite method as well as all operations performed on this stream must be wrapped inside a database transaction. - /// - public sealed class NpgsqlLargeObjectStream : Stream + readonly NpgsqlLargeObjectManager _manager; + readonly int _fd; + long _pos; + readonly bool _writeable; + bool _disposed; + + internal NpgsqlLargeObjectStream(NpgsqlLargeObjectManager manager, int fd, bool writeable) { - readonly NpgsqlLargeObjectManager _manager; - readonly int _fd; - long _pos; - readonly bool _writeable; - bool _disposed; + _manager = manager; + _fd = fd; + _pos = 0; + _writeable = writeable; + } - internal NpgsqlLargeObjectStream(NpgsqlLargeObjectManager manager, int fd, bool writeable) - { - _manager = manager; - _fd = fd; - _pos = 0; - _writeable = writeable; - } + void CheckDisposed() + { + if (_disposed) + throw new InvalidOperationException("Object disposed"); + } - void CheckDisposed() - { - if (_disposed) - throw new InvalidOperationException("Object disposed"); - } + /// + /// Since PostgreSQL 9.3, large objects larger than 2GB can be handled, up to 4TB. + /// This property returns true whether the PostgreSQL version is >= 9.3. + /// + public bool Has64BitSupport => _manager.Connection.PostgreSqlVersion.IsGreaterOrEqual(9, 3); - /// - /// Since PostgreSQL 9.3, large objects larger than 2GB can be handled, up to 4TB. - /// This property returns true whether the PostgreSQL version is >= 9.3. - /// - public bool Has64BitSupport => _manager.Connection.PostgreSqlVersion >= new Version(9, 3); - - /// - /// Reads count bytes from the large object. The only case when fewer bytes are read is when end of stream is reached. - /// - /// The buffer where read data should be stored. - /// The offset in the buffer where the first byte should be read. - /// The maximum number of bytes that should be read. - /// How many bytes actually read, or 0 if end of file was already reached. - public override int Read(byte[] buffer, int offset, int count) - => Read(buffer, offset, count, false).GetAwaiter().GetResult(); - - /// - /// Reads count bytes from the large object. The only case when fewer bytes are read is when end of stream is reached. - /// - /// The buffer where read data should be stored. - /// The offset in the buffer where the first byte should be read. - /// The maximum number of bytes that should be read. - /// The token to monitor for cancellation requests. The default value is . - /// How many bytes actually read, or 0 if end of file was already reached. - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - using (NoSynchronizationContextScope.Enter()) - return Read(buffer, offset, count, true, cancellationToken); - } + /// + /// Reads count bytes from the large object. The only case when fewer bytes are read is when end of stream is reached. + /// + /// The buffer where read data should be stored. + /// The offset in the buffer where the first byte should be read. + /// The maximum number of bytes that should be read. + /// How many bytes actually read, or 0 if end of file was already reached. + public override int Read(byte[] buffer, int offset, int count) + => Read(async: false, buffer, offset, count).GetAwaiter().GetResult(); - async Task Read(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentOutOfRangeException(nameof(offset)); - if (count < 0) - throw new ArgumentOutOfRangeException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Invalid offset or count for this buffer"); + /// + /// Reads count bytes from the large object. The only case when fewer bytes are read is when end of stream is reached. + /// + /// The buffer where read data should be stored. + /// The offset in the buffer where the first byte should be read. + /// The maximum number of bytes that should be read. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// How many bytes actually read, or 0 if end of file was already reached. + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Read(async: true, buffer, offset, count, cancellationToken); + + async Task Read(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0) + throw new ArgumentOutOfRangeException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Invalid offset or count for this buffer"); - CheckDisposed(); + CheckDisposed(); - var chunkCount = Math.Min(count, _manager.MaxTransferBlockSize); - var read = 0; + var chunkCount = Math.Min(count, _manager.MaxTransferBlockSize); + var read = 0; - while (read < count) + while (read < count) + { + var bytesRead = await _manager.ExecuteFunctionGetBytes( + async, "loread", buffer, offset + read, count - read, cancellationToken, _fd, chunkCount).ConfigureAwait(false); + _pos += bytesRead; + read += bytesRead; + if (bytesRead < chunkCount) { - var bytesRead = await _manager.ExecuteFunctionGetBytes( - "loread", buffer, offset + read, count - read, async, cancellationToken, _fd, chunkCount); - _pos += bytesRead; - read += bytesRead; - if (bytesRead < chunkCount) - { - return read; - } + return read; } - return read; } + return read; + } - /// - /// Writes count bytes to the large object. - /// - /// The buffer to write data from. - /// The offset in the buffer at which to begin copying bytes. - /// The number of bytes to write. - public override void Write(byte[] buffer, int offset, int count) - => Write(buffer, offset, count, false).GetAwaiter().GetResult(); - - /// - /// Writes count bytes to the large object. - /// - /// The buffer to write data from. - /// The offset in the buffer at which to begin copying bytes. - /// The number of bytes to write. - /// The token to monitor for cancellation requests. The default value is . - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - using (NoSynchronizationContextScope.Enter()) - return Write(buffer, offset, count, true, cancellationToken); - } + /// + /// Writes count bytes to the large object. + /// + /// The buffer to write data from. + /// The offset in the buffer at which to begin copying bytes. + /// The number of bytes to write. + public override void Write(byte[] buffer, int offset, int count) + => Write(async: false, buffer, offset, count).GetAwaiter().GetResult(); - async Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentOutOfRangeException(nameof(offset)); - if (count < 0) - throw new ArgumentOutOfRangeException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Invalid offset or count for this buffer"); + /// + /// Writes count bytes to the large object. + /// + /// The buffer to write data from. + /// The offset in the buffer at which to begin copying bytes. + /// The number of bytes to write. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Write(async: true, buffer, offset, count, cancellationToken); + + async Task Write(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0) + throw new ArgumentOutOfRangeException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Invalid offset or count for this buffer"); - CheckDisposed(); + CheckDisposed(); - if (!_writeable) - throw new NotSupportedException("Write cannot be called on a stream opened with no write permissions"); + if (!_writeable) + throw new NotSupportedException("Write cannot be called on a stream opened with no write permissions"); - var totalWritten = 0; + var totalWritten = 0; - while (totalWritten < count) - { - var chunkSize = Math.Min(count - totalWritten, _manager.MaxTransferBlockSize); - var bytesWritten = await _manager.ExecuteFunction("lowrite", async, cancellationToken, _fd, new ArraySegment(buffer, offset + totalWritten, chunkSize)); - totalWritten += bytesWritten; + while (totalWritten < count) + { + var chunkSize = Math.Min(count - totalWritten, _manager.MaxTransferBlockSize); + var bytesWritten = await _manager.ExecuteFunction(async, "lowrite", cancellationToken, _fd, new ArraySegment(buffer, offset + totalWritten, chunkSize)).ConfigureAwait(false); + totalWritten += bytesWritten; - if (bytesWritten != chunkSize) - throw new InvalidOperationException($"Internal Npgsql bug, please report"); + if (bytesWritten != chunkSize) + throw new InvalidOperationException($"Internal Npgsql bug, please report"); - _pos += bytesWritten; - } + _pos += bytesWritten; } + } - /// - /// CanTimeout always returns false. - /// - public override bool CanTimeout => false; - - /// - /// CanRead always returns true, unless the stream has been closed. - /// - public override bool CanRead => !_disposed; - - /// - /// CanWrite returns true if the stream was opened with write permissions, and the stream has not been closed. - /// - public override bool CanWrite => _writeable && !_disposed; - - /// - /// CanSeek always returns true, unless the stream has been closed. - /// - public override bool CanSeek => !_disposed; - - /// - /// Returns the current position in the stream. Getting the current position does not need a round-trip to the server, however setting the current position does. - /// - public override long Position - { - get - { - CheckDisposed(); - return _pos; - } - set => Seek(value, SeekOrigin.Begin); - } + /// + /// CanTimeout always returns false. + /// + public override bool CanTimeout => false; + + /// + /// CanRead always returns true, unless the stream has been closed. + /// + public override bool CanRead => !_disposed; - /// - /// Gets the length of the large object. This internally seeks to the end of the stream to retrieve the length, and then back again. - /// - public override long Length => GetLength(false).GetAwaiter().GetResult(); + /// + /// CanWrite returns true if the stream was opened with write permissions, and the stream has not been closed. + /// + public override bool CanWrite => _writeable && !_disposed; - /// - /// Gets the length of the large object. This internally seeks to the end of the stream to retrieve the length, and then back again. - /// - /// The token to monitor for cancellation requests. The default value is . - public Task GetLengthAsync(CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return GetLength(true); - } + /// + /// CanSeek always returns true, unless the stream has been closed. + /// + public override bool CanSeek => !_disposed; - async Task GetLength(bool async) + /// + /// Returns the current position in the stream. Getting the current position does not need a round-trip to the server, however setting the current position does. + /// + public override long Position + { + get { CheckDisposed(); - var old = _pos; - var retval = await Seek(0, SeekOrigin.End, async); - if (retval != old) - await Seek(old, SeekOrigin.Begin, async); - return retval; + return _pos; } + set => Seek(value, SeekOrigin.Begin); + } - /// - /// Seeks in the stream to the specified position. This requires a round-trip to the backend. - /// - /// A byte offset relative to the origin parameter. - /// A value of type SeekOrigin indicating the reference point used to obtain the new position. - /// - public override long Seek(long offset, SeekOrigin origin) - => Seek(offset, origin, false).GetAwaiter().GetResult(); - - /// - /// Seeks in the stream to the specified position. This requires a round-trip to the backend. - /// - /// A byte offset relative to the origin parameter. - /// A value of type SeekOrigin indicating the reference point used to obtain the new position. - /// The token to monitor for cancellation requests. The default value is . - public Task SeekAsync(long offset, SeekOrigin origin, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return Seek(offset, origin, true, cancellationToken); - } + /// + /// Gets the length of the large object. This internally seeks to the end of the stream to retrieve the length, and then back again. + /// + public override long Length => GetLength(false).GetAwaiter().GetResult(); - async Task Seek(long offset, SeekOrigin origin, bool async, CancellationToken cancellationToken = default) - { - if (origin < SeekOrigin.Begin || origin > SeekOrigin.End) - throw new ArgumentException("Invalid origin"); - if (!Has64BitSupport && offset != (int)offset) - throw new ArgumentOutOfRangeException(nameof(offset), "offset must fit in 32 bits for PostgreSQL versions older than 9.3"); + /// + /// Gets the length of the large object. This internally seeks to the end of the stream to retrieve the length, and then back again. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task GetLengthAsync(CancellationToken cancellationToken = default) => GetLength(async: true); - CheckDisposed(); + async Task GetLength(bool async) + { + CheckDisposed(); + var old = _pos; + var retval = await Seek(async, 0, SeekOrigin.End).ConfigureAwait(false); + if (retval != old) + await Seek(async, old, SeekOrigin.Begin).ConfigureAwait(false); + return retval; + } - return _manager.Has64BitSupport - ? _pos = await _manager.ExecuteFunction("lo_lseek64", async, cancellationToken, _fd, offset, (int)origin) - : _pos = await _manager.ExecuteFunction("lo_lseek", async, cancellationToken, _fd, (int)offset, (int)origin); - } + /// + /// Seeks in the stream to the specified position. This requires a round-trip to the backend. + /// + /// A byte offset relative to the origin parameter. + /// A value of type SeekOrigin indicating the reference point used to obtain the new position. + /// + public override long Seek(long offset, SeekOrigin origin) + => Seek(async: false, offset, origin).GetAwaiter().GetResult(); - /// - /// Does nothing. - /// - public override void Flush() {} - - /// - /// Truncates or enlarges the large object to the given size. If enlarging, the large object is extended with null bytes. - /// For PostgreSQL versions earlier than 9.3, the value must fit in an Int32. - /// - /// Number of bytes to either truncate or enlarge the large object. - public override void SetLength(long value) - => SetLength(value, false).GetAwaiter().GetResult(); - - /// - /// Truncates or enlarges the large object to the given size. If enlarging, the large object is extended with null bytes. - /// For PostgreSQL versions earlier than 9.3, the value must fit in an Int32. - /// - /// Number of bytes to either truncate or enlarge the large object. - /// Cancellation token. - public Task SetLength(long value, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - using (NoSynchronizationContextScope.Enter()) - return SetLength(value, true, cancellationToken); - } + /// + /// Seeks in the stream to the specified position. This requires a round-trip to the backend. + /// + /// A byte offset relative to the origin parameter. + /// A value of type SeekOrigin indicating the reference point used to obtain the new position. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task SeekAsync(long offset, SeekOrigin origin, CancellationToken cancellationToken = default) + => Seek(async: true, offset, origin, cancellationToken); + + async Task Seek(bool async, long offset, SeekOrigin origin, CancellationToken cancellationToken = default) + { + if (origin < SeekOrigin.Begin || origin > SeekOrigin.End) + throw new ArgumentException("Invalid origin"); + if (!Has64BitSupport && offset != (int)offset) + throw new ArgumentOutOfRangeException(nameof(offset), "offset must fit in 32 bits for PostgreSQL versions older than 9.3"); - async Task SetLength(long value, bool async, CancellationToken cancellationToken = default) - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value)); - if (!Has64BitSupport && value != (int)value) - throw new ArgumentOutOfRangeException(nameof(value), "offset must fit in 32 bits for PostgreSQL versions older than 9.3"); + CheckDisposed(); - CheckDisposed(); + return _manager.Has64BitSupport + ? _pos = await _manager.ExecuteFunction(async, "lo_lseek64", cancellationToken, _fd, offset, (int)origin).ConfigureAwait(false) + : _pos = await _manager.ExecuteFunction(async, "lo_lseek", cancellationToken, _fd, (int)offset, (int)origin).ConfigureAwait(false); + } - if (!_writeable) - throw new NotSupportedException("SetLength cannot be called on a stream opened with no write permissions"); + /// + /// Does nothing. + /// + public override void Flush() {} - if (_manager.Has64BitSupport) - await _manager.ExecuteFunction("lo_truncate64", async, cancellationToken, _fd, value); - else - await _manager.ExecuteFunction("lo_truncate", async, cancellationToken, _fd, (int)value); - } + /// + /// Truncates or enlarges the large object to the given size. If enlarging, the large object is extended with null bytes. + /// For PostgreSQL versions earlier than 9.3, the value must fit in an Int32. + /// + /// Number of bytes to either truncate or enlarge the large object. + public override void SetLength(long value) + => SetLength(async: false, value).GetAwaiter().GetResult(); - /// - /// Releases resources at the backend allocated for this stream. - /// - public override void Close() + /// + /// Truncates or enlarges the large object to the given size. If enlarging, the large object is extended with null bytes. + /// For PostgreSQL versions earlier than 9.3, the value must fit in an Int32. + /// + /// Number of bytes to either truncate or enlarge the large object. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public Task SetLength(long value, CancellationToken cancellationToken) + => SetLength(async: true, value, cancellationToken); + + async Task SetLength(bool async, long value, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value)); + if (!Has64BitSupport && value != (int)value) + throw new ArgumentOutOfRangeException(nameof(value), "offset must fit in 32 bits for PostgreSQL versions older than 9.3"); + + CheckDisposed(); + + if (!_writeable) + throw new NotSupportedException("SetLength cannot be called on a stream opened with no write permissions"); + + if (_manager.Has64BitSupport) + await _manager.ExecuteFunction(async, "lo_truncate64", cancellationToken, _fd, value).ConfigureAwait(false); + else + await _manager.ExecuteFunction(async, "lo_truncate", cancellationToken, _fd, (int)value).ConfigureAwait(false); + } + + /// + /// Releases resources at the backend allocated for this stream. + /// + public override void Close() + { + if (!_disposed) { - if (!_disposed) - { - _manager.ExecuteFunction("lo_close", false, CancellationToken.None, _fd).GetAwaiter().GetResult(); - _disposed = true; - } + _manager.ExecuteFunction(async: false, "lo_close", CancellationToken.None, _fd).GetAwaiter().GetResult(); + _disposed = true; } + } - /// - /// Releases resources at the backend allocated for this stream, iff disposing is true. - /// - /// Whether to release resources allocated at the backend. - protected override void Dispose(bool disposing) + /// + /// Releases resources at the backend allocated for this stream, iff disposing is true. + /// + /// Whether to release resources allocated at the backend. + protected override void Dispose(bool disposing) + { + if (disposing) { - if (disposing) - { - Close(); - } + Close(); } } } diff --git a/src/Npgsql/NpgsqlLengthCache.cs b/src/Npgsql/NpgsqlLengthCache.cs deleted file mode 100644 index 7259218788..0000000000 --- a/src/Npgsql/NpgsqlLengthCache.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System.Collections.Generic; -using System.Diagnostics; -using Npgsql.TypeHandling; - -namespace Npgsql -{ - /// - /// An array of cached lengths for the parameters sending process. - /// - /// When sending parameters, lengths need to be calculated more than once (once for Bind, once for - /// an array, once for the string within that array). This cache optimized that. Lengths are added - /// to the cache, and then retrieved at the same order. - /// - public sealed class NpgsqlLengthCache - { - internal bool IsPopulated; - internal int Position; - internal List Lengths; - - internal NpgsqlLengthCache() => Lengths = new List(); - - internal NpgsqlLengthCache(int capacity) => Lengths = new List(capacity); - - /// - /// Stores a length value in the cache, to be fetched later via . - /// Called at the phase. - /// - /// The length parameter. - public int Set(int len) - { - Debug.Assert(!IsPopulated); - Lengths.Add(len); - Position++; - return len; - } - - /// - /// Retrieves a length value previously stored in the cache via . - /// Called at the writing phase, after validation has already occurred and the length cache is populated. - /// - /// - public int Get() - { - Debug.Assert(IsPopulated); - return Lengths[Position++]; - } - - internal int GetLast() - { - Debug.Assert(IsPopulated); - return Lengths[Position-1]; - } - - internal void Rewind() - { - Position = 0; - IsPopulated = true; - } - - internal void Clear() - { - Lengths.Clear(); - Position = 0; - IsPopulated = false; - } - } -} diff --git a/src/Npgsql/NpgsqlLoggingConfiguration.cs b/src/Npgsql/NpgsqlLoggingConfiguration.cs new file mode 100644 index 0000000000..745cf476cb --- /dev/null +++ b/src/Npgsql/NpgsqlLoggingConfiguration.cs @@ -0,0 +1,59 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Npgsql; + +/// +/// Configures Npgsql logging +/// +public class NpgsqlLoggingConfiguration +{ + internal static readonly NpgsqlLoggingConfiguration NullConfiguration + = new(NullLoggerFactory.Instance, isParameterLoggingEnabled: false); + + internal static ILoggerFactory GlobalLoggerFactory = NullLoggerFactory.Instance; + internal static bool GlobalIsParameterLoggingEnabled; + + internal NpgsqlLoggingConfiguration(ILoggerFactory loggerFactory, bool isParameterLoggingEnabled) + { + ConnectionLogger = loggerFactory.CreateLogger("Npgsql.Connection"); + CommandLogger = loggerFactory.CreateLogger("Npgsql.Command"); + TransactionLogger = loggerFactory.CreateLogger("Npgsql.Transaction"); + CopyLogger = loggerFactory.CreateLogger("Npgsql.Copy"); + ReplicationLogger = loggerFactory.CreateLogger("Npgsql.Replication"); + ExceptionLogger = loggerFactory.CreateLogger("Npgsql.Exception"); + + IsParameterLoggingEnabled = isParameterLoggingEnabled; + } + + internal ILogger ConnectionLogger { get; } + internal ILogger CommandLogger { get; } + internal ILogger TransactionLogger { get; } + internal ILogger CopyLogger { get; } + internal ILogger ReplicationLogger { get; } + internal ILogger ExceptionLogger { get; } + + /// + /// Determines whether parameter contents will be logged alongside SQL statements - this may reveal sensitive information. + /// Defaults to false. + /// + internal bool IsParameterLoggingEnabled { get; } + + /// + /// + /// Globally initializes Npgsql logging to use the provided . + /// Must be called before any Npgsql APIs are used. + /// + /// + /// This is a legacy-only, backwards compatibility API. New applications should set the logger factory on + /// and use the resulting instead. + /// + /// + /// The logging factory to use when logging from Npgsql. + /// + /// Determines whether parameter contents will be logged alongside SQL statements - this may reveal sensitive information. + /// Defaults to . + /// + public static void InitializeLogging(ILoggerFactory loggerFactory, bool parameterLoggingEnabled = false) + => (GlobalLoggerFactory, GlobalIsParameterLoggingEnabled) = (loggerFactory, parameterLoggingEnabled); +} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlMultiHostDataSource.cs b/src/Npgsql/NpgsqlMultiHostDataSource.cs new file mode 100644 index 0000000000..813460b557 --- /dev/null +++ b/src/Npgsql/NpgsqlMultiHostDataSource.cs @@ -0,0 +1,464 @@ +using Npgsql.Internal; +using Npgsql.Util; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; + +namespace Npgsql; + +/// +/// An which manages connections for multiple hosts, is aware of their states (primary, secondary, +/// offline...) and can perform failover and load balancing across them. +/// +/// +/// See . +/// +public sealed class NpgsqlMultiHostDataSource : NpgsqlDataSource +{ + internal override bool OwnsConnectors => false; + + readonly NpgsqlDataSource[] _pools; + + internal NpgsqlDataSource[] Pools => _pools; + + readonly MultiHostDataSourceWrapper[] _wrappers; + + volatile int _roundRobinIndex = -1; + + internal NpgsqlMultiHostDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) + : base(settings, dataSourceConfig) + { + var hosts = settings.Host!.Split(','); + _pools = new NpgsqlDataSource[hosts.Length]; + for (var i = 0; i < hosts.Length; i++) + { + var poolSettings = settings.Clone(); + Debug.Assert(!poolSettings.Multiplexing); + var host = hosts[i].AsSpan().Trim(); + if (NpgsqlConnectionStringBuilder.TrySplitHostPort(host, out var newHost, out var newPort)) + { + poolSettings.Host = newHost; + poolSettings.Port = newPort; + } + else + poolSettings.Host = host.ToString(); + + _pools[i] = settings.Pooling + ? new PoolingDataSource(poolSettings, dataSourceConfig, this) + : new UnpooledDataSource(poolSettings, dataSourceConfig); + } + +#if NETSTANDARD + var targetSessionAttributeValues = Enum.GetValues(typeof(TargetSessionAttributes)).Cast().ToArray(); +#else + var targetSessionAttributeValues = Enum.GetValues().ToArray(); +#endif + var highestValue = 0; + foreach (var value in targetSessionAttributeValues) + if ((int)value > highestValue) + highestValue = (int)value; + + _wrappers = new MultiHostDataSourceWrapper[highestValue + 1]; + foreach (var targetSessionAttribute in targetSessionAttributeValues) + _wrappers[(int)targetSessionAttribute] = new(this, targetSessionAttribute); + } + + /// + /// Returns a new, unopened connection from this data source. + /// + /// Specifies the server type (e.g. primary, standby). + public NpgsqlConnection CreateConnection(TargetSessionAttributes targetSessionAttributes) + => NpgsqlConnection.FromDataSource(_wrappers[(int)targetSessionAttributes]); + + /// + /// Returns a new, opened connection from this data source. + /// + /// Specifies the server type (e.g. primary, standby). + public NpgsqlConnection OpenConnection(TargetSessionAttributes targetSessionAttributes) + { + var connection = CreateConnection(targetSessionAttributes); + + try + { + connection.Open(); + return connection; + } + catch + { + connection.Dispose(); + throw; + } + } + + /// + /// Returns a new, opened connection from this data source. + /// + /// Specifies the server type (e.g. primary, standby). + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + public async ValueTask OpenConnectionAsync( + TargetSessionAttributes targetSessionAttributes, + CancellationToken cancellationToken = default) + { + var connection = CreateConnection(targetSessionAttributes); + + try + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + return connection; + } + catch + { + await connection.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + /// + /// Returns an that wraps this multi-host one with the given server type. + /// + /// Specifies the server type (e.g. primary, standby). + public NpgsqlDataSource WithTargetSession(TargetSessionAttributes targetSessionAttributes) + => _wrappers[(int)targetSessionAttributes]; + + static bool IsPreferred(DatabaseState state, TargetSessionAttributes preferredType) + => state switch + { + DatabaseState.Offline => false, + DatabaseState.Unknown => true, // We will check compatibility again after refreshing the database state + + DatabaseState.PrimaryReadWrite when preferredType is + TargetSessionAttributes.Primary or + TargetSessionAttributes.PreferPrimary or + TargetSessionAttributes.ReadWrite + => true, + + DatabaseState.PrimaryReadOnly when preferredType is + TargetSessionAttributes.Primary or + TargetSessionAttributes.PreferPrimary or + TargetSessionAttributes.ReadOnly + => true, + + DatabaseState.Standby when preferredType is + TargetSessionAttributes.Standby or + TargetSessionAttributes.PreferStandby or + TargetSessionAttributes.ReadOnly + => true, + + _ => preferredType == TargetSessionAttributes.Any + }; + + static bool IsOnline(DatabaseState state, TargetSessionAttributes preferredType) + { + Debug.Assert(preferredType is TargetSessionAttributes.PreferPrimary or TargetSessionAttributes.PreferStandby); + return state != DatabaseState.Offline; + } + + async ValueTask TryGetIdleOrNew( + NpgsqlConnection conn, + TimeSpan timeoutPerHost, + bool async, + TargetSessionAttributes preferredType, Func stateValidator, + int poolIndex, + IList exceptions, + CancellationToken cancellationToken) + { + var pools = _pools; + for (var i = 0; i < pools.Length; i++) + { + var pool = pools[poolIndex]; + poolIndex++; + if (poolIndex == pools.Length) + poolIndex = 0; + + var databaseState = pool.GetDatabaseState(); + if (!stateValidator(databaseState, preferredType)) + continue; + + NpgsqlConnector? connector = null; + + try + { + if (pool.TryGetIdleConnector(out connector)) + { + if (databaseState == DatabaseState.Unknown) + { + databaseState = await connector.QueryDatabaseState(new NpgsqlTimeout(timeoutPerHost), async, cancellationToken).ConfigureAwait(false); + Debug.Assert(databaseState != DatabaseState.Unknown); + if (!stateValidator(databaseState, preferredType)) + { + pool.Return(connector); + continue; + } + } + + return connector; + } + else + { + connector = await pool.OpenNewConnector(conn, new NpgsqlTimeout(timeoutPerHost), async, cancellationToken).ConfigureAwait(false); + if (connector is not null) + { + if (databaseState == DatabaseState.Unknown) + { + // While opening a new connector we might have refreshed the database state, check again + databaseState = pool.GetDatabaseState(); + if (databaseState == DatabaseState.Unknown) + databaseState = await connector.QueryDatabaseState(new NpgsqlTimeout(timeoutPerHost), async, cancellationToken).ConfigureAwait(false); + Debug.Assert(databaseState != DatabaseState.Unknown); + if (!stateValidator(databaseState, preferredType)) + { + pool.Return(connector); + continue; + } + } + + return connector; + } + } + } + catch (Exception ex) + { + exceptions.Add(ex); + if (connector is not null) + pool.Return(connector); + } + } + + return null; + } + + async ValueTask TryGet( + NpgsqlConnection conn, + TimeSpan timeoutPerHost, + bool async, + TargetSessionAttributes preferredType, + Func stateValidator, + int poolIndex, + IList exceptions, + CancellationToken cancellationToken) + { + var pools = _pools; + for (var i = 0; i < pools.Length; i++) + { + var pool = pools[poolIndex]; + poolIndex++; + if (poolIndex == pools.Length) + poolIndex = 0; + + var databaseState = pool.GetDatabaseState(); + if (!stateValidator(databaseState, preferredType)) + continue; + + NpgsqlConnector? connector = null; + + try + { + connector = await pool.Get(conn, new NpgsqlTimeout(timeoutPerHost), async, cancellationToken).ConfigureAwait(false); + if (databaseState == DatabaseState.Unknown) + { + // Get might have opened a new physical connection and refreshed the database state, check again + databaseState = pool.GetDatabaseState(); + if (databaseState == DatabaseState.Unknown) + databaseState = await connector.QueryDatabaseState(new NpgsqlTimeout(timeoutPerHost), async, cancellationToken).ConfigureAwait(false); + + Debug.Assert(databaseState != DatabaseState.Unknown); + if (!stateValidator(databaseState, preferredType)) + { + pool.Return(connector); + continue; + } + } + + return connector; + } + catch (Exception ex) + { + exceptions.Add(ex); + if (connector is not null) + pool.Return(connector); + } + } + + return null; + } + + internal override async ValueTask Get( + NpgsqlConnection conn, + NpgsqlTimeout timeout, + bool async, + CancellationToken cancellationToken) + { + CheckDisposed(); + + var exceptions = new List(); + + var poolIndex = conn.Settings.LoadBalanceHosts ? GetRoundRobinIndex() : 0; + + var timeoutPerHost = timeout.IsSet ? timeout.CheckAndGetTimeLeft() : TimeSpan.Zero; + var preferredType = GetTargetSessionAttributes(conn); + var checkUnpreferred = preferredType is TargetSessionAttributes.PreferPrimary or TargetSessionAttributes.PreferStandby; + + var connector = await TryGetIdleOrNew(conn, timeoutPerHost, async, preferredType, IsPreferred, poolIndex, exceptions, cancellationToken).ConfigureAwait(false) ?? + (checkUnpreferred ? + await TryGetIdleOrNew(conn, timeoutPerHost, async, preferredType, IsOnline, poolIndex, exceptions, cancellationToken).ConfigureAwait(false) + : null) ?? + await TryGet(conn, timeoutPerHost, async, preferredType, IsPreferred, poolIndex, exceptions, cancellationToken).ConfigureAwait(false) ?? + (checkUnpreferred ? + await TryGet(conn, timeoutPerHost, async, preferredType, IsOnline, poolIndex, exceptions, cancellationToken).ConfigureAwait(false) + : null); + + return connector ?? throw NoSuitableHostsException(exceptions); + } + + static NpgsqlException NoSuitableHostsException(IList exceptions) + { + return exceptions.Count == 0 + ? new NpgsqlException("No suitable host was found.") + : exceptions[0] is PostgresException firstException && AllEqual(firstException, exceptions) + ? firstException + : new NpgsqlException("Unable to connect to a suitable host. Check inner exception for more details.", + new AggregateException(exceptions)); + + static bool AllEqual(PostgresException first, IList exceptions) + { + foreach (var x in exceptions) + if (x is not PostgresException ex || ex.SqlState != first.SqlState) + return false; + return true; + } + } + + int GetRoundRobinIndex() + { + while (true) + { + var index = Interlocked.Increment(ref _roundRobinIndex); + if (index >= 0) + return index % _pools.Length; + + // Worst case scenario - we've wrapped around integer counter + if (index == int.MinValue) + { + // This is the thread which wrapped around the counter - reset it to 0 + _roundRobinIndex = 0; + return 0; + } + + // This is not the thread which wrapped around the counter - just wait until it's 0 or more + var sw = new SpinWait(); + while (_roundRobinIndex < 0) + sw.SpinOnce(); + } + } + + internal override void Return(NpgsqlConnector connector) + => throw new NpgsqlException("Npgsql bug: a connector was returned to " + nameof(NpgsqlMultiHostDataSource)); + + internal override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + => throw new NpgsqlException("Npgsql bug: trying to get an idle connector from " + nameof(NpgsqlMultiHostDataSource)); + + internal override ValueTask OpenNewConnector(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + => throw new NpgsqlException("Npgsql bug: trying to open a new connector from " + nameof(NpgsqlMultiHostDataSource)); + + internal override void Clear() + { + foreach (var pool in _pools) + pool.Clear(); + } + + /// + /// Clears the database state (primary, secondary, offline...) for all data sources managed by this multi-host data source. + /// Can be useful to make Npgsql retry a PostgreSQL instance which was previously detected to be offline. + /// + public void ClearDatabaseStates() + { + foreach (var pool in _pools) + { + pool.UpdateDatabaseState(default, default, default, ignoreTimeStamp: true); + } + } + + internal override (int Total, int Idle, int Busy) Statistics + { + get + { + var numConnectors = 0; + var idleCount = 0; + + foreach (var pool in _pools) + { + var stat = pool.Statistics; + numConnectors += stat.Total; + idleCount += stat.Idle; + } + + return (numConnectors, idleCount, numConnectors - idleCount); + } + } + + internal override bool TryRentEnlistedPending( + Transaction transaction, + NpgsqlConnection connection, + [NotNullWhen(true)] out NpgsqlConnector? connector) + { + lock (_pendingEnlistedConnectors) + { + if (!_pendingEnlistedConnectors.TryGetValue(transaction, out var list)) + { + connector = null; + return false; + } + + var preferredType = GetTargetSessionAttributes(connection); + // First try to get a valid preferred connector. + if (TryGetValidConnector(list, preferredType, IsPreferred, out connector)) + { + return true; + } + + // Can't get valid preferred connector. Try to get an unpreferred connector, if supported. + if ((preferredType == TargetSessionAttributes.PreferPrimary || preferredType == TargetSessionAttributes.PreferStandby) + && TryGetValidConnector(list, preferredType, IsOnline, out connector)) + { + return true; + } + + connector = null; + return false; + } + + bool TryGetValidConnector(List list, TargetSessionAttributes preferredType, + Func validationFunc, [NotNullWhen(true)] out NpgsqlConnector? connector) + { + for (var i = list.Count - 1; i >= 0; i--) + { + connector = list[i]; + var lastKnownState = connector.DataSource.GetDatabaseState(ignoreExpiration: true); + Debug.Assert(lastKnownState != DatabaseState.Unknown); + if (validationFunc(lastKnownState, preferredType)) + { + list.RemoveAt(i); + if (list.Count == 0) + _pendingEnlistedConnectors.Remove(transaction); + return true; + } + } + + connector = null; + return false; + } + } + + static TargetSessionAttributes GetTargetSessionAttributes(NpgsqlConnection connection) + => connection.Settings.TargetSessionAttributesParsed ?? + (PostgresEnvironment.TargetSessionAttributes is { } s + ? NpgsqlConnectionStringBuilder.ParseTargetSessionAttributes(s) + : TargetSessionAttributes.Any); +} diff --git a/src/Npgsql/NpgsqlNestedDataReader.cs b/src/Npgsql/NpgsqlNestedDataReader.cs new file mode 100644 index 0000000000..d3b6e37bfd --- /dev/null +++ b/src/Npgsql/NpgsqlNestedDataReader.cs @@ -0,0 +1,535 @@ +using Npgsql.Internal; +using Npgsql.PostgresTypes; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using Npgsql.Internal.Postgres; + +namespace Npgsql; + +/// +/// Reads a forward-only stream of rows from a nested data source. +/// Can be retrieved using or +/// . +/// +public sealed class NpgsqlNestedDataReader : DbDataReader +{ + readonly NpgsqlDataReader _outermostReader; + readonly NpgsqlNestedDataReader? _outerNestedReader; + NpgsqlNestedDataReader? _cachedFreeNestedDataReader; + PostgresCompositeType? _compositeType; + readonly int _depth; + int _numRows; + int _nextRowIndex; + int _nextRowBufferPos; + ReaderState _readerState; + + readonly List _columns = new(); + long _startPos; + + DataFormat Format => DataFormat.Binary; + + readonly struct ColumnInfo + { + readonly DataFormat _format; + public PostgresType PostgresType { get; } + public int BufferPos { get; } + public PgConverterInfo LastConverterInfo { get; init; } + + public PgTypeInfo ObjectOrDefaultTypeInfo { get; } + public PgConverterInfo GetObjectOrDefaultInfo() => ObjectOrDefaultTypeInfo.Bind(Field, _format); + + Field Field => new("?", ObjectOrDefaultTypeInfo.Options.PortableTypeIds ? PostgresType.DataTypeName : (Oid)PostgresType.OID, -1); + + public PgConverterInfo Bind(PgTypeInfo typeInfo) => typeInfo.Bind(Field, _format); + + public ColumnInfo(PostgresType postgresType, int bufferPos, PgTypeInfo objectOrDefaultTypeInfo, DataFormat format) + { + _format = format; + PostgresType = postgresType; + BufferPos = bufferPos; + ObjectOrDefaultTypeInfo = objectOrDefaultTypeInfo; + } + } + + PgReader PgReader => _outermostReader.Buffer.PgReader; + PgSerializerOptions SerializerOptions => _outermostReader.Connector.SerializerOptions; + + internal NpgsqlNestedDataReader(NpgsqlDataReader outermostReader, NpgsqlNestedDataReader? outerNestedReader, + int depth, PostgresCompositeType? compositeType) + { + _outermostReader = outermostReader; + _outerNestedReader = outerNestedReader; + _depth = depth; + _compositeType = compositeType; + _startPos = PgReader.FieldStartPos; + } + + internal void Init(PostgresCompositeType? compositeType) + { + _startPos = PgReader.FieldStartPos; + _columns.Clear(); + _numRows = 0; + _nextRowIndex = 0; + _nextRowBufferPos = 0; + _readerState = ReaderState.BeforeFirstRow; + _compositeType = compositeType; + } + + internal void InitArray() + { + var dimensions = PgReader.ReadInt32(); + var containsNulls = PgReader.ReadInt32() == 1; + PgReader.ReadUInt32(); // Element OID. Ignored. + + if (containsNulls) + throw new InvalidOperationException("Record array contains null record"); + + if (dimensions == 0) + return; + + if (dimensions != 1) + throw new InvalidOperationException("Cannot read a multidimensional array with a nested DbDataReader"); + + _numRows = PgReader.ReadInt32(); + PgReader.ReadInt32(); // Lower bound + + if (_numRows > 0) + PgReader.ReadInt32(); // Length of first row + + _nextRowBufferPos = PgReader.FieldOffset; + } + + internal void InitSingleRow() + { + _numRows = 1; + _nextRowBufferPos = PgReader.FieldOffset; + } + + /// + public override object this[int ordinal] => GetValue(ordinal); + + /// + public override object this[string name] => GetValue(GetOrdinal(name)); + + /// + public override int Depth + { + get + { + CheckNotClosed(); + return _depth; + } + } + + /// + public override int FieldCount + { + get + { + CheckNotClosed(); + return _readerState == ReaderState.OnRow ? _columns.Count : 0; + } + } + + /// + public override bool HasRows + { + get + { + CheckNotClosed(); + return _numRows > 0; + } + } + + /// + public override bool IsClosed + => _readerState == ReaderState.Closed || _readerState == ReaderState.Disposed + || _outermostReader.IsClosed || PgReader.FieldStartPos != _startPos; + + /// + public override int RecordsAffected => -1; + + /// + public override bool GetBoolean(int ordinal) => GetFieldValue(ordinal); + /// + public override byte GetByte(int ordinal) => GetFieldValue(ordinal); + /// + public override char GetChar(int ordinal) => GetFieldValue(ordinal); + /// + public override DateTime GetDateTime(int ordinal) => GetFieldValue(ordinal); + /// + public override decimal GetDecimal(int ordinal) => GetFieldValue(ordinal); + /// + public override double GetDouble(int ordinal) => GetFieldValue(ordinal); + /// + public override float GetFloat(int ordinal) => GetFieldValue(ordinal); + /// + public override Guid GetGuid(int ordinal) => GetFieldValue(ordinal); + /// + public override short GetInt16(int ordinal) => GetFieldValue(ordinal); + /// + public override int GetInt32(int ordinal) => GetFieldValue(ordinal); + /// + public override long GetInt64(int ordinal) => GetFieldValue(ordinal); + /// + public override string GetString(int ordinal) => GetFieldValue(ordinal); + + /// + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) + { + if (dataOffset is < 0 or > int.MaxValue) + throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between 0 and {int.MaxValue}"); + if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) + throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) + throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); + + var columnLen = CheckRowAndColumnAndSeek(ordinal, out var column); + if (columnLen is -1) + ThrowHelper.ThrowInvalidCastException_NoValue(); + + if (buffer is null) + return columnLen; + + using var _ = PgReader.BeginNestedRead(columnLen, Size.Zero); + + // Move to offset + PgReader.Seek((int)dataOffset); + + // At offset, read into buffer. + length = Math.Min(length, PgReader.CurrentRemaining); + PgReader.ReadBytes(new Span(buffer, bufferOffset, length)); + return length; + } + /// + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) + => throw new NotSupportedException(); + + /// + protected override DbDataReader GetDbDataReader(int ordinal) => GetData(ordinal); + + /// + /// Returns a nested data reader for the requested column. + /// The column type must be a record or a to Npgsql known composite type, or an array thereof. + /// + /// The zero-based column ordinal. + /// A data reader. + public new NpgsqlNestedDataReader GetData(int ordinal) + { + var valueLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var type = column.PostgresType; + var isArray = type is PostgresArrayType; + var elementType = isArray ? ((PostgresArrayType)type).Element : type; + var compositeType = elementType as PostgresCompositeType; + if (elementType.InternalName != "record" && compositeType == null) + throw new InvalidCastException("GetData() not supported for type " + type.DisplayName); + + if (valueLength == -1) + throw new InvalidCastException("field is null"); + + var reader = _cachedFreeNestedDataReader; + if (reader != null) + { + _cachedFreeNestedDataReader = null; + reader.Init(compositeType); + } + else + { + reader = new NpgsqlNestedDataReader(_outermostReader, this, _depth + 1, compositeType); + } + if (isArray) + reader.InitArray(); + else + reader.InitSingleRow(); + return reader; + } + + /// + public override string GetDataTypeName(int ordinal) + { + var column = CheckRowAndColumn(ordinal); + return column.PostgresType.DisplayName; + } + + /// + public override IEnumerator GetEnumerator() => new DbEnumerator(this); + + /// + public override string GetName(int ordinal) + { + CheckRowAndColumn(ordinal); + return _compositeType?.Fields[ordinal].Name ?? "?column?"; + } + + /// + public override int GetOrdinal(string name) + { + if (_compositeType == null) + throw new NotSupportedException("GetOrdinal is not supported for the record type"); + + for (var i = 0; i < _compositeType.Fields.Count; i++) + { + if (_compositeType.Fields[i].Name == name) + return i; + } + + for (var i = 0; i < _compositeType.Fields.Count; i++) + { + if (string.Compare(_compositeType.Fields[i].Name, name, CultureInfo.InvariantCulture, + CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase | CompareOptions.IgnoreKanaType) == 0) + return i; + } + + throw new IndexOutOfRangeException("Field not found in row: " + name); + } + + /// + [UnconditionalSuppressMessage("ILLink", "IL2093", Justification = "No members are dynamically accessed by Npgsql via NpgsqlNestedDataReader.GetFieldType.")] + public override Type GetFieldType(int ordinal) + { + var column = CheckRowAndColumn(ordinal); + return column.GetObjectOrDefaultInfo().TypeToConvert; + } + + /// + public override object GetValue(int ordinal) + { + var columnLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var info = column.GetObjectOrDefaultInfo(); + if (columnLength == -1) + return DBNull.Value; + + using var _ = PgReader.BeginNestedRead(columnLength, info.BufferRequirement); + return info.Converter.ReadAsObject(PgReader); + } + + /// + public override int GetValues(object[] values) + { + if (values == null) + throw new ArgumentNullException(nameof(values)); + CheckOnRow(); + + var count = Math.Min(FieldCount, values.Length); + for (var i = 0; i < count; i++) + values[i] = GetValue(i); + return count; + } + + /// + public override bool IsDBNull(int ordinal) + => CheckRowAndColumnAndSeek(ordinal, out _) == -1; + + /// + public override T GetFieldValue(int ordinal) + { + if (typeof(T) == typeof(Stream)) + return (T)(object)GetStream(ordinal); + + if (typeof(T) == typeof(TextReader)) + return (T)(object)GetTextReader(ordinal); + + var columnLength = CheckRowAndColumnAndSeek(ordinal, out var column); + var info = GetOrAddConverterInfo(typeof(T), column, ordinal, out var asObject); + + if (columnLength == -1) + { + // When T is a Nullable (and only in that case), we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + + if (typeof(T) == typeof(object)) + return (T)(object)DBNull.Value; + + ThrowHelper.ThrowInvalidCastException_NoValue(); + } + + using var _ = PgReader.BeginNestedRead(columnLength, info.BufferRequirement); + return asObject + ? (T)info.Converter.ReadAsObject(PgReader)! + : info.Converter.UnsafeDowncast().Read(PgReader); + } + + /// + public override bool Read() + { + CheckResultSet(); + + PgReader.Seek(_nextRowBufferPos); + if (_nextRowIndex == _numRows) + { + _readerState = ReaderState.AfterRows; + return false; + } + + if (_nextRowIndex++ != 0) + PgReader.ReadInt32(); // Length of record + + var numColumns = PgReader.ReadInt32(); + + for (var i = 0; i < numColumns; i++) + { + var typeOid = PgReader.ReadUInt32(); + var bufferPos = PgReader.FieldOffset; + if (i >= _columns.Count) + { + var pgType = SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); + _columns.Add(new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format)); + } + else + { + var pgType = _columns[i].PostgresType.OID == typeOid + ? _columns[i].PostgresType + : SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); + _columns[i] = new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format); + } + + var columnLen = PgReader.ReadInt32(); + if (columnLen >= 0) + PgReader.Consume(columnLen); + } + _columns.RemoveRange(numColumns, _columns.Count - numColumns); + + _nextRowBufferPos = PgReader.FieldOffset; + + _readerState = ReaderState.OnRow; + return true; + } + + /// + public override bool NextResult() + { + CheckNotClosed(); + + _numRows = 0; + _nextRowBufferPos = 0; + _nextRowIndex = 0; + _readerState = ReaderState.AfterResult; + return false; + } + + /// + public override void Close() + { + if (_readerState != ReaderState.Disposed) + { + _readerState = ReaderState.Closed; + } + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing && _readerState != ReaderState.Disposed) + { + Close(); + _readerState = ReaderState.Disposed; + if (_outerNestedReader != null) + { + _outerNestedReader._cachedFreeNestedDataReader ??= this; + } + else + { + _outermostReader.CachedFreeNestedDataReader ??= this; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void CheckNotClosed() + { + if (IsClosed) + throw new InvalidOperationException("The reader is closed"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void CheckResultSet() + { + CheckNotClosed(); + switch (_readerState) + { + case ReaderState.BeforeFirstRow: + case ReaderState.OnRow: + case ReaderState.AfterRows: + break; + default: + throw new InvalidOperationException("No resultset is currently being traversed"); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void CheckOnRow() + { + CheckResultSet(); + if (_readerState != ReaderState.OnRow) + throw new InvalidOperationException("No row is available"); + } + + ColumnInfo CheckRowAndColumn(int column) + { + CheckOnRow(); + + if (column < 0 || column >= _columns.Count) + throw new IndexOutOfRangeException($"Column must be between {0} and {_columns.Count - 1}"); + + return _columns[column]; + } + + int CheckRowAndColumnAndSeek(int ordinal, out ColumnInfo column) + { + column = CheckRowAndColumn(ordinal); + PgReader.Seek(column.BufferPos); + return PgReader.ReadInt32(); + } + + PgConverterInfo GetOrAddConverterInfo(Type type, ColumnInfo column, int ordinal, out bool asObject) + { + if (column.LastConverterInfo is { IsDefault: false } lastInfo && lastInfo.TypeToConvert == type) + { + // As TypeInfoMappingCollection is always adding object mappings for + // default/datatypename mappings, we'll also check Converter.TypeToConvert. + // If we have an exact match we are still able to use e.g. a converter for ints in an unboxed fashion. + asObject = lastInfo.IsBoxingConverter && lastInfo.Converter.TypeToConvert != type; + return lastInfo; + } + + if (column.GetObjectOrDefaultInfo() is { IsDefault: false } odfInfo) + { + if (typeof(object) == type) + { + asObject = true; + return odfInfo; + } + + if (odfInfo.TypeToConvert == type) + { + // As TypeInfoMappingCollection is always adding object mappings for + // default/datatypename mappings, we'll also check Converter.TypeToConvert. + // If we have an exact match we are still able to use e.g. a converter for ints in an unboxed fashion. + asObject = odfInfo.IsBoxingConverter && odfInfo.Converter.TypeToConvert != type; + return odfInfo; + } + } + + var converterInfo = column.Bind(AdoSerializerHelpers.GetTypeInfoForReading(type, column.PostgresType, SerializerOptions)); + _columns[ordinal] = column with { LastConverterInfo = converterInfo }; + asObject = converterInfo.IsBoxingConverter; + return converterInfo; + } + + enum ReaderState + { + BeforeFirstRow, + OnRow, + AfterRows, + AfterResult, + Closed, + Disposed + } +} diff --git a/src/Npgsql/NpgsqlNotificationEventArgs.cs b/src/Npgsql/NpgsqlNotificationEventArgs.cs index 5e62ee89f5..82e00b18a6 100644 --- a/src/Npgsql/NpgsqlNotificationEventArgs.cs +++ b/src/Npgsql/NpgsqlNotificationEventArgs.cs @@ -1,47 +1,35 @@ using System; +using Npgsql.Internal; -namespace Npgsql +namespace Npgsql; + +/// +/// Provides information on a PostgreSQL notification. Notifications are sent when your connection has registered for +/// notifications on a specific channel via the LISTEN command. NOTIFY can be used to generate such notifications, +/// allowing for an inter-connection communication channel. +/// +public sealed class NpgsqlNotificationEventArgs : EventArgs { /// - /// Provides information on a PostgreSQL notification. Notifications are sent when your connection has registered for - /// notifications on a specific channel via the LISTEN command. NOTIFY can be used to generate such notifications, - /// allowing for an inter-connection communication channel. + /// Process ID of the PostgreSQL backend that sent this notification. /// - public sealed class NpgsqlNotificationEventArgs : EventArgs - { - /// - /// Process ID of the PostgreSQL backend that sent this notification. - /// - // ReSharper disable once InconsistentNaming - public int PID { get; } - - /// - /// The channel on which the notification was sent. - /// - public string Channel { get; } + // ReSharper disable once InconsistentNaming + public int PID { get; } - /// - /// An optional payload string that was sent with this notification. - /// - public string Payload { get; } - - /// - /// The channel on which the notification was sent. - /// - [Obsolete("Use Channel instead")] - public string Condition => Channel; + /// + /// The channel on which the notification was sent. + /// + public string Channel { get; } - /// - /// An optional payload string that was sent with this notification. - /// - [Obsolete("Use Payload instead")] - public string AdditionalInformation => Payload; + /// + /// An optional payload string that was sent with this notification. + /// + public string Payload { get; } - internal NpgsqlNotificationEventArgs(NpgsqlReadBuffer buf) - { - PID = buf.ReadInt32(); - Channel = buf.ReadNullTerminatedString(); - Payload = buf.ReadNullTerminatedString(); - } + internal NpgsqlNotificationEventArgs(NpgsqlReadBuffer buf) + { + PID = buf.ReadInt32(); + Channel = buf.ReadNullTerminatedString(); + Payload = buf.ReadNullTerminatedString(); } } diff --git a/src/Npgsql/NpgsqlOperationInProgressException.cs b/src/Npgsql/NpgsqlOperationInProgressException.cs index 09db836010..eb7377afcd 100644 --- a/src/Npgsql/NpgsqlOperationInProgressException.cs +++ b/src/Npgsql/NpgsqlOperationInProgressException.cs @@ -1,33 +1,34 @@ -namespace Npgsql +using Npgsql.Internal; + +namespace Npgsql; + +/// +/// Thrown when trying to use a connection that is already busy performing some other operation. +/// Provides information on the already-executing operation to help with debugging. +/// +public sealed class NpgsqlOperationInProgressException : NpgsqlException { /// - /// Thrown when trying to use a connection that is already busy performing some other operation. - /// Provides information on the already-executing operation to help with debugging. + /// Creates a new instance of . /// - public sealed class NpgsqlOperationInProgressException : NpgsqlException + /// + /// A command which was in progress when the operation which triggered this exception was executed. + /// + public NpgsqlOperationInProgressException(NpgsqlCommand command) + : base("A command is already in progress: " + command.CommandText) { - /// - /// Creates a new instance of . - /// - /// - /// A command which was in progress when the operation which triggered this exception was executed. - /// - public NpgsqlOperationInProgressException(NpgsqlCommand command) - : base("A command is already in progress: " + command.CommandText) - { - CommandInProgress = command; - } - - internal NpgsqlOperationInProgressException(ConnectorState state) - : base($"The connection is already in state '{state}'") - { - } + CommandInProgress = command; + } - /// - /// If the connection is busy with another command, this will contain a reference to that command. - /// Otherwise, if the connection if busy with another type of operation (e.g. COPY), contains - /// . - /// - public NpgsqlCommand? CommandInProgress { get; } + internal NpgsqlOperationInProgressException(ConnectorState state) + : base($"The connection is already in state '{state}'") + { } -} + + /// + /// If the connection is busy with another command, this will contain a reference to that command. + /// Otherwise, if the connection if busy with another type of operation (e.g. COPY), contains + /// . + /// + public NpgsqlCommand? CommandInProgress { get; } +} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlParameter.cs b/src/Npgsql/NpgsqlParameter.cs index bc6726b6e8..06022e1dae 100644 --- a/src/Npgsql/NpgsqlParameter.cs +++ b/src/Npgsql/NpgsqlParameter.cs @@ -2,566 +2,829 @@ using System.ComponentModel; using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using JetBrains.Annotations; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; using Npgsql.TypeMapping; using Npgsql.Util; using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +/// +/// This class represents a parameter to a command that will be sent to server +/// +public class NpgsqlParameter : DbParameter, IDbDataParameter, ICloneable { - /// - /// This class represents a parameter to a command that will be sent to server - /// - public class NpgsqlParameter : DbParameter, IDbDataParameter, ICloneable + #region Fields and Properties + + private protected byte _precision; + private protected byte _scale; + private protected int _size; + + internal NpgsqlDbType? _npgsqlDbType; + internal string? _dataTypeName; + + private protected string _name = string.Empty; + object? _value; + private protected bool _useSubStream; + private protected SubReadStream? _subStream; + private protected string _sourceColumn; + + internal string TrimmedName { get; private protected set; } = PositionalName; + internal const string PositionalName = ""; + + private protected PgTypeInfo? TypeInfo { get; private set; } + + internal PgTypeId PgTypeId { get; private set; } + private protected PgConverter? Converter { get; private set; } + + internal DataFormat Format { get; private protected set; } + private protected Size? WriteSize { get; set; } + private protected object? _writeState; + private protected Size _bufferRequirement; + private protected bool _asObject; + + #endregion + + #region Constructors + + /// + /// Initializes a new instance of the class. + /// + public NpgsqlParameter() { - #region Fields and Properties + _sourceColumn = string.Empty; + Direction = ParameterDirection.Input; + SourceVersion = DataRowVersion.Current; + } - byte _precision; - byte _scale; - int _size; + /// + /// Initializes a new instance of the class with the parameter name and a value. + /// + /// The name of the parameter to map. + /// The value of the . + /// + ///

+ /// When you specify an in the value parameter, the is + /// inferred from the CLR type. + ///

+ ///

+ /// When using this constructor, you must be aware of a possible misuse of the constructor which takes a + /// parameter. This happens when calling this constructor passing an int 0 and the compiler thinks you are passing a value of + /// . Use for example to have compiler calling the correct constructor. + ///

+ ///
+ public NpgsqlParameter(string? parameterName, object? value) + : this() + { + ParameterName = parameterName; + // ReSharper disable once VirtualMemberCallInConstructor + Value = value; + } - // ReSharper disable InconsistentNaming - private protected NpgsqlDbType? _npgsqlDbType; - private protected string? _dataTypeName; - // ReSharper restore InconsistentNaming + /// + /// Initializes a new instance of the class with the parameter name and the data type. + /// + /// The name of the parameter to map. + /// One of the values. + public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType) + : this(parameterName, parameterType, 0, string.Empty) + { + } - DbType? _cachedDbType; - string _name = string.Empty; - object? _value; - string _sourceColumn; + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + public NpgsqlParameter(string? parameterName, DbType parameterType) + : this(parameterName, parameterType, 0, string.Empty) + { + } - internal string TrimmedName { get; private set; } = string.Empty; + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType, int size) + : this(parameterName, parameterType, size, string.Empty) + { + } - /// - /// Can be used to communicate a value from the validation phase to the writing phase. - /// To be used by type handlers only. - /// - public object? ConvertedValue { get; set; } + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + public NpgsqlParameter(string? parameterName, DbType parameterType, int size) + : this(parameterName, parameterType, size, string.Empty) + { + } - internal NpgsqlLengthCache? LengthCache { get; set; } + /// + /// Initializes a new instance of the + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + /// The name of the source column. + public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn) + { + ParameterName = parameterName; + NpgsqlDbType = parameterType; + _size = size; + _sourceColumn = sourceColumn ?? string.Empty; + Direction = ParameterDirection.Input; + SourceVersion = DataRowVersion.Current; + } - internal NpgsqlTypeHandler? Handler { get; set; } + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + /// The name of the source column. + public NpgsqlParameter(string? parameterName, DbType parameterType, int size, string? sourceColumn) + { + ParameterName = parameterName; + DbType = parameterType; + _size = size; + _sourceColumn = sourceColumn ?? string.Empty; + Direction = ParameterDirection.Input; + SourceVersion = DataRowVersion.Current; + } - internal FormatCode FormatCode { get; private set; } + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + /// The name of the source column. + /// One of the values. + /// + /// if the value of the field can be , otherwise . + /// + /// + /// The total number of digits to the left and right of the decimal point to which is resolved. + /// + /// The total number of decimal places to which is resolved. + /// One of the values. + /// An that is the value of the . + public NpgsqlParameter(string parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn, + ParameterDirection direction, bool isNullable, byte precision, byte scale, + DataRowVersion sourceVersion, object value) + { + ParameterName = parameterName; + Size = size; + _sourceColumn = sourceColumn ?? string.Empty; + Direction = direction; + IsNullable = isNullable; + Precision = precision; + Scale = scale; + SourceVersion = sourceVersion; + // ReSharper disable once VirtualMemberCallInConstructor + Value = value; + + NpgsqlDbType = parameterType; + } - #endregion + /// + /// Initializes a new instance of the . + /// + /// The name of the parameter to map. + /// One of the values. + /// The length of the parameter. + /// The name of the source column. + /// One of the values. + /// + /// if the value of the field can be , otherwise . + /// + /// + /// The total number of digits to the left and right of the decimal point to which is resolved. + /// + /// The total number of decimal places to which is resolved. + /// One of the values. + /// An that is the value of the . + public NpgsqlParameter(string parameterName, DbType parameterType, int size, string? sourceColumn, + ParameterDirection direction, bool isNullable, byte precision, byte scale, + DataRowVersion sourceVersion, object value) + { + ParameterName = parameterName; + Size = size; + _sourceColumn = sourceColumn ?? string.Empty; + Direction = direction; + IsNullable = isNullable; + Precision = precision; + Scale = scale; + SourceVersion = sourceVersion; + // ReSharper disable once VirtualMemberCallInConstructor + Value = value; + DbType = parameterType; + } + #endregion - #region Constructors + #region Name - /// - /// Initializes a new instance of the NpgsqlParameter class. - /// - public NpgsqlParameter() + /// + /// Gets or sets The name of the . + /// + /// The name of the . + /// The default is an empty string. + [AllowNull, DefaultValue("")] + public sealed override string ParameterName + { + get => _name; + set { - _sourceColumn = string.Empty; - Direction = ParameterDirection.Input; - SourceVersion = DataRowVersion.Current; + if (Collection is not null) + Collection.ChangeParameterName(this, value); + else + ChangeParameterName(value); } + } - /// - /// Initializes a new instance of the NpgsqlParameter - /// class with the parameter name and a value of the new NpgsqlParameter. - /// - /// The name of the parameter to map. - /// An Object that is the value of the NpgsqlParameter. - /// - ///

When you specify an Object - /// in the value parameter, the DbType is - /// inferred from the .NET Framework type of the Object.

- ///

When using this constructor, you must be aware of a possible misuse of the constructor which takes a DbType parameter. - /// This happens when calling this constructor passing an int 0 and the compiler thinks you are passing a value of DbType. - /// Use Convert.ToInt32(value) for example to have compiler calling the correct constructor.

- ///
- public NpgsqlParameter(string? parameterName, object? value) - : this() - { - ParameterName = parameterName; - // ReSharper disable once VirtualMemberCallInConstructor - Value = value; - } + internal void ChangeParameterName(string? value) + { + if (value is null) + _name = TrimmedName = PositionalName; + else if (value.Length > 0 && (value[0] == ':' || value[0] == '@')) + TrimmedName = (_name = value).Substring(1); + else + _name = TrimmedName = value; + } - /// - /// Initializes a new instance of the NpgsqlParameter - /// class with the parameter name and the data type. - /// - /// The name of the parameter to map. - /// One of the NpgsqlDbType values. - public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType) - : this(parameterName, parameterType, 0, string.Empty) - { - } + internal bool IsPositional => ParameterName.Length == 0; - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the DbType values. - public NpgsqlParameter(string? parameterName, DbType parameterType) - : this(parameterName, parameterType, 0, string.Empty) - { - } + #endregion Name - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the NpgsqlDbType values. - /// The length of the parameter. - public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType, int size) - : this(parameterName, parameterType, size, string.Empty) - { - } + #region Value - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the DbType values. - /// The length of the parameter. - public NpgsqlParameter(string? parameterName, DbType parameterType, int size) - : this(parameterName, parameterType, size, string.Empty) + /// + [TypeConverter(typeof(StringConverter)), Category("Data")] + public override object? Value + { + get => _value; + set { + if (ShouldResetObjectTypeInfo(value)) + ResetTypeInfo(); + else + ResetBindingInfo(); + _value = value; } + } - /// - /// Initializes a new instance of the NpgsqlParameter - /// - /// The name of the parameter to map. - /// One of the NpgsqlDbType values. - /// The length of the parameter. - /// The name of the source column. - public NpgsqlParameter(string? parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn) - { - ParameterName = parameterName; - NpgsqlDbType = parameterType; - _size = size; - _sourceColumn = sourceColumn ?? string.Empty; - Direction = ParameterDirection.Input; - SourceVersion = DataRowVersion.Current; - } + /// + /// Gets or sets the value of the parameter. + /// + /// + /// An that is the value of the parameter. + /// The default value is . + /// + [Category("Data")] + [TypeConverter(typeof(StringConverter))] + public object? NpgsqlValue + { + get => Value; + set => Value = value; + } + + #endregion Value + + #region Type - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the DbType values. - /// The length of the parameter. - /// The name of the source column. - public NpgsqlParameter(string? parameterName, DbType parameterType, int size, string? sourceColumn) + /// + /// Gets or sets the of the parameter. + /// + /// One of the values. The default is . + [DefaultValue(DbType.Object)] + [Category("Data"), RefreshProperties(RefreshProperties.All)] + public sealed override DbType DbType + { + get { - ParameterName = parameterName; - DbType = parameterType; - _size = size; - _sourceColumn = sourceColumn ?? string.Empty; - Direction = ParameterDirection.Input; - SourceVersion = DataRowVersion.Current; - } + if (_npgsqlDbType is { } npgsqlDbType) + return npgsqlDbType.ToDbType(); + + if (_dataTypeName is not null) + return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the NpgsqlDbType values. - /// The length of the parameter. - /// The name of the source column. - /// One of the ParameterDirection values. - /// true if the value of the field can be null, otherwise false. - /// The total number of digits to the left and right of the decimal point to which - /// Value is resolved. - /// The total number of decimal places to which - /// Value is resolved. - /// One of the DataRowVersion values. - /// An Object that is the value - /// of the NpgsqlParameter. - public NpgsqlParameter(string parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn, - ParameterDirection direction, bool isNullable, byte precision, byte scale, - DataRowVersion sourceVersion, object value) + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; + + return DbType.Object; + } + set { - ParameterName = parameterName; - Size = size; - _sourceColumn = sourceColumn ?? string.Empty; - Direction = direction; - IsNullable = isNullable; - Precision = precision; - Scale = scale; - SourceVersion = sourceVersion; - // ReSharper disable once VirtualMemberCallInConstructor - Value = value; - - NpgsqlDbType = parameterType; + ResetTypeInfo(); + _npgsqlDbType = value == DbType.Object + ? null + : value.ToNpgsqlDbType() + ?? throw new NotSupportedException($"The parameter type DbType.{value} isn't supported by PostgreSQL or Npgsql"); } + } - /// - /// Initializes a new instance of the NpgsqlParameter. - /// - /// The name of the parameter to map. - /// One of the DbType values. - /// The length of the parameter. - /// The name of the source column. - /// One of the ParameterDirection values. - /// true if the value of the field can be null, otherwise false. - /// The total number of digits to the left and right of the decimal point to which - /// Value is resolved. - /// The total number of decimal places to which - /// Value is resolved. - /// One of the DataRowVersion values. - /// An Object that is the value - /// of the NpgsqlParameter. - public NpgsqlParameter(string parameterName, DbType parameterType, int size, string? sourceColumn, - ParameterDirection direction, bool isNullable, byte precision, byte scale, - DataRowVersion sourceVersion, object value) + /// + /// Gets or sets the of the parameter. + /// + /// One of the values. The default is . + [DefaultValue(NpgsqlDbType.Unknown)] + [Category("Data"), RefreshProperties(RefreshProperties.All)] + [DbProviderSpecificTypeProperty(true)] + public NpgsqlDbType NpgsqlDbType + { + get { - ParameterName = parameterName; - Size = size; - _sourceColumn = sourceColumn ?? string.Empty; - Direction = direction; - IsNullable = isNullable; - Precision = precision; - Scale = scale; - SourceVersion = sourceVersion; - // ReSharper disable once VirtualMemberCallInConstructor - Value = value; - DbType = parameterType; + if (_npgsqlDbType.HasValue) + return _npgsqlDbType.Value; + + if (_dataTypeName is not null) + return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; + + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type (NpgsqlParameter) where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; + + return NpgsqlDbType.Unknown; } - #endregion + set + { + if (value == NpgsqlDbType.Array) + throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Array, Binary-Or with the element type (e.g. Array of Box is NpgsqlDbType.Array | NpgsqlDbType.Box)."); + if (value == NpgsqlDbType.Range) + throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Range, Binary-Or with the element type (e.g. Range of integer is NpgsqlDbType.Range | NpgsqlDbType.Integer)"); - #region Name + ResetTypeInfo(); + _npgsqlDbType = value; + } + } - /// - /// Gets or sets The name of the NpgsqlParameter. - /// - /// The name of the NpgsqlParameter. - /// The default is an empty string. - [AllowNull, DefaultValue("")] - public sealed override string ParameterName + /// + /// Used to specify which PostgreSQL type will be sent to the database for this parameter. + /// + public string? DataTypeName + { + get { - get => _name; - set + if (_dataTypeName != null) + return _dataTypeName; + + // Map it to a display name. + if (_npgsqlDbType is { } npgsqlDbType) { - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - if (value == null) - _name = TrimmedName = string.Empty; - else if (value.Length > 0 && (value[0] == ':' || value[0] == '@')) - TrimmedName = (_name = value).Substring(1); - else - _name = TrimmedName = value; - - Collection?.InvalidateHashLookups(); + var unqualifiedName = npgsqlDbType.ToUnqualifiedDataTypeName(); + return unqualifiedName is null ? null : Internal.Postgres.DataTypeName.ValidatedName( + "pg_catalog." + unqualifiedName).UnqualifiedDisplayName; } - } - - #endregion Name - #region Value + // Infer from value but don't cache + if (Value is not null) + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.DisplayName; - /// - [TypeConverter(typeof(StringConverter)), Category("Data")] - public override object? Value + return null; + } + set { - get => _value; - set - { - if (_value == null || value == null || _value.GetType() != value.GetType()) - Handler = null; - _value = value; - ConvertedValue = null; - } + ResetTypeInfo(); + _dataTypeName = value; } + } + + #endregion Type + + #region Other Properties + + /// + public sealed override bool IsNullable { get; set; } + + /// + [DefaultValue(ParameterDirection.Input)] + [Category("Data")] + public sealed override ParameterDirection Direction { get; set; } + +#pragma warning disable CS0109 + /// + /// Gets or sets the maximum number of digits used to represent the property. + /// + /// + /// The maximum number of digits used to represent the property. + /// The default value is 0, which indicates that the data provider sets the precision for . + [DefaultValue((byte)0)] + [Category("Data")] + public new byte Precision + { + get => _precision; + set => _precision = value; + } + + /// + /// Gets or sets the number of decimal places to which is resolved. + /// + /// The number of decimal places to which is resolved. The default is 0. + [DefaultValue((byte)0)] + [Category("Data")] + public new byte Scale + { + get => _scale; + set => _scale = value; + } +#pragma warning restore CS0109 - /// - /// Gets or sets the value of the parameter. - /// - /// - /// An that is the value of the parameter. - /// The default value is . - /// - [Category("Data")] - [TypeConverter(typeof(StringConverter))] - public object? NpgsqlValue + /// + [DefaultValue(0)] + [Category("Data")] + public sealed override int Size + { + get => _size; + set { - get => Value; - set => Value = value; + if (value < -1) + throw new ArgumentException($"Invalid parameter Size value '{value}'. The value must be greater than or equal to 0."); + + ResetBindingInfo(); + _size = value; } + } + + /// + [AllowNull, DefaultValue("")] + [Category("Data")] + public sealed override string SourceColumn + { + get => _sourceColumn; + set => _sourceColumn = value ?? string.Empty; + } + + /// + [Category("Data"), DefaultValue(DataRowVersion.Current)] + public sealed override DataRowVersion SourceVersion { get; set; } + + /// + public sealed override bool SourceColumnNullMapping { get; set; } + +#pragma warning disable CA2227 + /// + /// The collection to which this parameter belongs, if any. + /// + public NpgsqlParameterCollection? Collection { get; set; } +#pragma warning restore CA2227 - #endregion Value + /// + /// The PostgreSQL data type, such as int4 or text, as discovered from pg_type. + /// This property is automatically set if parameters have been derived via + /// and can be used to + /// acquire additional information about the parameters' data type. + /// + public PostgresType? PostgresType { get; internal set; } - #region Type + #endregion Other Properties - /// - /// Gets or sets the DbType of the parameter. - /// - /// One of the DbType values. The default is Object. - [DefaultValue(DbType.Object)] - [Category("Data"), RefreshProperties(RefreshProperties.All)] - public sealed override DbType DbType + #region Internals + + private protected virtual Type StaticValueType => typeof(object); + + Type? GetValueType(Type staticValueType) => staticValueType != typeof(object) ? staticValueType : Value?.GetType(); + + internal bool ShouldResetObjectTypeInfo(object? value) + { + var currentType = TypeInfo?.Type; + if (currentType is null || value is null) + return false; + + var valueType = value.GetType(); + // We don't want to reset the type info when the value is a DBNull, we're able to write it out with any type info. + return valueType != typeof(DBNull) && currentType != valueType; + } + + internal void GetResolutionInfo(out PgTypeInfo? typeInfo, out PgConverter? converter, out PgTypeId pgTypeId) + { + typeInfo = TypeInfo; + converter = Converter; + pgTypeId = PgTypeId; + } + + internal void SetResolutionInfo(PgTypeInfo typeInfo, PgConverter converter, PgTypeId pgTypeId) + { + if (WriteSize is not null) + ResetBindingInfo(); + + TypeInfo = typeInfo; + Converter = converter; + PgTypeId = pgTypeId; + } + + /// Attempt to resolve a type info based on available (postgres) type information on the parameter. + internal void ResolveTypeInfo(PgSerializerOptions options) + { + var typeInfo = TypeInfo; + var previouslyResolved = ReferenceEquals(typeInfo?.Options, options); + if (!previouslyResolved) { - get + var dataTypeName = + _npgsqlDbType is { } npgsqlDbType + ? npgsqlDbType.ToDataTypeName() ?? npgsqlDbType.ToUnqualifiedDataTypeNameOrThrow() + : _dataTypeName is not null + ? Internal.Postgres.DataTypeName.NormalizeName(_dataTypeName) + : null; + + PgTypeId? pgTypeId = null; + if (dataTypeName is not null) { - if (_cachedDbType.HasValue) - return _cachedDbType.Value; - if (_npgsqlDbType.HasValue) - return _cachedDbType ??= GlobalTypeMapper.Instance.ToDbType(_npgsqlDbType.Value); - if (_value != null) // Infer from value but don't cache - return GlobalTypeMapper.Instance.ToDbType(_value.GetType()); - - return DbType.Object; + if (!options.DatabaseInfo.TryGetPostgresTypeByName(dataTypeName, out var pgType)) + { + ThrowNotSupported(dataTypeName); + return; + } + + pgTypeId = options.ToCanonicalTypeId(pgType.GetRepresentationalType()); } - set + + var unspecifiedDBNull = false; + var valueType = StaticValueType; + if (valueType == typeof(object)) { - Handler = null; - if (value == DbType.Object) + valueType = Value?.GetType(); + if (valueType is null && pgTypeId is null) { - _cachedDbType = null; - _npgsqlDbType = null; + ThrowNoTypeInfo(); + return; } - else + + // We treat object typed DBNull values as default info. + // Unless we don't have a pgTypeId either, at which point we'll use an 'unspecified' PgTypeInfo to help us write a NULL. + if (valueType == typeof(DBNull)) { - _cachedDbType = value; - _npgsqlDbType = GlobalTypeMapper.Instance.ToNpgsqlDbType(value); + if (pgTypeId is null) + { + unspecifiedDBNull = true; + typeInfo = options.UnspecifiedDBNullTypeInfo; + } + else + valueType = null; } } + + if (!unspecifiedDBNull) + typeInfo = AdoSerializerHelpers.GetTypeInfoForWriting(valueType, pgTypeId, options, _npgsqlDbType); + + TypeInfo = typeInfo; } - /// - /// Gets or sets the NpgsqlDbType of the parameter. - /// - /// One of the NpgsqlDbType values. The default is Unknown. - [DefaultValue(NpgsqlDbType.Unknown)] - [Category("Data"), RefreshProperties(RefreshProperties.All)] - [DbProviderSpecificTypeProperty(true)] - public NpgsqlDbType NpgsqlDbType + // This step isn't part of BindValue because we need to know the PgTypeId beforehand for things like SchemaOnly with null values. + // We never reuse resolutions for resolvers across executions as a mutable value itself may influence the result. + // TODO we could expose a property on a Converter/TypeInfo to indicate whether it's immutable, at that point we can reuse. + if (!previouslyResolved || typeInfo!.IsResolverInfo) { - get - { - if (_npgsqlDbType.HasValue) - return _npgsqlDbType.Value; - if (_value != null) // Infer from value - return GlobalTypeMapper.Instance.ToNpgsqlDbType(_value.GetType()); - return NpgsqlDbType.Unknown; - } - set - { - if (value == NpgsqlDbType.Array) - throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Array, Binary-Or with the element type (e.g. Array of Box is NpgsqlDbType.Array | NpgsqlDbType.Box)."); - if (value == NpgsqlDbType.Range) - throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Range, Binary-Or with the element type (e.g. Range of integer is NpgsqlDbType.Range | NpgsqlDbType.Integer)"); - - Handler = null; - _npgsqlDbType = value; - _cachedDbType = null; - } + ResetBindingInfo(); + var resolution = ResolveConverter(typeInfo!); + Converter = resolution.Converter; + PgTypeId = resolution.PgTypeId; + } + + void ThrowNoTypeInfo() + => ThrowHelper.ThrowInvalidOperationException( + $"Parameter '{(!string.IsNullOrEmpty(ParameterName) ? ParameterName : $"${Collection?.IndexOf(this) + 1}")}' must have either its NpgsqlDbType or its DataTypeName or its Value set."); + + void ThrowNotSupported(string dataTypeName) + { + throw new NotSupportedException(_npgsqlDbType is not null + ? $"The NpgsqlDbType '{_npgsqlDbType}' isn't present in your database. You may need to install an extension or upgrade to a newer version." + : $"The data type name '{dataTypeName}' isn't present in your database. You may need to install an extension or upgrade to a newer version."); + } + } + + // Pull from Value so we also support object typed generic params. + private protected virtual PgConverterResolution ResolveConverter(PgTypeInfo typeInfo) + { + _asObject = true; + return typeInfo.GetObjectResolution(Value); + } + + /// Bind the current value to the type info, truncate (if applicable), take its size, and do any final validation before writing. + internal void Bind(out DataFormat format, out Size size, DataFormat? requiredFormat = null) + { + if (TypeInfo is null) + ThrowHelper.ThrowInvalidOperationException($"Missing type info, {nameof(ResolveTypeInfo)} needs to be called before {nameof(Bind)}."); + + if (!TypeInfo.SupportsWriting) + ThrowHelper.ThrowNotSupportedException($"Cannot write values for parameters of type '{TypeInfo.Type}' and postgres type '{TypeInfo.Options.DatabaseInfo.GetDataTypeName(PgTypeId).DisplayName}'."); + + // We might call this twice, once during validation and once during WriteBind, only compute things once. + if (WriteSize is null) + { + if (_size > 0) + HandleSizeTruncation(); + + BindCore(requiredFormat); } - /// - /// Used to specify which PostgreSQL type will be sent to the database for this parameter. - /// - public string? DataTypeName + format = Format; + size = WriteSize!.Value; + if (requiredFormat is not null && format != requiredFormat) + ThrowHelper.ThrowNotSupportedException($"Parameter '{ParameterName}' must be written in {requiredFormat} format, but does not support this format."); + + // Handle Size truncate behavior for a predetermined set of types and pg types. + // Doesn't matter if we 'box' Value, all supported types are reference types. + [MethodImpl(MethodImplOptions.NoInlining)] + void HandleSizeTruncation() { - get + var type = Converter!.TypeToConvert; + if ((type != typeof(string) && type != typeof(char[]) && type != typeof(byte[]) && type != typeof(Stream)) || Value is not { } value) + return; + + var dataTypeName = TypeInfo!.Options.GetDataTypeName(PgTypeId); + if (dataTypeName == DataTypeNames.Text || dataTypeName == DataTypeNames.Varchar || dataTypeName == DataTypeNames.Bpchar) { - if (_dataTypeName != null) - return _dataTypeName; - throw new NotImplementedException("Infer from others"); + if (value is string s && s.Length > _size) + Value = s.Substring(0, _size); + else if (value is char[] chars && chars.Length > _size) + { + var truncated = new char[_size]; + Array.Copy(chars, truncated, _size); + Value = truncated; + } } - set + else if (dataTypeName == DataTypeNames.Bytea) { - _dataTypeName = value; - Handler = null; + if (value is byte[] bytes && bytes.Length > _size) + { + var truncated = new byte[_size]; + Array.Copy(bytes, truncated, _size); + Value = truncated; + } + else if (value is Stream) + { + _asObject = true; + _useSubStream = true; + } } } + } - #endregion Type + private protected virtual void BindCore(DataFormat? formatPreference, bool allowNullReference = false) + { + // Pull from Value so we also support object typed generic params. + var value = Value; + if (value is null && !allowNullReference) + ThrowHelper.ThrowInvalidOperationException($"Parameter '{ParameterName}' cannot be null, DBNull.Value should be used instead."); - #region Other Properties + if (_useSubStream && value is not null) + value = _subStream = new SubReadStream((Stream)value, _size); - /// - public sealed override bool IsNullable { get; set; } + if (TypeInfo!.BindObject(Converter!, value, out var size, out _writeState, out var dataFormat, formatPreference) is { } info) + { + WriteSize = size; + _bufferRequirement = info.BufferRequirement; + } + else + { + WriteSize = -1; + _bufferRequirement = default; + } - /// - [DefaultValue(ParameterDirection.Input)] - [Category("Data")] - public sealed override ParameterDirection Direction { get; set; } + Format = dataFormat; + } -#pragma warning disable CS0109 - /// - /// Gets or sets the maximum number of digits used to represent the - /// Value property. - /// - /// The maximum number of digits used to represent the - /// Value property. - /// The default value is 0, which indicates that the data provider - /// sets the precision for Value. - [DefaultValue((byte)0)] - [Category("Data")] - public new byte Precision + internal async ValueTask Write(bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (WriteSize is not { } writeSize) { - get => _precision; - set - { - _precision = value; - Handler = null; - } + ThrowHelper.ThrowInvalidOperationException("Missing type info or binding info."); + return; } - /// - /// Gets or sets the number of decimal places to which - /// Value is resolved. - /// - /// The number of decimal places to which - /// Value is resolved. The default is 0. - [DefaultValue((byte)0)] - [Category("Data")] - public new byte Scale + try { - get => _scale; - set + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + writer.WriteInt32(writeSize.Value); + if (writeSize.Value is -1) { - _scale = value; - Handler = null; + writer.Commit(sizeof(int)); + return; } - } -#pragma warning restore CS0109 - /// - [DefaultValue(0)] - [Category("Data")] - public sealed override int Size - { - get => _size; - set + var current = new ValueMetadata { - if (value < -1) - throw new ArgumentException($"Invalid parameter Size value '{value}'. The value must be greater than or equal to 0."); - - _size = value; - Handler = null; - } + Format = Format, + BufferRequirement = _bufferRequirement, + Size = writeSize, + WriteState = _writeState + }; + await writer.BeginWrite(async, current, cancellationToken).ConfigureAwait(false); + await WriteValue(async, writer, cancellationToken).ConfigureAwait(false); + writer.Commit(writeSize.Value + sizeof(int)); } - - /// - [AllowNull, DefaultValue("")] - [Category("Data")] - public sealed override string SourceColumn + finally { - get => _sourceColumn; - set => _sourceColumn = value ?? string.Empty; + ResetBindingInfo(); } + } - /// - [Category("Data"), DefaultValue(DataRowVersion.Current)] - public sealed override DataRowVersion SourceVersion { get; set; } - - /// - public sealed override bool SourceColumnNullMapping { get; set; } - -#pragma warning disable CA2227 - /// - /// The collection to which this parameter belongs, if any. - /// - public NpgsqlParameterCollection? Collection { get; set; } -#pragma warning restore CA2227 + private protected virtual ValueTask WriteValue(bool async, PgWriter writer, CancellationToken cancellationToken) + { + // Pull from Value so we also support base calls from generic parameters. + var value = (_useSubStream ? _subStream : Value)!; + if (async) + return Converter!.WriteAsObjectAsync(writer, value, cancellationToken); - /// - /// The PostgreSQL data type, such as int4 or text, as discovered from pg_type. - /// This property is automatically set if parameters have been derived via - /// and can be used to - /// acquire additional information about the parameters' data type. - /// - public PostgresType? PostgresType { get; internal set; } + Converter!.WriteAsObject(writer, value); + return new(); + } - #endregion Other Properties + /// + public override void ResetDbType() + { + _npgsqlDbType = null; + _dataTypeName = null; + ResetTypeInfo(); + } - #region Internals + private protected void ResetTypeInfo() + { + TypeInfo = null; + _asObject = false; + Converter = null; + PgTypeId = default; + ResetBindingInfo(); + } - internal virtual void ResolveHandler(ConnectorTypeMapper typeMapper) + private protected void ResetBindingInfo() + { + if (WriteSize is null) { - if (Handler != null) - return; - - if (_npgsqlDbType.HasValue) - Handler = typeMapper.GetByNpgsqlDbType(_npgsqlDbType.Value); - else if (_dataTypeName != null) - Handler = typeMapper.GetByDataTypeName(_dataTypeName); - else if (_value != null) - Handler = typeMapper.GetByClrType(_value.GetType()); - else - throw new InvalidOperationException($"Parameter '{ParameterName}' must have its value set"); + Debug.Assert(_writeState == default && _useSubStream == default && Format == default && _bufferRequirement == default); + return; } - internal void Bind(ConnectorTypeMapper typeMapper) + if (_writeState is not null) { - ResolveHandler(typeMapper); - FormatCode = Handler!.PreferTextWrite ? FormatCode.Text : FormatCode.Binary; + TypeInfo?.DisposeWriteState(_writeState); + _writeState = null; } - - internal virtual int ValidateAndGetLength() + if (_useSubStream) { - if (_value == null) - throw new InvalidCastException($"Parameter {ParameterName} must be set"); - if (_value is DBNull) - return 0; - - var lengthCache = LengthCache; - var len = Handler!.ValidateObjectAndGetLength(_value, ref lengthCache, this); - LengthCache = lengthCache; - return len; + _useSubStream = false; + _subStream?.Dispose(); + _subStream = null; } + WriteSize = null; + Format = default; + _bufferRequirement = default; + } - internal virtual Task WriteWithLength(NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - => Handler!.WriteObjectWithLength(_value!, buf, LengthCache, this, async, cancellationToken); - - /// - public override void ResetDbType() - { - _cachedDbType = null; - _npgsqlDbType = null; - _dataTypeName = null; - Handler = null; - } + internal bool IsInputDirection => Direction == ParameterDirection.InputOutput || Direction == ParameterDirection.Input; - internal bool IsInputDirection => Direction == ParameterDirection.InputOutput || Direction == ParameterDirection.Input; + internal bool IsOutputDirection => Direction == ParameterDirection.InputOutput || Direction == ParameterDirection.Output; - internal bool IsOutputDirection => Direction == ParameterDirection.InputOutput || Direction == ParameterDirection.Output; + #endregion - #endregion + #region Clone - #region Clone + /// + /// Creates a new that is a copy of the current instance. + /// + /// A new that is a copy of this instance. + public NpgsqlParameter Clone() => CloneCore(); - /// - /// Creates a new NpgsqlParameter that - /// is a copy of the current instance. - /// - /// A new NpgsqlParameter that is a copy of this instance. - public NpgsqlParameter Clone() + private protected virtual NpgsqlParameter CloneCore() => + // use fields instead of properties + // to avoid auto-initializing something like type_info + new() { - // use fields instead of properties - // to avoid auto-initializing something like type_info - var clone = new NpgsqlParameter - { - _precision = _precision, - _scale = _scale, - _size = _size, - _cachedDbType = _cachedDbType, - _npgsqlDbType = _npgsqlDbType, - Direction = Direction, - IsNullable = IsNullable, - _name = _name, - TrimmedName = TrimmedName, - SourceColumn = SourceColumn, - SourceVersion = SourceVersion, - _value = _value, - SourceColumnNullMapping = SourceColumnNullMapping, - }; - return clone; - } - - object ICloneable.Clone() => Clone(); - - #endregion - } + _precision = _precision, + _scale = _scale, + _size = _size, + _npgsqlDbType = _npgsqlDbType, + _dataTypeName = _dataTypeName, + Direction = Direction, + IsNullable = IsNullable, + _name = _name, + TrimmedName = TrimmedName, + SourceColumn = SourceColumn, + SourceVersion = SourceVersion, + _value = _value, + SourceColumnNullMapping = SourceColumnNullMapping, + }; + + object ICloneable.Clone() => Clone(); + + #endregion } diff --git a/src/Npgsql/NpgsqlParameterCollection.cs b/src/Npgsql/NpgsqlParameterCollection.cs index 58182cf397..a10f9dceb0 100644 --- a/src/Npgsql/NpgsqlParameterCollection.cs +++ b/src/Npgsql/NpgsqlParameterCollection.cs @@ -1,561 +1,788 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using JetBrains.Annotations; -using Npgsql.Util; +using Npgsql.Internal; using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +/// +/// Represents a collection of parameters relevant to a as well as their respective mappings to columns in +/// a . +/// +public sealed class NpgsqlParameterCollection : DbParameterCollection, IList { + internal const int LookupThreshold = 5; + + internal List InternalList { get; } = new(5); +#if DEBUG + internal static bool TwoPassCompatMode; +#else + internal static readonly bool TwoPassCompatMode; +#endif + + static NpgsqlParameterCollection() + => TwoPassCompatMode = AppContext.TryGetSwitch("Npgsql.EnableLegacyCaseInsensitiveDbParameters", out var enabled) + && enabled; + + // Dictionary lookups for GetValue to improve performance. _caseSensitiveLookup is only ever used in legacy two-pass mode. + Dictionary? _caseInsensitiveLookup; + Dictionary? _caseSensitiveLookup; + /// - /// Represents a collection of parameters relevant to a NpgsqlCommand - /// as well as their respective mappings to columns in a DataSet. - /// This class cannot be inherited. + /// Initializes a new instance of the NpgsqlParameterCollection class. /// - public sealed class NpgsqlParameterCollection : DbParameterCollection, IList + internal NpgsqlParameterCollection() { } + + bool LookupEnabled => InternalList.Count >= LookupThreshold; + + void LookupClear() { - readonly List _internalList = new List(5); - - // Dictionary lookups for GetValue to improve performance - Dictionary? _lookup; - Dictionary? _lookupIgnoreCase; - - /// - /// Initializes a new instance of the NpgsqlParameterCollection class. - /// - internal NpgsqlParameterCollection() => InvalidateHashLookups(); - - /// - /// Invalidate the hash lookup tables. This should be done any time a change - /// may throw the lookups out of sync with the list. - /// - internal void InvalidateHashLookups() - { - _lookup = null; - _lookupIgnoreCase = null; - } + _caseInsensitiveLookup?.Clear(); + _caseSensitiveLookup?.Clear(); + } + + void LookupAdd(string name, int index) + { + if (_caseInsensitiveLookup is null) + return; - #region NpgsqlParameterCollection Member + if (TwoPassCompatMode) + _caseSensitiveLookup!.TryAdd(name, index); - /// - /// Gets the NpgsqlParameter with the specified name. - /// - /// The name of the NpgsqlParameter to retrieve. - /// The NpgsqlParameter with the specified name, or a null reference if the parameter is not found. - public new NpgsqlParameter this[string parameterName] + _caseInsensitiveLookup.TryAdd(name, index); + } + + void LookupInsert(string name, int index) + { + if (_caseInsensitiveLookup is null) + return; + + if (TwoPassCompatMode && + (!_caseSensitiveLookup!.TryGetValue(name, out var indexCs) || index < indexCs)) { - get + for (var i = index + 1; i < InternalList.Count; i++) { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + var parameterName = InternalList[i].TrimmedName; + if (_caseSensitiveLookup.TryGetValue(parameterName, out var currentI) && currentI + 1 == i) + _caseSensitiveLookup[parameterName] = i; + } - var index = IndexOf(parameterName); - if (index == -1) - throw new ArgumentException("Parameter not found"); + _caseSensitiveLookup[name] = index; + } - return _internalList[index]; - } - set + if (!_caseInsensitiveLookup.TryGetValue(name, out var indexCi) || index < indexCi) + { + for (var i = index + 1; i < InternalList.Count; i++) { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); - if (value is null) - throw new ArgumentNullException(nameof(value)); - - var index = IndexOf(parameterName); + var parameterName = InternalList[i].TrimmedName; + if (_caseInsensitiveLookup.TryGetValue(parameterName, out var currentI) && currentI + 1 == i) + _caseInsensitiveLookup[parameterName] = i; + } - if (index == -1) - throw new ArgumentException("Parameter not found"); + _caseInsensitiveLookup[name] = index; + } + } - var oldValue = _internalList[index]; - if (oldValue.ParameterName != value.ParameterName) - InvalidateHashLookups(); + void LookupRemove(string name, int index) + { + if (_caseInsensitiveLookup is null) + return; - _internalList[index] = value; + if (TwoPassCompatMode && _caseSensitiveLookup!.Remove(name)) + { + for (var i = index; i < InternalList.Count; i++) + { + var parameterName = InternalList[i].TrimmedName; + if (_caseSensitiveLookup.TryGetValue(parameterName, out var currentI) && currentI - 1 == i) + _caseSensitiveLookup[parameterName] = i; } } - /// - /// Gets the NpgsqlParameter at the specified index. - /// - /// The zero-based index of the NpgsqlParameter to retrieve. - /// The NpgsqlParameter at the specified index. - public new NpgsqlParameter this[int index] + if (_caseInsensitiveLookup.Remove(name)) { - get => _internalList[index]; - set + for (var i = index; i < InternalList.Count; i++) + { + var parameterName = InternalList[i].TrimmedName; + if (_caseInsensitiveLookup.TryGetValue(parameterName, out var currentI) && currentI - 1 == i) + _caseInsensitiveLookup[parameterName] = i; + } + + // Fix-up the case-insensitive lookup to point to the next match, if any. + for (var i = 0; i < InternalList.Count; i++) { - if (value is null) - throw new ArgumentNullException(nameof(value)); - if (value.Collection != null) - throw new InvalidOperationException("The parameter already belongs to a collection"); + var value = InternalList[i]; + if (string.Equals(name, value.TrimmedName, StringComparison.OrdinalIgnoreCase)) + { + _caseInsensitiveLookup[value.TrimmedName] = i; + break; + } + } + } - var oldValue = _internalList[index]; + } - if (oldValue == value) - return; + void LookupChangeName(NpgsqlParameter parameter, string oldName, string oldTrimmedName, int index) + { + if (string.Equals(oldTrimmedName, parameter.TrimmedName, StringComparison.OrdinalIgnoreCase)) + return; - if (value.ParameterName != oldValue.ParameterName) - InvalidateHashLookups(); + if (oldName.Length != 0) + LookupRemove(oldTrimmedName, index); + if (!parameter.IsPositional) + LookupAdd(parameter.TrimmedName, index); + } - _internalList[index] = value; - value.Collection = this; - oldValue.Collection = null; - } - } + internal void ChangeParameterName(NpgsqlParameter parameter, string? value) + { + var oldName = parameter.ParameterName; + var oldTrimmedName = parameter.TrimmedName; + parameter.ChangeParameterName(value); + + if (_caseInsensitiveLookup is null || _caseInsensitiveLookup.Count == 0) + return; + + var index = IndexOf(parameter); + if (index == -1) // This would be weird. + return; + + LookupChangeName(parameter, oldName, oldTrimmedName, index); + } - /// - /// Adds the specified NpgsqlParameter object to the NpgsqlParameterCollection. - /// - /// The NpgsqlParameter to add to the collection. - /// The index of the new NpgsqlParameter object. - public NpgsqlParameter Add(NpgsqlParameter value) + #region NpgsqlParameterCollection Member + + /// + /// Gets the with the specified name. + /// + /// The name of the to retrieve. + /// + /// The with the specified name, or a reference if the parameter is not found. + /// + public new NpgsqlParameter this[string parameterName] + { + get + { + if (parameterName is null) + throw new ArgumentNullException(nameof(parameterName)); + + var index = IndexOf(parameterName); + if (index == -1) + throw new ArgumentException("Parameter not found"); + + return InternalList[index]; + } + set { + if (parameterName is null) + throw new ArgumentNullException(nameof(parameterName)); if (value is null) throw new ArgumentNullException(nameof(value)); - if (value.Collection != null) - throw new InvalidOperationException("The parameter already belongs to a collection"); - _internalList.Add(value); - value.Collection = this; - InvalidateHashLookups(); - return value; + var index = IndexOf(parameterName); + if (index == -1) + throw new ArgumentException("Parameter not found"); + + if (!string.Equals(parameterName, value.TrimmedName, StringComparison.OrdinalIgnoreCase)) + throw new ArgumentException("Parameter name must be a case-insensitive match with the property 'ParameterName' on the given NpgsqlParameter", nameof(parameterName)); + + var oldValue = InternalList[index]; + LookupChangeName(value, oldValue.ParameterName, oldValue.TrimmedName, index); + + InternalList[index] = value; } + } - /// - void ICollection.Add(NpgsqlParameter item) - => Add(item); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified parameter name and value. - /// - /// The name of the NpgsqlParameter. - /// The Value of the NpgsqlParameter to add to the collection. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(string parameterName, object value) - => Add(new NpgsqlParameter(parameterName, value)); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified parameter name, data type and value. - /// - /// The name of the NpgsqlParameter. - /// One of the NpgsqlDbType values. - /// The Value of the NpgsqlParameter to add to the collection. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, object value) - => Add(new NpgsqlParameter(parameterName, parameterType) { Value = value }); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified parameter name and value. - /// - /// The name of the NpgsqlParameter. - /// The Value of the NpgsqlParameter to add to the collection. - /// One of the NpgsqlDbType values. - /// The length of the column. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, int size, object value) - => Add(new NpgsqlParameter(parameterName, parameterType, size) { Value = value }); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified parameter name and value. - /// - /// The name of the NpgsqlParameter. - /// The Value of the NpgsqlParameter to add to the collection. - /// One of the NpgsqlDbType values. - /// The length of the column. - /// The name of the source column. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn, object value) - => Add(new NpgsqlParameter(parameterName, parameterType, size, sourceColumn) { Value = value }); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified value. - /// - /// The Value of the NpgsqlParameter to add to the collection. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(object value) - => Add(new NpgsqlParameter { Value = value }); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the specified data type and value. - /// - /// One of the NpgsqlDbType values. - /// The Value of the NpgsqlParameter to add to the collection. - /// The parameter that was added. - public NpgsqlParameter AddWithValue(NpgsqlDbType parameterType, object value) - => Add(new NpgsqlParameter { NpgsqlDbType = parameterType, Value = value }); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection given the parameter name and the data type. - /// - /// The name of the parameter. - /// One of the DbType values. - /// The index of the new NpgsqlParameter object. - public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType) - => Add(new NpgsqlParameter(parameterName, parameterType)); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection with the parameter name, the data type, and the column length. - /// - /// The name of the parameter. - /// One of the DbType values. - /// The length of the column. - /// The index of the new NpgsqlParameter object. - public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size) - => Add(new NpgsqlParameter(parameterName, parameterType, size)); - - /// - /// Adds a NpgsqlParameter to the NpgsqlParameterCollection with the parameter name, the data type, the column length, and the source column name. - /// - /// The name of the parameter. - /// One of the DbType values. - /// The length of the column. - /// The name of the source column. - /// The index of the new NpgsqlParameter object. - public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size, string sourceColumn) - => Add(new NpgsqlParameter(parameterName, parameterType, size, sourceColumn)); - - #endregion - - #region IDataParameterCollection Member - - /// - // ReSharper disable once ImplicitNotNullOverridesUnknownExternalMember - public override void RemoveAt(string parameterName) - => RemoveAt(IndexOf(parameterName ?? throw new ArgumentNullException(nameof(parameterName)))); - - /// - public override bool Contains(string parameterName) - => IndexOf(parameterName ?? throw new ArgumentNullException(nameof(parameterName))) != -1; - - /// - public override int IndexOf(string parameterName) + /// + /// Gets the at the specified index. + /// + /// The zero-based index of the to retrieve. + /// The at the specified index. + public new NpgsqlParameter this[int index] + { + get => InternalList[index]; + set { - if (parameterName is null) - return -1; + if (value is null) + ThrowHelper.ThrowArgumentNullException(nameof(value)); + if (value.Collection is not null) + ThrowHelper.ThrowInvalidOperationException("The parameter already belongs to a collection"); - if (parameterName.Length > 0 && (parameterName[0] == ':' || parameterName[0] == '@')) - parameterName = parameterName.Remove(0, 1); + var oldValue = InternalList[index]; - // Using a dictionary is much faster for 5 or more items - if (_internalList.Count >= 5) - { - if (_lookup == null) - { - _lookup = new Dictionary(); - for (var i = 0 ; i < _internalList.Count ; i++) - { - var item = _internalList[i]; - - // Store only the first of each distinct value - if (!_lookup.ContainsKey(item.TrimmedName)) - _lookup.Add(item.TrimmedName, i); - } - } + if (ReferenceEquals(oldValue, value)) + return; - // Try to access the case sensitive parameter name first - if (_lookup.TryGetValue(parameterName, out var retIndex)) - return retIndex; + LookupChangeName(value, oldValue.ParameterName, oldValue.TrimmedName, index); - // Case sensitive lookup failed, generate a case insensitive lookup - if (_lookupIgnoreCase == null) - { - _lookupIgnoreCase = new Dictionary(PGUtil.InvariantCaseIgnoringStringComparer); - for (var i = 0 ; i < _internalList.Count ; i++) - { - var item = _internalList[i]; - - // Store only the first of each distinct value - if (!_lookupIgnoreCase.ContainsKey(item.TrimmedName)) - _lookupIgnoreCase.Add(item.TrimmedName, i); - } - } + InternalList[index] = value; + value.Collection = this; + oldValue.Collection = null; + } + } - // Then try to access the case insensitive parameter name - if (_lookupIgnoreCase.TryGetValue(parameterName, out retIndex)) - return retIndex; + /// + /// Adds the specified object to the . + /// + /// The to add to the collection. + /// The index of the new object. + public NpgsqlParameter Add(NpgsqlParameter value) + { + if (value is null) + ThrowHelper.ThrowArgumentNullException(nameof(value)); + if (value.Collection is not null) + ThrowHelper.ThrowInvalidOperationException("The parameter already belongs to a collection"); + + InternalList.Add(value); + value.Collection = this; + if (!value.IsPositional) + LookupAdd(value.TrimmedName, InternalList.Count - 1); + return value; + } - return -1; - } + /// + void ICollection.Add(NpgsqlParameter item) + => Add(item); - // First try a case-sensitive match - for (var i = 0; i < _internalList.Count; i++) - if (parameterName == _internalList[i].TrimmedName) - return i; + /// + /// Adds a to the given the specified parameter name and + /// value. + /// + /// The name of the . + /// The value of the to add to the collection. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(string parameterName, object value) + => Add(new NpgsqlParameter(parameterName, value)); - // If not fond, try a case-insensitive match - for (var i = 0; i < _internalList.Count; i++) - if (string.Equals(parameterName, _internalList[i].TrimmedName, StringComparison.OrdinalIgnoreCase)) - return i; + /// + /// Adds a to the given the specified parameter name, + /// data type and value. + /// + /// The name of the . + /// One of the NpgsqlDbType values. + /// The value of the to add to the collection. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, object value) + => Add(new NpgsqlParameter(parameterName, parameterType) { Value = value }); - return -1; - } + /// + /// Adds a to the given the specified parameter name and + /// value. + /// + /// The name of the . + /// The value of the to add to the collection. + /// One of the values. + /// The length of the column. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, int size, object value) + => Add(new NpgsqlParameter(parameterName, parameterType, size) { Value = value }); - #endregion + /// + /// Adds a to the given the specified parameter name and + /// value. + /// + /// The name of the . + /// The value of the to add to the collection. + /// One of the values. + /// The length of the column. + /// The name of the source column. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(string parameterName, NpgsqlDbType parameterType, int size, string? sourceColumn, object value) + => Add(new NpgsqlParameter(parameterName, parameterType, size, sourceColumn) { Value = value }); - #region IList Member + /// + /// Adds a to the given the specified value. + /// + /// The value of the to add to the collection. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(object value) + => Add(new NpgsqlParameter { Value = value }); - /// - public override bool IsReadOnly => false; + /// + /// Adds a to the given the specified data type and value. + /// + /// One of the values. + /// The value of the to add to the collection. + /// The parameter that was added. + public NpgsqlParameter AddWithValue(NpgsqlDbType parameterType, object value) + => Add(new NpgsqlParameter { NpgsqlDbType = parameterType, Value = value }); - /// - /// Removes the specified NpgsqlParameter from the collection using a specific index. - /// - /// The zero-based index of the parameter. - public override void RemoveAt(int index) - { - if (_internalList.Count - 1 < index) - throw new ArgumentOutOfRangeException(nameof(index)); + /// + /// Adds a to the given the parameter name and the data type. + /// + /// The name of the parameter. + /// One of the values. + /// The index of the new object. + public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType) + => Add(new NpgsqlParameter(parameterName, parameterType)); - Remove(_internalList[index]); - } + /// + /// Adds a to the with the parameter name, the data type, + /// and the column length. + /// + /// The name of the parameter. + /// One of the values. + /// The length of the column. + /// The index of the new object. + public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size) + => Add(new NpgsqlParameter(parameterName, parameterType, size)); - /// - public override void Insert(int index, object value) - => Insert(index, Cast(value)); + /// + /// Adds a to the with the parameter name, the data type, the + /// column length, and the source column name. + /// + /// The name of the parameter. + /// One of the values. + /// The length of the column. + /// The name of the source column. + /// The index of the new object. + public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size, string sourceColumn) + => Add(new NpgsqlParameter(parameterName, parameterType, size, sourceColumn)); - /// - /// Removes the specified NpgsqlParameter from the collection. - /// - /// The name of the NpgsqlParameter to remove from the collection. - public void Remove(string parameterName) - { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + #endregion - var index = IndexOf(parameterName); - if (index < 0) - throw new InvalidOperationException("No parameter with the specified name exists in the collection"); + #region IDataParameterCollection Member - RemoveAt(index); - } + /// + // ReSharper disable once ImplicitNotNullOverridesUnknownExternalMember + public override void RemoveAt(string parameterName) + => RemoveAt(IndexOf(parameterName ?? throw new ArgumentNullException(nameof(parameterName)))); + + /// + public override bool Contains(string parameterName) + => IndexOf(parameterName ?? throw new ArgumentNullException(nameof(parameterName))) != -1; - /// - /// Removes the specified NpgsqlParameter from the collection. - /// - /// The NpgsqlParameter to remove from the collection. - public override void Remove(object value) - => Remove(Cast(value)); - - /// - public override bool Contains(object value) - => value is NpgsqlParameter param && _internalList.Contains(param); - - /// - /// Gets a value indicating whether a NpgsqlParameter with the specified parameter name exists in the collection. - /// - /// The name of the NpgsqlParameter object to find. - /// A reference to the requested parameter is returned in this out param if it is found in the list. This value is null if the parameter is not found. - /// true if the collection contains the parameter and param will contain the parameter; otherwise, false. - public bool TryGetValue(string parameterName, [NotNullWhen(true)] out NpgsqlParameter? parameter) + /// + public override int IndexOf(string parameterName) + { + if (parameterName is null) + return -1; + + if (parameterName.Length > 0 && (parameterName[0] == ':' || parameterName[0] == '@')) + parameterName = parameterName.Remove(0, 1); + + // Using a dictionary is always faster after around 10 items when matched against reference equality. + // For string equality this is the case after ~3 items so we take a decent compromise going with 5. + if (LookupEnabled && parameterName.Length != 0) { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + if (_caseInsensitiveLookup is null) + BuildLookup(); - var index = IndexOf(parameterName); + if (TwoPassCompatMode && _caseSensitiveLookup!.TryGetValue(parameterName, out var indexCs)) + return indexCs; + + if (_caseInsensitiveLookup!.TryGetValue(parameterName, out var indexCi)) + return indexCi; + + return -1; + } - if (index != -1) + // Start with case-sensitive search in two pass mode. + if (TwoPassCompatMode) + { + for (var i = 0; i < InternalList.Count; i++) { - parameter = _internalList[index]; - return true; + var name = InternalList[i].TrimmedName; + if (string.Equals(parameterName, name)) + return i; } + } - parameter = null; - return false; + // Then do case-insensitive search. + for (var i = 0; i < InternalList.Count; i++) + { + var name = InternalList[i].TrimmedName; + if (ReferenceEquals(parameterName, name) || string.Equals(parameterName, name, StringComparison.OrdinalIgnoreCase)) + return i; } - /// - /// Removes all items from the collection. - /// - public override void Clear() + return -1; + + void BuildLookup() { - // clean up parameters so they can be added to another command if required. - foreach (var toRemove in _internalList) - toRemove.Collection = null; + if (TwoPassCompatMode) + _caseSensitiveLookup = new Dictionary(InternalList.Count); + + _caseInsensitiveLookup = new Dictionary(InternalList.Count, StringComparer.OrdinalIgnoreCase); - _internalList.Clear(); - InvalidateHashLookups(); + for (var i = 0; i < InternalList.Count; i++) + { + var item = InternalList[i]; + if (!item.IsPositional) + LookupAdd(item.TrimmedName, i); + } } + } + + #endregion + + #region IList Member + + /// + public override bool IsReadOnly => false; + + /// + /// Removes the specified from the collection using a specific index. + /// + /// The zero-based index of the parameter. + public override void RemoveAt(int index) + { + if (InternalList.Count - 1 < index) + throw new ArgumentOutOfRangeException(nameof(index)); + + Remove(InternalList[index]); + } + + /// + public override void Insert(int index, object value) + => Insert(index, Cast(value)); + + /// + /// Removes the specified from the collection. + /// + /// The name of the to remove from the collection. + public void Remove(string parameterName) + { + if (parameterName is null) + ThrowHelper.ThrowArgumentNullException(nameof(parameterName)); - /// - public override int IndexOf(object value) - => IndexOf(Cast(value)); + var index = IndexOf(parameterName); + if (index < 0) + ThrowHelper.ThrowInvalidOperationException("No parameter with the specified name exists in the collection"); - /// - public override int Add(object value) + RemoveAt(index); + } + + /// + /// Removes the specified from the collection. + /// + /// The to remove from the collection. + public override void Remove(object value) + => Remove(Cast(value)); + + /// + public override bool Contains(object value) + => value is NpgsqlParameter param && InternalList.Contains(param); + + /// + /// Gets a value indicating whether a with the specified parameter name exists in the collection. + /// + /// The name of the object to find. + /// + /// A reference to the requested parameter is returned in this out param if it is found in the list. + /// This value is if the parameter is not found. + /// + /// + /// if the collection contains the parameter and param will contain the parameter; + /// otherwise, . + /// + public bool TryGetValue(string parameterName, [NotNullWhen(true)] out NpgsqlParameter? parameter) + { + if (parameterName is null) + throw new ArgumentNullException(nameof(parameterName)); + + var index = IndexOf(parameterName); + + if (index != -1) { - Add(Cast(value)); - return Count - 1; + parameter = InternalList[index]; + return true; } - /// - public override bool IsFixedSize => false; + parameter = null; + return false; + } - #endregion + /// + /// Removes all items from the collection. + /// + public override void Clear() + { + // clean up parameters so they can be added to another command if required. + foreach (var toRemove in InternalList) + toRemove.Collection = null; - #region ICollection Member + InternalList.Clear(); + LookupClear(); + } - /// - public override bool IsSynchronized => (_internalList as ICollection).IsSynchronized; + /// + public override int IndexOf(object value) + => IndexOf(Cast(value)); - /// - /// Gets the number of NpgsqlParameter objects in the collection. - /// - /// The number of NpgsqlParameter objects in the collection. - public override int Count => _internalList.Count; + /// + public override int Add(object value) + { + Add(Cast(value)); + return Count - 1; + } - /// - public override void CopyTo(Array array, int index) - => ((ICollection)_internalList).CopyTo(array, index); + /// + public override bool IsFixedSize => false; - /// - bool ICollection.IsReadOnly => false; + #endregion - /// - public override object SyncRoot => ((ICollection)_internalList).SyncRoot; + #region ICollection Member - #endregion + /// + public override bool IsSynchronized => (InternalList as ICollection).IsSynchronized; - #region IEnumerable Member + /// + /// Gets the number of objects in the collection. + /// + /// The number of objects in the collection. + public override int Count => InternalList.Count; - IEnumerator IEnumerable.GetEnumerator() - => _internalList.GetEnumerator(); + /// + public override void CopyTo(Array array, int index) + => ((ICollection)InternalList).CopyTo(array, index); - /// - public override IEnumerator GetEnumerator() => _internalList.GetEnumerator(); + /// + bool ICollection.IsReadOnly => false; - #endregion + /// + public override object SyncRoot => ((ICollection)InternalList).SyncRoot; - /// - public override void AddRange(Array values) - { - if (values is null) - throw new ArgumentNullException(nameof(values)); + #endregion - foreach (var parameter in values) - Add(Cast(parameter) ?? throw new ArgumentException("Collection contains a null value.", nameof(values))); - } + #region IEnumerable Member + + IEnumerator IEnumerable.GetEnumerator() + => InternalList.GetEnumerator(); + + /// + public override IEnumerator GetEnumerator() => InternalList.GetEnumerator(); + + #endregion + + /// + public override void AddRange(Array values) + { + if (values is null) + throw new ArgumentNullException(nameof(values)); + + foreach (var parameter in values) + Add(Cast(parameter)); + } - /// - protected override DbParameter GetParameter(string parameterName) - => this[parameterName]; - - /// - protected override DbParameter GetParameter(int index) - => this[index]; - - /// - protected override void SetParameter(string parameterName, DbParameter value) - => this[parameterName] = Cast(value); - - /// - protected override void SetParameter(int index, DbParameter value) - => this[index] = Cast(value); - - /// - /// Report the offset within the collection of the given parameter. - /// - /// Parameter to find. - /// Index of the parameter, or -1 if the parameter is not present. - public int IndexOf(NpgsqlParameter item) - => _internalList.IndexOf(item); - - /// - /// Insert the specified parameter into the collection. - /// - /// Index of the existing parameter before which to insert the new one. - /// Parameter to insert. - public void Insert(int index, NpgsqlParameter item) + /// + protected override DbParameter GetParameter(string parameterName) + => this[parameterName]; + + /// + protected override DbParameter GetParameter(int index) + => this[index]; + + /// + protected override void SetParameter(string parameterName, DbParameter value) + => this[parameterName] = Cast(value); + + /// + protected override void SetParameter(int index, DbParameter value) + => this[index] = Cast(value); + + /// + /// Report the offset within the collection of the given parameter. + /// + /// Parameter to find. + /// Index of the parameter, or -1 if the parameter is not present. + public int IndexOf(NpgsqlParameter item) + => InternalList.IndexOf(item); + + /// + /// Insert the specified parameter into the collection. + /// + /// Index of the existing parameter before which to insert the new one. + /// Parameter to insert. + public void Insert(int index, NpgsqlParameter item) + { + if (item is null) + throw new ArgumentNullException(nameof(item)); + if (item.Collection != null) + throw new Exception("The parameter already belongs to a collection"); + + InternalList.Insert(index, item); + item.Collection = this; + if (!item.IsPositional) + LookupInsert(item.TrimmedName, index); + } + + /// + /// Report whether the specified parameter is present in the collection. + /// + /// Parameter to find. + /// True if the parameter was found, otherwise false. + public bool Contains(NpgsqlParameter item) => InternalList.Contains(item); + + /// + /// Remove the specified parameter from the collection. + /// + /// Parameter to remove. + /// True if the parameter was found and removed, otherwise false. + public bool Remove(NpgsqlParameter item) + { + if (item == null) + ThrowHelper.ThrowArgumentNullException(nameof(item)); + if (item.Collection != this) + ThrowHelper.ThrowInvalidOperationException("The item does not belong to this collection"); + + var index = IndexOf(item); + if (index >= 0) { - if (item is null) - throw new ArgumentNullException(nameof(item)); - if (item.Collection != null) - throw new Exception("The parameter already belongs to a collection"); - - _internalList.Insert(index, item); - item.Collection = this; - InvalidateHashLookups(); + InternalList.RemoveAt(index); + if (!LookupEnabled) + LookupClear(); + if (!item.IsPositional) + LookupRemove(item.TrimmedName, index); + item.Collection = null; + return true; } - /// - /// Report whether the specified parameter is present in the collection. - /// - /// Parameter to find. - /// True if the parameter was found, otherwise false. - public bool Contains(NpgsqlParameter item) => _internalList.Contains(item); - - /// - /// Remove the specified parameter from the collection. - /// - /// Parameter to remove. - /// True if the parameter was found and removed, otherwise false. - public bool Remove(NpgsqlParameter item) - { - if (item == null) - throw new ArgumentNullException(nameof(item)); - if (item.Collection != this) - throw new InvalidOperationException("The item does not belong to this collection"); + return false; + } - if (_internalList.Remove(item)) - { - item.Collection = null; - InvalidateHashLookups(); - return true; - } + /// + /// Convert collection to a System.Array. + /// + /// Destination array. + /// Starting index in destination array. + public void CopyTo(NpgsqlParameter[] array, int arrayIndex) + => InternalList.CopyTo(array, arrayIndex); - return false; + /// + /// Convert collection to a System.Array. + /// + /// NpgsqlParameter[] + public NpgsqlParameter[] ToArray() => InternalList.ToArray(); + + internal void CloneTo(NpgsqlParameterCollection other) + { + other.InternalList.Clear(); + foreach (var param in InternalList) + { + var newParam = param.Clone(); + newParam.Collection = this; + other.InternalList.Add(newParam); } - /// - /// Convert collection to a System.Array. - /// - /// Destination array. - /// Starting index in destination array. - public void CopyTo(NpgsqlParameter[] array, int arrayIndex) - => _internalList.CopyTo(array, arrayIndex); - - /// - /// Convert collection to a System.Array. - /// - /// NpgsqlParameter[] - public NpgsqlParameter[] ToArray() => _internalList.ToArray(); - - internal void CloneTo(NpgsqlParameterCollection other) + if (LookupEnabled && _caseInsensitiveLookup is not null) { - other._internalList.Clear(); - foreach (var param in _internalList) + other._caseInsensitiveLookup = new Dictionary(_caseInsensitiveLookup, StringComparer.OrdinalIgnoreCase); + if (TwoPassCompatMode) { - var newParam = param.Clone(); - newParam.Collection = this; - other._internalList.Add(newParam); + Debug.Assert(_caseSensitiveLookup is not null); + other._caseSensitiveLookup = new Dictionary(_caseSensitiveLookup); } - other._lookup = _lookup; - other._lookupIgnoreCase = _lookupIgnoreCase; } + } + + internal void ProcessParameters(PgSerializerOptions options, bool validateValues, CommandType commandType) + { + HasOutputParameters = false; + PlaceholderType = PlaceholderType.NoParameters; - internal bool HasOutputParameters + var list = InternalList; + for (var i = 0; i < list.Count; i++) { - get + var p = list[i]; + + switch (PlaceholderType) { - foreach (var p in _internalList) - if (p.IsOutputDirection) - return true; - return false; + case PlaceholderType.NoParameters: + PlaceholderType = p.IsPositional ? PlaceholderType.Positional : PlaceholderType.Named; + break; + case PlaceholderType.Named: + if (p.IsPositional) + PlaceholderType = PlaceholderType.Mixed; + break; + case PlaceholderType.Positional: + if (!p.IsPositional) + PlaceholderType = PlaceholderType.Mixed; + break; + case PlaceholderType.Mixed: + break; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(PlaceholderType), $"Unknown {nameof(PlaceholderType)} value: {{0}}", PlaceholderType); + break; } - } - static NpgsqlParameter Cast(object? value) - { - try + switch (p.Direction) { - return (NpgsqlParameter)value!; + case ParameterDirection.Input: + break; + + case ParameterDirection.InputOutput: + if (PlaceholderType == PlaceholderType.Positional && commandType != CommandType.StoredProcedure) + ThrowHelper.ThrowNotSupportedException("Output parameters are not supported in positional mode (unless used with CommandType.StoredProcedure)"); + HasOutputParameters = true; + break; + + case ParameterDirection.Output: + if (PlaceholderType == PlaceholderType.Positional && commandType != CommandType.StoredProcedure) + ThrowHelper.ThrowNotSupportedException("Output parameters are not supported in positional mode (unless used with CommandType.StoredProcedure)"); + HasOutputParameters = true; + continue; + + case ParameterDirection.ReturnValue: + // Simply ignored + continue; + + default: + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(ParameterDirection), + $"Unhandled {nameof(ParameterDirection)} value: {{0}}", p.Direction); + break; } - catch (Exception) + + p.ResolveTypeInfo(options); + + if (validateValues) { - throw new InvalidCastException($"The value \"{value}\" is not of type \"{nameof(NpgsqlParameter)}\" and cannot be used in this parameter collection."); + p.Bind(out _, out _); } } } + + internal bool HasOutputParameters { get; set; } + internal PlaceholderType PlaceholderType { get; set; } + + static NpgsqlParameter Cast(object? value) + { + var castedValue = value as NpgsqlParameter; + if (castedValue is null) + ThrowInvalidCastException(value); + + return castedValue; + } + + [DoesNotReturn] + static void ThrowInvalidCastException(object? value) => + throw new InvalidCastException( + $"The value \"{value}\" is not of type \"{nameof(NpgsqlParameter)}\" and cannot be used in this parameter collection."); +} + +enum PlaceholderType +{ + /// + /// The parameter collection includes no parameters. + /// + NoParameters, + + /// + /// The parameter collection includes only named parameters. + /// + Named, + + /// + /// The parameter collection includes only positional parameters. + /// + Positional, + + /// + /// The parameter collection includes both named and positional parameters. + /// This is only supported when is set to . + /// + Mixed } diff --git a/src/Npgsql/NpgsqlParameter`.cs b/src/Npgsql/NpgsqlParameter`.cs index e0ee4be70f..e50618a510 100644 --- a/src/Npgsql/NpgsqlParameter`.cs +++ b/src/Npgsql/NpgsqlParameter`.cs @@ -1,101 +1,149 @@ using System; using System.Data; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Npgsql.TypeMapping; +using Npgsql.Internal; using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +/// +/// A generic version of which provides more type safety and +/// avoids boxing of value types. Use instead of . +/// +/// The type of the value that will be stored in the parameter. +public sealed class NpgsqlParameter : NpgsqlParameter { + T? _typedValue; + /// - /// A generic version of which provides more type safety and - /// avoids boxing of value types. Use instead of . + /// Gets or sets the strongly-typed value of the parameter. /// - /// The type of the value that will be stored in the parameter. - public sealed class NpgsqlParameter : NpgsqlParameter + public T? TypedValue { - /// - /// Gets or sets the strongly-typed value of the parameter. - /// - [MaybeNull, AllowNull] - public T TypedValue { get; set; } = default!; - - /// - /// Gets or sets the value of the parameter. This delegates to . - /// - public override object? Value + get => _typedValue; + set { - get => TypedValue; - set => TypedValue = (T)value!; + if (typeof(T) == typeof(object) && ShouldResetObjectTypeInfo(value)) + ResetTypeInfo(); + else + ResetBindingInfo(); + _typedValue = value; } + } - #region Constructors + /// + /// Gets or sets the value of the parameter. This delegates to . + /// + public override object? Value + { + get => TypedValue; + set => TypedValue = (T)value!; + } - /// - /// Initializes a new instance of . - /// - public NpgsqlParameter() {} + private protected override Type StaticValueType => typeof(T); - /// - /// Initializes a new instance of with a parameter name and value. - /// - public NpgsqlParameter(string parameterName, T value) - { - ParameterName = parameterName; - TypedValue = value; - } + #region Constructors + + /// + /// Initializes a new instance of . + /// + public NpgsqlParameter() { } - /// - /// Initializes a new instance of with a parameter name and type. - /// - public NpgsqlParameter(string parameterName, NpgsqlDbType npgsqlDbType) + /// + /// Initializes a new instance of with a parameter name and value. + /// + public NpgsqlParameter(string parameterName, T value) + { + ParameterName = parameterName; + TypedValue = value; + } + + /// + /// Initializes a new instance of with a parameter name and type. + /// + public NpgsqlParameter(string parameterName, NpgsqlDbType npgsqlDbType) + { + ParameterName = parameterName; + NpgsqlDbType = npgsqlDbType; + } + + /// + /// Initializes a new instance of with a parameter name and type. + /// + public NpgsqlParameter(string parameterName, DbType dbType) + { + ParameterName = parameterName; + DbType = dbType; + } + + #endregion Constructors + + private protected override PgConverterResolution ResolveConverter(PgTypeInfo typeInfo) + { + if (typeof(T) == typeof(object) || TypeInfo!.IsBoxing) + return base.ResolveConverter(typeInfo); + + _asObject = false; + return typeInfo.GetResolution(TypedValue); + } + + // We ignore allowNullReference, it's just there to control the base implementation. + private protected override void BindCore(DataFormat? formatPreference, bool allowNullReference = false) + { + if (_asObject) { - ParameterName = parameterName; - NpgsqlDbType = npgsqlDbType; + // If we're object typed we should not support null. + base.BindCore(formatPreference, typeof(T) != typeof(object)); + return; } - /// - /// Initializes a new instance of with a parameter name and type. - /// - public NpgsqlParameter(string parameterName, DbType dbType) + var value = TypedValue; + if (TypeInfo!.Bind(Converter!.UnsafeDowncast(), value, out var size, out _writeState, out var dataFormat, formatPreference) is { } info) { - ParameterName = parameterName; - DbType = dbType; + WriteSize = size; + _bufferRequirement = info.BufferRequirement; } - - #endregion Constructors - - internal override void ResolveHandler(ConnectorTypeMapper typeMapper) + else { - if (Handler != null) - return; - - // TODO: Better exceptions in case of cast failure etc. - if (_npgsqlDbType.HasValue) - Handler = typeMapper.GetByNpgsqlDbType(_npgsqlDbType.Value); - else if (_dataTypeName != null) - Handler = typeMapper.GetByDataTypeName(_dataTypeName); - else - Handler = typeMapper.GetByClrType(typeof(T)); + WriteSize = -1; + _bufferRequirement = default; } - internal override int ValidateAndGetLength() - { - if (TypedValue == null) - return 0; + Format = dataFormat; + } - // TODO: Why do it like this rather than a handler? - if (typeof(T) == typeof(DBNull)) - return 0; + private protected override ValueTask WriteValue(bool async, PgWriter writer, CancellationToken cancellationToken) + { + if (_asObject) + return base.WriteValue(async, writer, cancellationToken); - var lengthCache = LengthCache; - var len = Handler!.ValidateAndGetLength(TypedValue, ref lengthCache, this); - LengthCache = lengthCache; - return len; - } + if (async) + return Converter!.UnsafeDowncast().WriteAsync(writer, TypedValue!, cancellationToken); - internal override Task WriteWithLength(NpgsqlWriteBuffer buf, bool async, CancellationToken cancellationToken = default) - => Handler!.WriteWithLengthInternal(TypedValue, buf, LengthCache, this, async, cancellationToken); + Converter!.UnsafeDowncast().Write(writer, TypedValue!); + return new(); } + + private protected override NpgsqlParameter CloneCore() => + // use fields instead of properties + // to avoid auto-initializing something like type_info + new NpgsqlParameter + { + _precision = _precision, + _scale = _scale, + _size = _size, + _npgsqlDbType = _npgsqlDbType, + _dataTypeName = _dataTypeName, + Direction = Direction, + IsNullable = IsNullable, + _name = _name, + TrimmedName = TrimmedName, + SourceColumn = SourceColumn, + SourceVersion = SourceVersion, + TypedValue = TypedValue, + SourceColumnNullMapping = SourceColumnNullMapping, + }; } diff --git a/src/Npgsql/NpgsqlRawCopyStream.cs b/src/Npgsql/NpgsqlRawCopyStream.cs index f26f901d91..d963e411c8 100644 --- a/src/Npgsql/NpgsqlRawCopyStream.cs +++ b/src/Npgsql/NpgsqlRawCopyStream.cs @@ -3,554 +3,548 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; -using Npgsql.Logging; +using Npgsql.Internal; using static Npgsql.Util.Statics; #pragma warning disable 1591 -namespace Npgsql +namespace Npgsql; + +/// +/// Provides an API for a raw binary COPY operation, a high-performance data import/export mechanism to +/// a PostgreSQL table. Initiated by +/// +/// +/// See https://www.postgresql.org/docs/current/static/sql-copy.html. +/// +public sealed class NpgsqlRawCopyStream : Stream, ICancelable { + #region Fields and Properties + + NpgsqlConnector _connector; + NpgsqlReadBuffer _readBuf; + NpgsqlWriteBuffer _writeBuf; + + int _leftToReadInDataMsg; + bool _isDisposed, _isConsumed; + + bool _canRead; + bool _canWrite; + + internal bool IsBinary { get; private set; } + + public override bool CanWrite => _canWrite; + public override bool CanRead => _canRead; + + public override bool CanTimeout => true; + public override int WriteTimeout + { + get => (int) _writeBuf.Timeout.TotalMilliseconds; + set => _writeBuf.Timeout = TimeSpan.FromMilliseconds(value); + } + public override int ReadTimeout + { + get => (int) _readBuf.Timeout.TotalMilliseconds; + set => _readBuf.Timeout = TimeSpan.FromMilliseconds(value); + } + /// - /// Provides an API for a raw binary COPY operation, a high-performance data import/export mechanism to - /// a PostgreSQL table. Initiated by + /// The copy binary format header signature /// - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public sealed class NpgsqlRawCopyStream : Stream, ICancelable + internal static readonly byte[] BinarySignature = { - #region Fields and Properties + (byte)'P',(byte)'G',(byte)'C',(byte)'O',(byte)'P',(byte)'Y', + (byte)'\n', 255, (byte)'\r', (byte)'\n', 0 + }; - NpgsqlConnector _connector; - NpgsqlReadBuffer _readBuf; - NpgsqlWriteBuffer _writeBuf; + readonly ILogger _copyLogger; - int _leftToReadInDataMsg; - bool _isDisposed, _isConsumed; + #endregion - readonly bool _canRead; - readonly bool _canWrite; + #region Constructor / Initializer - internal bool IsBinary { get; private set; } + internal NpgsqlRawCopyStream(NpgsqlConnector connector) + { + _connector = connector; + _readBuf = connector.ReadBuffer; + _writeBuf = connector.WriteBuffer; + _copyLogger = connector.LoggingConfiguration.CopyLogger; + } - public override bool CanWrite => _canWrite; - public override bool CanRead => _canRead; + internal async Task Init(string copyCommand, bool async, CancellationToken cancellationToken = default) + { + await _connector.WriteQuery(copyCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); - public override bool CanTimeout => true; - public override int WriteTimeout - { - get => (int) _writeBuf.Timeout.TotalMilliseconds; - set => _writeBuf.Timeout = TimeSpan.FromMilliseconds(value); - } - public override int ReadTimeout - { - get => (int) _readBuf.Timeout.TotalMilliseconds; - set - { - _readBuf.Timeout = TimeSpan.FromMilliseconds(value); - // While calling the connector it will overwrite our read buffer timeout - _connector.UserTimeout = value; - } - } + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - /// - /// The copy binary format header signature - /// - internal static readonly byte[] BinarySignature = + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) { - (byte)'P',(byte)'G',(byte)'C',(byte)'O',(byte)'P',(byte)'Y', - (byte)'\n', 255, (byte)'\r', (byte)'\n', 0 - }; + case BackendMessageCode.CopyInResponse: + var copyInResponse = (CopyInResponseMessage) msg; + IsBinary = copyInResponse.IsBinary; + _canWrite = true; + _writeBuf.StartCopyMode(); + break; + case BackendMessageCode.CopyOutResponse: + var copyOutResponse = (CopyOutResponseMessage) msg; + IsBinary = copyOutResponse.IsBinary; + _canRead = true; + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); + } + } - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlRawCopyStream)); + #endregion - #endregion + #region Write - #region Constructor + public override void Write(byte[] buffer, int offset, int count) + { + ValidateArguments(buffer, offset, count); + Write(new ReadOnlySpan(buffer, offset, count)); + } - internal NpgsqlRawCopyStream(NpgsqlConnector connector, string copyCommand) - { - _connector = connector; - _readBuf = connector.ReadBuffer; - _writeBuf = connector.WriteBuffer; + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateArguments(buffer, offset, count); + return WriteAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } - _connector.WriteQuery(copyCommand); - _connector.Flush(); +#if NETSTANDARD2_0 + public void Write(ReadOnlySpan buffer) +#else + public override void Write(ReadOnlySpan buffer) +#endif + { + CheckDisposed(); + if (!CanWrite) + throw new InvalidOperationException("Stream not open for writing"); - using var registration = connector.StartNestedCancellableOperation(attemptPgCancellation: false); + if (buffer.Length == 0) { return; } - var msg = _connector.ReadMessage(async: false).GetAwaiter().GetResult(); - switch (msg.Code) - { - case BackendMessageCode.CopyInResponse: - var copyInResponse = (CopyInResponseMessage) msg; - IsBinary = copyInResponse.IsBinary; - _canWrite = true; - _writeBuf.StartCopyMode(); - break; - case BackendMessageCode.CopyOutResponse: - var copyOutResponse = (CopyOutResponseMessage) msg; - IsBinary = copyOutResponse.IsBinary; - _canRead = true; - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); - } + if (buffer.Length <= _writeBuf.WriteSpaceLeft) + { + _writeBuf.WriteBytes(buffer); + return; } - #endregion - - #region Write + // Value is too big, flush. + Flush(); - public override void Write(byte[] buffer, int offset, int count) + if (buffer.Length <= _writeBuf.WriteSpaceLeft) { - ValidateArguments(buffer, offset, count); - Write(new ReadOnlySpan(buffer, offset, count)); + _writeBuf.WriteBytes(buffer); + return; } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - ValidateArguments(buffer, offset, count); - return WriteAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } + // Value is too big even after a flush - bypass the buffer and write directly. + _writeBuf.DirectWrite(buffer); + } #if NETSTANDARD2_0 - public void Write(ReadOnlySpan buffer) + public ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) #else - public override void Write(ReadOnlySpan buffer) + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) #endif - { - CheckDisposed(); - if (!CanWrite) - throw new InvalidOperationException("Stream not open for writing"); + { + CheckDisposed(); + if (!CanWrite) + throw new InvalidOperationException("Stream not open for writing"); + cancellationToken.ThrowIfCancellationRequested(); - if (buffer.Length == 0) { return; } + return WriteAsyncInternal(buffer, cancellationToken); + + async ValueTask WriteAsyncInternal(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (buffer.Length == 0) + return; if (buffer.Length <= _writeBuf.WriteSpaceLeft) { - _writeBuf.WriteBytes(buffer); + _writeBuf.WriteBytes(buffer.Span); return; } - try - { - // Value is too big, flush. - Flush(); - - if (buffer.Length <= _writeBuf.WriteSpaceLeft) - { - _writeBuf.WriteBytes(buffer); - return; - } + // Value is too big, flush. + await FlushAsync(true, cancellationToken).ConfigureAwait(false); - // Value is too big even after a flush - bypass the buffer and write directly. - _writeBuf.DirectWrite(buffer); - } - catch (Exception e) + if (buffer.Length <= _writeBuf.WriteSpaceLeft) { - _connector.Break(e); - Cleanup(); - throw; + _writeBuf.WriteBytes(buffer.Span); + return; } - } - -#if NETSTANDARD2_0 - public ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) -#else - public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) -#endif - { - CheckDisposed(); - if (!CanWrite) - throw new InvalidOperationException("Stream not open for writing"); - cancellationToken.ThrowIfCancellationRequested(); - using (NoSynchronizationContextScope.Enter()) - return WriteAsyncInternal(); - - async ValueTask WriteAsyncInternal() - { - if (buffer.Length == 0) - return; - - if (buffer.Length <= _writeBuf.WriteSpaceLeft) - { - _writeBuf.WriteBytes(buffer.Span); - return; - } - - try - { - // Value is too big, flush. - await FlushAsync(true, cancellationToken); - - if (buffer.Length <= _writeBuf.WriteSpaceLeft) - { - _writeBuf.WriteBytes(buffer.Span); - return; - } - // Value is too big even after a flush - bypass the buffer and write directly. - await _writeBuf.DirectWrite(buffer, true, cancellationToken); - } - catch (Exception e) - { - _connector.Break(e); - Cleanup(); - throw; - } - } + // Value is too big even after a flush - bypass the buffer and write directly. + await _writeBuf.DirectWrite(buffer, true, cancellationToken).ConfigureAwait(false); } + } - public override void Flush() => FlushAsync(false).GetAwaiter().GetResult(); + public override void Flush() => FlushAsync(async: false).GetAwaiter().GetResult(); - public override Task FlushAsync(CancellationToken cancellationToken) - { - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return FlushAsync(true, cancellationToken); - } + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + return Task.FromCanceled(cancellationToken); - Task FlushAsync(bool async, CancellationToken cancellationToken = default) - { - CheckDisposed(); - return _writeBuf.Flush(async, cancellationToken); - } + return FlushAsync(async: true, cancellationToken); + } - #endregion + Task FlushAsync(bool async, CancellationToken cancellationToken = default) + { + CheckDisposed(); + return _writeBuf.Flush(async, cancellationToken); + } - #region Read + #endregion - public override int Read(byte[] buffer, int offset, int count) - { - ValidateArguments(buffer, offset, count); - return Read(new Span(buffer, offset, count)); - } + #region Read - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - ValidateArguments(buffer, offset, count); - return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } + public override int Read(byte[] buffer, int offset, int count) + { + ValidateArguments(buffer, offset, count); + return Read(new Span(buffer, offset, count)); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateArguments(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } #if NETSTANDARD2_0 - public int Read(Span span) + public int Read(Span span) #else - public override int Read(Span span) + public override int Read(Span span) #endif - { - CheckDisposed(); - if (!CanRead) - throw new InvalidOperationException("Stream not open for reading"); - - var count = ReadCore(span.Length, false).GetAwaiter().GetResult(); - if (count > 0) - _readBuf.ReadBytes(span.Slice(0, count)); - return count; - } + { + CheckDisposed(); + if (!CanRead) + throw new InvalidOperationException("Stream not open for reading"); + + var count = ReadCore(span.Length, false).GetAwaiter().GetResult(); + if (count > 0) + _readBuf.ReadBytes(span.Slice(0, count)); + return count; + } #if NETSTANDARD2_0 - public ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + public ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) #else - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) #endif + { + CheckDisposed(); + if (!CanRead) + throw new InvalidOperationException("Stream not open for reading"); + cancellationToken.ThrowIfCancellationRequested(); + + return ReadAsyncInternal(); + + async ValueTask ReadAsyncInternal() { - CheckDisposed(); - if (!CanRead) - throw new InvalidOperationException("Stream not open for reading"); - cancellationToken.ThrowIfCancellationRequested(); - using (NoSynchronizationContextScope.Enter()) - return ReadAsyncInternal(); - - async ValueTask ReadAsyncInternal() - { - var count = await ReadCore(buffer.Length, true, cancellationToken); - if (count > 0) - _readBuf.ReadBytes(buffer.Slice(0, count).Span); - return count; - } + var count = await ReadCore(buffer.Length, true, cancellationToken).ConfigureAwait(false); + if (count > 0) + _readBuf.ReadBytes(buffer.Slice(0, count).Span); + return count; } + } - async ValueTask ReadCore(int count, bool async, CancellationToken cancellationToken = default) - { - if (_isConsumed) - return 0; + async ValueTask ReadCore(int count, bool async, CancellationToken cancellationToken = default) + { + if (_isConsumed) + return 0; - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - if (_leftToReadInDataMsg == 0) + if (_leftToReadInDataMsg == 0) + { + IBackendMessage msg; + try { - IBackendMessage msg; - try - { - // We've consumed the current DataMessage (or haven't yet received the first), - // read the next message - msg = await _connector.ReadMessage(async); - } - catch - { + // We've consumed the current DataMessage (or haven't yet received the first), + // read the next message + msg = await _connector.ReadMessage(async).ConfigureAwait(false); + } + catch + { + if (!_isDisposed) Cleanup(); - throw; - } + throw; + } - switch (msg.Code) - { - case BackendMessageCode.CopyData: - _leftToReadInDataMsg = ((CopyDataMessage)msg).Length; - break; - case BackendMessageCode.CopyDone: - Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - _isConsumed = true; - return 0; - default: - throw _connector.UnexpectedMessageReceived(msg.Code); - } + switch (msg.Code) + { + case BackendMessageCode.CopyData: + _leftToReadInDataMsg = ((CopyDataMessage)msg).Length; + break; + case BackendMessageCode.CopyDone: + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + _isConsumed = true; + return 0; + default: + throw _connector.UnexpectedMessageReceived(msg.Code); } + } - Debug.Assert(_leftToReadInDataMsg > 0); + Debug.Assert(_leftToReadInDataMsg > 0); - // If our buffer is empty, read in more. Otherwise return whatever is there, even if the - // user asked for more (normal socket behavior) - if (_readBuf.ReadBytesLeft == 0) - await _readBuf.ReadMore(async); + // If our buffer is empty, read in more. Otherwise return whatever is there, even if the + // user asked for more (normal socket behavior) + if (_readBuf.ReadBytesLeft == 0) + await _readBuf.ReadMore(async).ConfigureAwait(false); - Debug.Assert(_readBuf.ReadBytesLeft > 0); + Debug.Assert(_readBuf.ReadBytesLeft > 0); - var maxCount = Math.Min(_readBuf.ReadBytesLeft, _leftToReadInDataMsg); - if (count > maxCount) - count = maxCount; + var maxCount = Math.Min(_readBuf.ReadBytesLeft, _leftToReadInDataMsg); + if (count > maxCount) + count = maxCount; - _leftToReadInDataMsg -= count; - return count; - } + _leftToReadInDataMsg -= count; + return count; + } - #endregion + #endregion - #region Cancel + #region Cancel - /// - /// Cancels and terminates an ongoing operation. Any data already written will be discarded. - /// - public void Cancel() => Cancel(false).GetAwaiter().GetResult(); + /// + /// Cancels and terminates an ongoing operation. Any data already written will be discarded. + /// + public void Cancel() => Cancel(async: false).GetAwaiter().GetResult(); - /// - /// Cancels and terminates an ongoing operation. Any data already written will be discarded. - /// - public Task CancelAsync() - { - using (NoSynchronizationContextScope.Enter()) - return Cancel(true); - } + /// + /// Cancels and terminates an ongoing operation. Any data already written will be discarded. + /// + public Task CancelAsync() => Cancel(async: true); - async Task Cancel(bool async) - { - CheckDisposed(); + async Task Cancel(bool async) + { + CheckDisposed(); - if (CanWrite) + if (CanWrite) + { + _writeBuf.EndCopyMode(); + _writeBuf.Clear(); + await _connector.WriteCopyFail(async).ConfigureAwait(false); + await _connector.Flush(async).ConfigureAwait(false); + try { - _writeBuf.EndCopyMode(); - _writeBuf.Clear(); - await _connector.WriteCopyFail(async); - await _connector.Flush(async); - try - { - var msg = await _connector.ReadMessage(async); - // The CopyFail should immediately trigger an exception from the read above. - throw _connector.Break( - new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); - } - catch (PostgresException e) - { - _connector.EndUserAction(); - Cleanup(); - - if (e.SqlState == PostgresErrorCodes.QueryCanceled) - return; - throw; - } + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + // The CopyFail should immediately trigger an exception from the read above. + throw _connector.Break( + new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); } - else + catch (PostgresException e) { - _connector.PerformPostgresCancellation(); + Cleanup(); + + if (e.SqlState != PostgresErrorCodes.QueryCanceled) + throw; } } + else + { + _connector.PerformPostgresCancellation(); + } + } + + #endregion + + #region Dispose - #endregion + protected override void Dispose(bool disposing) => DisposeAsync(disposing, false).GetAwaiter().GetResult(); - #region Dispose +#if NETSTANDARD2_0 + public ValueTask DisposeAsync() +#else + public override ValueTask DisposeAsync() +#endif + => DisposeAsync(disposing: true, async: true); - protected override void Dispose(bool disposing) => DisposeAsync(disposing, false).GetAwaiter().GetResult(); - async ValueTask DisposeAsync(bool disposing, bool async) + async ValueTask DisposeAsync(bool disposing, bool async) + { + if (_isDisposed || !disposing) + return; + + try { - if (_isDisposed || !disposing) { return; } + _connector.CurrentCopyOperation = null; - try + if (CanWrite) { - if (CanWrite) - { - await FlushAsync(async); - _writeBuf.EndCopyMode(); - await _connector.WriteCopyDone(async); - await _connector.Flush(async); - Expect(await _connector.ReadMessage(async), _connector); - Expect(await _connector.ReadMessage(async), _connector); - } - else + await FlushAsync(async).ConfigureAwait(false); + _writeBuf.EndCopyMode(); + await _connector.WriteCopyDone(async).ConfigureAwait(false); + await _connector.Flush(async).ConfigureAwait(false); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + } + else + { + if (!_isConsumed) { - if (!_isConsumed) + try { - try - { - if (_leftToReadInDataMsg > 0) - { - await _readBuf.Skip(_leftToReadInDataMsg, async); - } - _connector.SkipUntil(BackendMessageCode.ReadyForQuery); - } - catch (OperationCanceledException e) when (e.InnerException is PostgresException pg && pg.SqlState == PostgresErrorCodes.QueryCanceled) + if (_leftToReadInDataMsg > 0) { - Log.Debug($"Caught an exception while disposing the {nameof(NpgsqlRawCopyStream)}, indicating that it was cancelled.", e, _connector.Id); - } - catch (Exception e) - { - Log.Error($"Caught an exception while disposing the {nameof(NpgsqlRawCopyStream)}.", e, _connector.Id); + await _readBuf.Skip(_leftToReadInDataMsg, async).ConfigureAwait(false); } + _connector.SkipUntil(BackendMessageCode.ReadyForQuery); + } + catch (OperationCanceledException e) when (e.InnerException is PostgresException pg && pg.SqlState == PostgresErrorCodes.QueryCanceled) + { + LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); + } + catch (Exception e) + { + LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); } } } - finally - { - _connector.EndUserAction(); - Cleanup(); - } } - -#pragma warning disable CS8625 - void Cleanup() + finally { - Log.Debug("COPY operation ended", _connector.Id); - _connector.CurrentCopyOperation = null; - _connector.Connection!.EndBindingScope(ConnectorBindingScope.Copy); - _connector = null; - _readBuf = null; - _writeBuf = null; - _isDisposed = true; + Cleanup(); } + } + +#pragma warning disable CS8625 + void Cleanup() + { + Debug.Assert(!_isDisposed); + LogMessages.CopyOperationCompleted(_copyLogger, _connector.Id); + _connector.EndUserAction(); + _connector.CurrentCopyOperation = null; + _connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); + _connector = null; + _readBuf = null; + _writeBuf = null; + _isDisposed = true; + } #pragma warning restore CS8625 - void CheckDisposed() - { - if (_isDisposed) { - throw new ObjectDisposedException(GetType().FullName, "The COPY operation has already ended."); - } + void CheckDisposed() + { + if (_isDisposed) { + throw new ObjectDisposedException(nameof(NpgsqlRawCopyStream), "The COPY operation has already ended."); } + } - #endregion + #endregion - #region Unsupported + #region Unsupported - public override bool CanSeek => false; + public override bool CanSeek => false; - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } - public override void SetLength(long value) - { - throw new NotSupportedException(); - } + public override void SetLength(long value) + { + throw new NotSupportedException(); + } - public override long Length => throw new NotSupportedException(); + public override long Length => throw new NotSupportedException(); - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } - #endregion + #endregion - #region Input validation - static void ValidateArguments(byte[] buffer, int offset, int count) - { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); - } - #endregion + #region Input validation + static void ValidateArguments(byte[] buffer, int offset, int count) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0) + throw new ArgumentNullException(nameof(offset)); + if (count < 0) + throw new ArgumentNullException(nameof(count)); + if (buffer.Length - offset < count) + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } + #endregion +} + +/// +/// Writer for a text import, initiated by . +/// +/// +/// See https://www.postgresql.org/docs/current/static/sql-copy.html. +/// +public sealed class NpgsqlCopyTextWriter : StreamWriter, ICancelable +{ + internal NpgsqlCopyTextWriter(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying) + { + if (underlying.IsBinary) + throw connector.Break(new Exception("Can't use a binary copy stream for text writing")); } /// - /// Writer for a text import, initiated by . + /// Cancels and terminates an ongoing import. Any data already written will be discarded. /// - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public sealed class NpgsqlCopyTextWriter : StreamWriter, ICancelable - { - internal NpgsqlCopyTextWriter(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying) - { - if (underlying.IsBinary) - throw connector.Break(new Exception("Can't use a binary copy stream for text writing")); - } + public void Cancel() + => ((NpgsqlRawCopyStream)BaseStream).Cancel(); - /// - /// Cancels and terminates an ongoing import. Any data already written will be discarded. - /// - public void Cancel() - => ((NpgsqlRawCopyStream)BaseStream).Cancel(); + /// + /// Cancels and terminates an ongoing import. Any data already written will be discarded. + /// + public Task CancelAsync() => ((NpgsqlRawCopyStream)BaseStream).CancelAsync(); - /// - /// Cancels and terminates an ongoing import. Any data already written will be discarded. - /// - public Task CancelAsync() - { - using (NoSynchronizationContextScope.Enter()) - return ((NpgsqlRawCopyStream)BaseStream).CancelAsync(); - } +#if NETSTANDARD2_0 + public ValueTask DisposeAsync() + { + Dispose(); + return default; + } +#endif +} + +/// +/// Reader for a text export, initiated by . +/// +/// +/// See https://www.postgresql.org/docs/current/static/sql-copy.html. +/// +public sealed class NpgsqlCopyTextReader : StreamReader, ICancelable +{ + internal NpgsqlCopyTextReader(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying) + { + if (underlying.IsBinary) + throw connector.Break(new Exception("Can't use a binary copy stream for text reading")); } /// - /// Reader for a text export, initiated by . + /// Cancels and terminates an ongoing export. /// - /// - /// See https://www.postgresql.org/docs/current/static/sql-copy.html. - /// - public sealed class NpgsqlCopyTextReader : StreamReader, ICancelable - { - internal NpgsqlCopyTextReader(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying) - { - if (underlying.IsBinary) - throw connector.Break(new Exception("Can't use a binary copy stream for text reading")); - } + public void Cancel() + => ((NpgsqlRawCopyStream)BaseStream).Cancel(); - /// - /// Cancels and terminates an ongoing import. - /// - public void Cancel() - => ((NpgsqlRawCopyStream)BaseStream).Cancel(); + /// + /// Asynchronously cancels and terminates an ongoing export. + /// + public Task CancelAsync() => ((NpgsqlRawCopyStream)BaseStream).CancelAsync(); - /// - /// Cancels and terminates an ongoing import. Any data already written will be discarded. - /// - public Task CancelAsync() - { - using (NoSynchronizationContextScope.Enter()) - return ((NpgsqlRawCopyStream)BaseStream).CancelAsync(); - } + public ValueTask DisposeAsync() + { + Dispose(); + return default; } } diff --git a/src/Npgsql/NpgsqlReadBuffer.Stream.cs b/src/Npgsql/NpgsqlReadBuffer.Stream.cs deleted file mode 100644 index ddd46da7d2..0000000000 --- a/src/Npgsql/NpgsqlReadBuffer.Stream.cs +++ /dev/null @@ -1,225 +0,0 @@ -using System; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql -{ - public sealed partial class NpgsqlReadBuffer - { - internal sealed class ColumnStream : Stream - { - readonly NpgsqlConnector _connector; - readonly NpgsqlReadBuffer _buf; - int _start, _len, _read; - bool _canSeek; - bool _startCancellableOperations; - internal bool IsDisposed { get; private set; } - - internal ColumnStream(NpgsqlConnector connector, bool startCancellableOperations = true) - { - _connector = connector; - _buf = connector.ReadBuffer; - _startCancellableOperations = startCancellableOperations; - } - - internal void Init(int len, bool canSeek) - { - Debug.Assert(!canSeek || _buf.ReadBytesLeft >= len, - "Seekable stream constructed but not all data is in buffer (sequential)"); - _start = _buf.ReadPosition; - _len = len; - _read = 0; - _canSeek = canSeek; - IsDisposed = false; - } - - public override bool CanRead => true; - - public override bool CanWrite => false; - - public override bool CanSeek => _canSeek; - - public override long Length - { - get - { - CheckDisposed(); - return _len; - } - } - - public override void SetLength(long value) - => throw new NotSupportedException(); - - public override long Position - { - get - { - CheckDisposed(); - return _read; - } - set - { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), "Non - negative number required."); - Seek(_start + value, SeekOrigin.Begin); - } - } - - public override long Seek(long offset, SeekOrigin origin) - { - CheckDisposed(); - - if (!_canSeek) - throw new NotSupportedException(); - if (offset > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(offset), "Stream length must be non-negative and less than 2^31 - 1 - origin."); - - const string seekBeforeBegin = "An attempt was made to move the position before the beginning of the stream."; - - switch (origin) - { - case SeekOrigin.Begin: - { - var tempPosition = unchecked(_start + (int)offset); - if (offset < 0 || tempPosition < _start) - throw new IOException(seekBeforeBegin); - _buf.ReadPosition = _start; - return tempPosition; - } - case SeekOrigin.Current: - { - var tempPosition = unchecked(_buf.ReadPosition + (int)offset); - if (unchecked(_buf.ReadPosition + offset) < _start || tempPosition < _start) - throw new IOException(seekBeforeBegin); - _buf.ReadPosition = tempPosition; - return tempPosition; - } - case SeekOrigin.End: - { - var tempPosition = unchecked(_len + (int)offset); - if (unchecked(_len + offset) < _start || tempPosition < _start) - throw new IOException(seekBeforeBegin); - _buf.ReadPosition = tempPosition; - return tempPosition; - } - default: - throw new ArgumentOutOfRangeException(nameof(origin), "Invalid seek origin."); - } - } - - public override void Flush() - => throw new NotSupportedException(); - - public override Task FlushAsync(CancellationToken cancellationToken) - => throw new NotSupportedException(); - - public override int Read(byte[] buffer, int offset, int count) - { - ValidateArguments(buffer, offset, count); - return Read(new Span(buffer, offset, count)); - } - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - ValidateArguments(buffer, offset, count); - - using (NoSynchronizationContextScope.Enter()) - return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } - -#if NETSTANDARD2_0 - public int Read(Span span) -#else - public override int Read(Span span) -#endif - { - CheckDisposed(); - - var count = Math.Min(span.Length, _len - _read); - - if (count == 0) - return 0; - - _buf.Read(span.Slice(0, count)); - _read += count; - - return count; - } - -#if NETSTANDARD2_0 - public ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) -#else - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) -#endif - { - CheckDisposed(); - - var count = Math.Min(buffer.Length, _len - _read); - - if (count == 0) - return new ValueTask(0); - - using (NoSynchronizationContextScope.Enter()) - return ReadLong(this, buffer.Slice(0, count), cancellationToken); - - static async ValueTask ReadLong(ColumnStream stream, Memory buffer, CancellationToken cancellationToken = default) - { - using var registration = stream._startCancellableOperations - ? stream._connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false) - : default; - var read = await stream._buf.ReadAsync(buffer, cancellationToken); - stream._read += read; - return read; - } - } - - public override void Write(byte[] buffer, int offset, int count) - => throw new NotSupportedException(); - - void CheckDisposed() - { - if (IsDisposed) - throw new ObjectDisposedException(null); - } - - protected override void Dispose(bool disposing) - => DisposeAsync(disposing, async: false).GetAwaiter().GetResult(); - -#if !NETSTANDARD2_0 - public override ValueTask DisposeAsync() - => DisposeAsync(disposing: true, async: true); -#endif - - async ValueTask DisposeAsync(bool disposing, bool async) - { - if (IsDisposed || !disposing) - return; - - var leftToSkip = _len - _read; - if (leftToSkip > 0) - { - if (async) - await _buf.Skip(leftToSkip, async); - else - _buf.Skip(leftToSkip, async).GetAwaiter().GetResult(); - } - IsDisposed = true; - } - } - - static void ValidateArguments(byte[] buffer, int offset, int count) - { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); - } - } -} diff --git a/src/Npgsql/NpgsqlReadBuffer.cs b/src/Npgsql/NpgsqlReadBuffer.cs deleted file mode 100644 index 323ef62527..0000000000 --- a/src/Npgsql/NpgsqlReadBuffer.cs +++ /dev/null @@ -1,653 +0,0 @@ -using System; -using System.Buffers; -using System.Buffers.Binary; -using System.Diagnostics; -using System.IO; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Util; -using static System.Threading.Timeout; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - -namespace Npgsql -{ - /// - /// A buffer used by Npgsql to read data from the socket efficiently. - /// Provides methods which decode different values types and tracks the current position. - /// - public sealed partial class NpgsqlReadBuffer : IDisposable - { - #region Fields and Properties - - public NpgsqlConnection Connection => Connector.Connection!; - - internal readonly NpgsqlConnector Connector; - - internal Stream Underlying { private get; set; } - - readonly Socket? _underlyingSocket; - - internal ResettableCancellationTokenSource Cts { get; } - - TimeSpan _preTranslatedTimeout = TimeSpan.Zero; - - /// - /// Timeout for sync and async reads - /// - internal TimeSpan Timeout - { - get => Cts.Timeout; - set - { - if (_preTranslatedTimeout != value) - { - _preTranslatedTimeout = value; - - if (value == TimeSpan.Zero) - value = InfiniteTimeSpan; - else if (value < TimeSpan.Zero) - value = TimeSpan.Zero; - - Debug.Assert(_underlyingSocket != null); - - _underlyingSocket.ReceiveTimeout = (int)value.TotalMilliseconds; - Cts.Timeout = value; - } - } - } - - /// - /// The total byte length of the buffer. - /// - internal int Size { get; } - - internal Encoding TextEncoding { get; } - - /// - /// Same as , except that it does not throw an exception if an invalid char is - /// encountered (exception fallback), but rather replaces it with a question mark character (replacement - /// fallback). - /// - internal Encoding RelaxedTextEncoding { get; } - - internal int ReadPosition { get; set; } - internal int ReadBytesLeft => FilledBytes - ReadPosition; - - internal readonly byte[] Buffer; - internal int FilledBytes; - - ColumnStream? _columnStream; - - bool _disposed; - - /// - /// The minimum buffer size possible. - /// - internal const int MinimumSize = 4096; - internal const int DefaultSize = 8192; - - #endregion - - #region Constructors - - internal NpgsqlReadBuffer( - NpgsqlConnector connector, - Stream stream, - Socket? socket, - int size, - Encoding textEncoding, - Encoding relaxedTextEncoding) - { - if (size < MinimumSize) - { - throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); - } - - Connector = connector; - Underlying = stream; - _underlyingSocket = socket; - Cts = new ResettableCancellationTokenSource(); - Size = size; - Buffer = ArrayPool.Shared.Rent(size); - TextEncoding = textEncoding; - RelaxedTextEncoding = relaxedTextEncoding; - } - - #endregion - - #region I/O - - internal void Ensure(int count) - { - if (count <= ReadBytesLeft) - return; - Ensure(count, false).GetAwaiter().GetResult(); - } - - public Task Ensure(int count, bool async) - => Ensure(count, async, readingNotifications: false); - - public Task EnsureAsync(int count) - => Ensure(count, async: true, readingNotifications: false); - - /// - /// Ensures that bytes are available in the buffer, and if - /// not, reads from the socket until enough is available. - /// - internal Task Ensure(int count, bool async, bool readingNotifications) - { - return count <= ReadBytesLeft ? Task.CompletedTask : EnsureLong(this, count, async, readingNotifications); - - static async Task EnsureLong( - NpgsqlReadBuffer buffer, - int count, - bool async, - bool readingNotifications) - { - Debug.Assert(count <= buffer.Size); - Debug.Assert(count > buffer.ReadBytesLeft); - count -= buffer.ReadBytesLeft; - if (count <= 0) { return; } - - if (buffer.ReadPosition == buffer.FilledBytes) - { - buffer.Clear(); - } - else if (count > buffer.Size - buffer.FilledBytes) - { - Array.Copy(buffer.Buffer, buffer.ReadPosition, buffer.Buffer, 0, buffer.ReadBytesLeft); - buffer.FilledBytes = buffer.ReadBytesLeft; - buffer.ReadPosition = 0; - } - - var finalCt = async && buffer.Timeout >= TimeSpan.Zero - ? buffer.Cts.Start() - : buffer.Cts.Reset(); - - var totalRead = 0; - while (count > 0) - { - try - { - var toRead = buffer.Size - buffer.FilledBytes; - var read = async - ? await buffer.Underlying.ReadAsync(buffer.Buffer, buffer.FilledBytes, toRead, finalCt) - : buffer.Underlying.Read(buffer.Buffer, buffer.FilledBytes, toRead); - - if (read == 0) - throw new EndOfStreamException(); - count -= read; - buffer.FilledBytes += read; - totalRead += read; - - // Most of the time, it should be fine to reset cancellation token source, so we can use it again - // It's still possible for cancellation token to cancel between reading and resetting (although highly improbable) - // In this case, we consider it as timed out and fail with OperationCancelledException on next ReadAsync - // Or we consider it not timed out if we have already read everything (count == 0) - // In which case we reinitialize it on the next call to EnsureLong() - if (async) - buffer.Cts.RestartTimeoutWithoutReset(); - } - catch (Exception e) - { - var connector = buffer.Connector; - - // Stopping twice (in case the previous Stop() call succeeded) doesn't hurt. - // Not stopping will cause an assertion failure in debug mode when we call Start() the next time. - // We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained - // _timeoutCts - buffer.Cts.Stop(); - - switch (e) - { - // Read timeout - case OperationCanceledException _: - // Note that mono throws SocketException with the wrong error (see #1330) - case IOException _ when (e.InnerException as SocketException)?.SocketErrorCode == - (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): - { - Debug.Assert(e is OperationCanceledException ? async : !async); - - // When reading notifications (Wait), just throw TimeoutException or OperationCanceledException immediately. - // Nothing to cancel, and no breaking of the connection. - if (readingNotifications) - { - if (connector.UserCancellationRequested) - throw; - throw NpgsqlTimeoutException(); - } - - // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. - // TODO: As an optimization, we can still attempt to send a cancellation request, but after that immediately break the connection - if (connector.AttemptPostgresCancellation && - !connector.PostgresCancellationPerformed && - connector.PerformPostgresCancellation()) - { - // Note that if the cancellation timeout is negative, we flow down and break the connection immediately - var cancellationTimeout = connector.Settings.CancellationTimeout; - if (cancellationTimeout >= 0) - { - if (cancellationTimeout > 0) - buffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); - - if (async) - finalCt = buffer.Cts.Start(); - - continue; - } - } - - // If we're here, the PostgreSQL cancellation either failed or skipped entirely. - // Break the connection, bubbling up the correct exception type (cancellation or timeout) - throw connector.Break(!buffer.Connector.UserCancellationRequested - ? NpgsqlTimeoutException() - : connector.PostgresCancellationPerformed - ? new OperationCanceledException("Query was cancelled", TimeoutException(), connector.UserCancellationToken) - : new OperationCanceledException("Query was cancelled", connector.UserCancellationToken)); - } - - default: - throw connector.Break(new NpgsqlException("Exception while reading from stream", e)); - } - } - } - - buffer.Cts.Stop(); - NpgsqlEventSource.Log.BytesRead(totalRead); - - static Exception NpgsqlTimeoutException() => new NpgsqlException("Exception while reading from stream", TimeoutException()); - - static Exception TimeoutException() => new TimeoutException("Timeout during reading attempt"); - } - } - - internal void ReadMore() => ReadMore(false).GetAwaiter().GetResult(); - - internal Task ReadMore(bool async) => Ensure(ReadBytesLeft + 1, async); - - internal NpgsqlReadBuffer AllocateOversize(int count) - { - Debug.Assert(count > Size); - var tempBuf = new NpgsqlReadBuffer(Connector, Underlying, _underlyingSocket, count, TextEncoding, RelaxedTextEncoding); - if (_underlyingSocket != null) - tempBuf.Timeout = Timeout; - CopyTo(tempBuf); - Clear(); - return tempBuf; - } - - /// - /// Does not perform any I/O - assuming that the bytes to be skipped are in the memory buffer. - /// - /// - internal void Skip(long len) - { - Debug.Assert(ReadBytesLeft >= len); - ReadPosition += (int)len; - } - - /// - /// Skip a given number of bytes. - /// - public async Task Skip(long len, bool async) - { - Debug.Assert(len >= 0); - - if (len > ReadBytesLeft) - { - len -= ReadBytesLeft; - while (len > Size) - { - Clear(); - await Ensure(Size, async); - len -= Size; - } - Clear(); - await Ensure((int)len, async); - } - - ReadPosition += (int)len; - } - - #endregion - - #region Read Simple - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public sbyte ReadSByte() => Read(); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public byte ReadByte() => Read(); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public short ReadInt16() - => ReadInt16(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public short ReadInt16(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ushort ReadUInt16() - => ReadUInt16(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ushort ReadUInt16(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public int ReadInt32() - => ReadInt32(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public int ReadInt32(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public uint ReadUInt32() - => ReadUInt32(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public uint ReadUInt32(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public long ReadInt64() - => ReadInt64(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public long ReadInt64(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ulong ReadUInt64() - => ReadUInt64(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ulong ReadUInt64(bool littleEndian) - { - var result = Read(); - return littleEndian == BitConverter.IsLittleEndian - ? result : BinaryPrimitives.ReverseEndianness(result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public float ReadSingle() - => ReadSingle(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public float ReadSingle(bool littleEndian) - { - var result = ReadInt32(littleEndian); - return Unsafe.As(ref result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public double ReadDouble() - => ReadDouble(false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public double ReadDouble(bool littleEndian) - { - var result = ReadInt64(littleEndian); - return Unsafe.As(ref result); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - T Read() - { - if (Unsafe.SizeOf() > ReadBytesLeft) - ThrowNotSpaceLeft(); - - var result = Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); - ReadPosition += Unsafe.SizeOf(); - return result; - } - - [MethodImpl(MethodImplOptions.NoInlining)] - static void ThrowNotSpaceLeft() - => throw new InvalidOperationException("There is not enough space left in the buffer."); - - public string ReadString(int byteLen) - { - Debug.Assert(byteLen <= ReadBytesLeft); - var result = TextEncoding.GetString(Buffer, ReadPosition, byteLen); - ReadPosition += byteLen; - return result; - } - - public char[] ReadChars(int byteLen) - { - Debug.Assert(byteLen <= ReadBytesLeft); - var result = TextEncoding.GetChars(Buffer, ReadPosition, byteLen); - ReadPosition += byteLen; - return result; - } - - public void ReadBytes(Span output) - { - Debug.Assert(output.Length <= ReadBytesLeft); - new Span(Buffer, ReadPosition, output.Length).CopyTo(output); - ReadPosition += output.Length; - } - - public void ReadBytes(byte[] output, int outputOffset, int len) - => ReadBytes(new Span(output, outputOffset, len)); - - public ReadOnlySpan ReadSpan(int len) - { - Debug.Assert(len <= ReadBytesLeft); - return new ReadOnlySpan(Buffer, ReadPosition, len); - } - - public ReadOnlyMemory ReadMemory(int len) - { - Debug.Assert(len <= ReadBytesLeft); - return new ReadOnlyMemory(Buffer, ReadPosition, len); - } - - #endregion - - #region Read Complex - - public int Read(Span output) - { - var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); - if (readFromBuffer > 0) - { - new Span(Buffer, ReadPosition, readFromBuffer).CopyTo(output); - ReadPosition += readFromBuffer; - return readFromBuffer; - } - - if (output.Length == 0) - return 0; - - Debug.Assert(ReadPosition == 0); - Clear(); - try - { - var read = Underlying.Read(output); - if (read == 0) - throw new EndOfStreamException(); - return read; - } - catch (Exception e) - { - throw Connector.Break(new NpgsqlException("Exception while reading from stream", e)); - } - } - - public ValueTask ReadAsync(Memory output, CancellationToken cancellationToken = default) - { - if (output.Length == 0) - return new ValueTask(0); - - var readFromBuffer = Math.Min(ReadBytesLeft, output.Length); - if (readFromBuffer > 0) - { - new Span(Buffer, ReadPosition, readFromBuffer).CopyTo(output.Span); - ReadPosition += readFromBuffer; - return new ValueTask(readFromBuffer); - } - - return ReadAsyncLong(this, output, cancellationToken); - - static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, Memory output, CancellationToken cancellationToken) - { - Debug.Assert(buffer.ReadBytesLeft == 0); - buffer.Clear(); - try - { - var read = await buffer.Underlying.ReadAsync(output, cancellationToken); - if (read == 0) - throw new EndOfStreamException(); - return read; - } - catch (Exception e) - { - throw buffer.Connector.Break(new NpgsqlException("Exception while reading from stream", e)); - } - } - } - - public Stream GetStream(int len, bool canSeek) - { - if (_columnStream == null) - _columnStream = new ColumnStream(Connector); - - _columnStream.Init(len, canSeek); - return _columnStream; - } - - /// - /// Seeks the first null terminator (\0) and returns the string up to it. The buffer must already - /// contain the entire string and its terminator. - /// - public string ReadNullTerminatedString() - => ReadNullTerminatedString(TextEncoding, async: false).GetAwaiter().GetResult(); - - /// - /// Seeks the first null terminator (\0) and returns the string up to it. The buffer must already - /// contain the entire string and its terminator. If any character could not be decoded, a question - /// mark character is returned instead of throwing an exception. - /// - public string ReadNullTerminatedStringRelaxed() - => ReadNullTerminatedString(RelaxedTextEncoding, async: false).GetAwaiter().GetResult(); - - public ValueTask ReadNullTerminatedString(bool async, CancellationToken cancellationToken = default) - => ReadNullTerminatedString(TextEncoding, async, cancellationToken); - - /// - /// Seeks the first null terminator (\0) and returns the string up to it. Reads additional data from the network if a null - /// terminator isn't found in the buffered data. - /// - ValueTask ReadNullTerminatedString(Encoding encoding, bool async, CancellationToken cancellationToken = default) - { - return ReadFromBuffer(this, encoding, out var s) - ? new ValueTask(s) - : ReadLong(this, async, encoding, s); - - static bool ReadFromBuffer(NpgsqlReadBuffer buffer, Encoding encoding, out string s) - { - var start = buffer.ReadPosition; - while (buffer.ReadPosition < buffer.FilledBytes) - { - if (buffer.Buffer[buffer.ReadPosition++] == 0) - { - s = encoding.GetString(buffer.Buffer, start, buffer.ReadPosition - start - 1); - return true; - } - } - - s = encoding.GetString(buffer.Buffer, start, buffer.ReadPosition - start); - return false; - } - - static async ValueTask ReadLong(NpgsqlReadBuffer buffer, bool async, Encoding encoding, string s) - { - var builder = new StringBuilder(s); - bool complete; - do - { - await buffer.ReadMore(async); - complete = ReadFromBuffer(buffer, encoding, out s); - builder.Append(s); - } - while (!complete); - - return builder.ToString(); - } - } - - public ReadOnlySpan GetNullTerminatedBytes() - { - int i; - for (i = ReadPosition; Buffer[i] != 0; i++) - Debug.Assert(i <= ReadPosition + ReadBytesLeft); - Debug.Assert(i >= ReadPosition); - - var result = new ReadOnlySpan(Buffer, ReadPosition, i - ReadPosition); - ReadPosition = i + 1; - return result; - } - - #endregion - - #region Dispose - - public void Dispose() - { - if (_disposed) - return; - - ArrayPool.Shared.Return(Buffer); - - Cts.Dispose(); - _disposed = true; - } - - #endregion - - #region Misc - - internal void Clear() - { - ReadPosition = 0; - FilledBytes = 0; - } - - internal void CopyTo(NpgsqlReadBuffer other) - { - Debug.Assert(other.Size - other.FilledBytes >= ReadBytesLeft); - Array.Copy(Buffer, ReadPosition, other.Buffer, other.FilledBytes, ReadBytesLeft); - other.FilledBytes += ReadBytesLeft; - } - - #endregion - } -} diff --git a/src/Npgsql/NpgsqlSchema.cs b/src/Npgsql/NpgsqlSchema.cs index 2dd1d2869b..f9688744ec 100644 --- a/src/Npgsql/NpgsqlSchema.cs +++ b/src/Npgsql/NpgsqlSchema.cs @@ -1,315 +1,468 @@ using System; +using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Globalization; using System.Text; using System.Threading; using System.Threading.Tasks; +using Npgsql.Internal; using Npgsql.PostgresTypes; using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +/// +/// Provides the underlying mechanism for reading schema information. +/// +static class NpgsqlSchema { - /// - /// Provides the underlying mechanism for reading schema information. - /// - static class NpgsqlSchema + public static Task GetSchema(bool async, NpgsqlConnection conn, string? collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) { - public static Task GetSchema(NpgsqlConnection conn, string? collectionName, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + if (collectionName is null) + throw new ArgumentNullException(nameof(collectionName)); + if (collectionName.Length == 0) + throw new ArgumentException("Collection name cannot be empty.", nameof(collectionName)); + + return collectionName.ToUpperInvariant() switch { - if (collectionName is null) - throw new ArgumentNullException(nameof(collectionName)); - if (collectionName.Length == 0) - throw new ArgumentException("Collection name cannot be empty.", nameof(collectionName)); + "METADATACOLLECTIONS" => Task.FromResult(GetMetaDataCollections()), + "RESTRICTIONS" => Task.FromResult(GetRestrictions()), + "DATASOURCEINFORMATION" => Task.FromResult(GetDataSourceInformation(conn)), + "DATATYPES" => Task.FromResult(GetDataTypes(conn)), + "RESERVEDWORDS" => Task.FromResult(GetReservedWords()), + // custom collections for npgsql + "DATABASES" => GetDatabases(conn, restrictions, async, cancellationToken), + "SCHEMATA" => GetSchemata(conn, restrictions, async, cancellationToken), + "TABLES" => GetTables(conn, restrictions, async, cancellationToken), + "COLUMNS" => GetColumns(conn, restrictions, async, cancellationToken), + "VIEWS" => GetViews(conn, restrictions, async, cancellationToken), + "MATERIALIZEDVIEWS" => GetMaterializedViews(conn, restrictions, async, cancellationToken), + "USERS" => GetUsers(conn, restrictions, async, cancellationToken), + "INDEXES" => GetIndexes(conn, restrictions, async, cancellationToken), + "INDEXCOLUMNS" => GetIndexColumns(conn, restrictions, async, cancellationToken), + "CONSTRAINTS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), + "PRIMARYKEY" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), + "UNIQUEKEYS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), + "FOREIGNKEYS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), + "CONSTRAINTCOLUMNS" => GetConstraintColumns(conn, restrictions, async, cancellationToken), + _ => throw new ArgumentOutOfRangeException(nameof(collectionName), collectionName, "Invalid collection name.") + }; + } - return collectionName.ToUpperInvariant() switch - { - "METADATACOLLECTIONS" => Task.FromResult(GetMetaDataCollections()), - "RESTRICTIONS" => Task.FromResult(GetRestrictions()), - "DATASOURCEINFORMATION" => Task.FromResult(GetDataSourceInformation(conn)), - "DATATYPES" => Task.FromResult(GetDataTypes(conn)), - "RESERVEDWORDS" => Task.FromResult(GetReservedWords()), - // custom collections for npgsql - "DATABASES" => GetDatabases(conn, restrictions, async, cancellationToken), - "SCHEMATA" => GetSchemata(conn, restrictions, async, cancellationToken), - "TABLES" => GetTables(conn, restrictions, async, cancellationToken), - "COLUMNS" => GetColumns(conn, restrictions, async, cancellationToken), - "VIEWS" => GetViews(conn, restrictions, async, cancellationToken), - "USERS" => GetUsers(conn, restrictions, async, cancellationToken), - "INDEXES" => GetIndexes(conn, restrictions, async, cancellationToken), - "INDEXCOLUMNS" => GetIndexColumns(conn, restrictions, async, cancellationToken), - "CONSTRAINTS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), - "PRIMARYKEY" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), - "UNIQUEKEYS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), - "FOREIGNKEYS" => GetConstraints(conn, restrictions, collectionName, async, cancellationToken), - "CONSTRAINTCOLUMNS" => GetConstraintColumns(conn, restrictions, async, cancellationToken), - _ => throw new ArgumentOutOfRangeException(nameof(collectionName), collectionName, "Invalid collection name.") - }; - } + /// + /// Returns the MetaDataCollections that lists all possible collections. + /// + /// The MetaDataCollections + static DataTable GetMetaDataCollections() + { + var table = new DataTable("MetaDataCollections"); + table.Columns.Add("CollectionName", typeof(string)); + table.Columns.Add("NumberOfRestrictions", typeof(int)); + table.Columns.Add("NumberOfIdentifierParts", typeof(int)); + + table.Rows.Add("MetaDataCollections", 0, 0); + table.Rows.Add("DataSourceInformation", 0, 0); + table.Rows.Add("Restrictions", 0, 0); + table.Rows.Add("DataTypes", 0, 0); // TODO: Support type name restriction + table.Rows.Add("Databases", 1, 1); + table.Rows.Add("Tables", 4, 3); + table.Rows.Add("Columns", 4, 4); + table.Rows.Add("Views", 3, 3); + table.Rows.Add("Users", 1, 1); + table.Rows.Add("Indexes", 4, 4); + table.Rows.Add("IndexColumns", 5, 5); + + return table; + } - /// - /// Returns the MetaDataCollections that lists all possible collections. - /// - /// The MetaDataCollections - static DataTable GetMetaDataCollections() - { - var table = new DataTable("MetaDataCollections"); - table.Columns.Add("CollectionName", typeof(string)); - table.Columns.Add("NumberOfRestrictions", typeof(int)); - table.Columns.Add("NumberOfIdentifierParts", typeof(int)); - - table.Rows.Add("MetaDataCollections", 0, 0); - table.Rows.Add("DataSourceInformation", 0, 0); - table.Rows.Add("Restrictions", 0, 0); - table.Rows.Add("DataTypes", 0, 0); // TODO: Support type name restriction - table.Rows.Add("Databases", 1, 1); - table.Rows.Add("Tables", 4, 3); - table.Rows.Add("Columns", 4, 4); - table.Rows.Add("Views", 3, 3); - table.Rows.Add("Users", 1, 1); - table.Rows.Add("Indexes", 4, 4); - table.Rows.Add("IndexColumns", 5, 5); - - return table; - } + /// + /// Returns the Restrictions that contains the meaning and position of the values in the restrictions array. + /// + /// The Restrictions + static DataTable GetRestrictions() + { + var table = new DataTable("Restrictions"); + + table.Columns.Add("CollectionName", typeof(string)); + table.Columns.Add("RestrictionName", typeof(string)); + table.Columns.Add("RestrictionDefault", typeof(string)); + table.Columns.Add("RestrictionNumber", typeof(int)); + + table.Rows.Add("Database", "Name", "Name", 1); + table.Rows.Add("Tables", "Catalog", "table_catalog", 1); + table.Rows.Add("Tables", "Schema", "schema_catalog", 2); + table.Rows.Add("Tables", "Table", "table_name", 3); + table.Rows.Add("Tables", "TableType", "table_type", 4); + table.Rows.Add("Columns", "Catalog", "table_catalog", 1); + table.Rows.Add("Columns", "Schema", "table_schema", 2); + table.Rows.Add("Columns", "TableName", "table_name", 3); + table.Rows.Add("Columns", "Column", "column_name", 4); + table.Rows.Add("Views", "Catalog", "table_catalog", 1); + table.Rows.Add("Views", "Schema", "table_schema", 2); + table.Rows.Add("Views", "Table", "table_name", 3); + table.Rows.Add("Users", "User", "user_name", 1); + table.Rows.Add("Indexes", "Catalog", "table_catalog", 1); + table.Rows.Add("Indexes", "Schema", "table_schema", 2); + table.Rows.Add("Indexes", "Table", "table_name", 3); + table.Rows.Add("Indexes", "Index", "index_name", 4); + table.Rows.Add("IndexColumns", "Catalog", "table_catalog", 1); + table.Rows.Add("IndexColumns", "Schema", "table_schema", 2); + table.Rows.Add("IndexColumns", "Table", "table_name", 3); + table.Rows.Add("IndexColumns", "Index", "index_name", 4); + table.Rows.Add("IndexColumns", "Column", "column_name", 5); + + return table; + } - /// - /// Returns the Restrictions that contains the meaning and position of the values in the restrictions array. - /// - /// The Restrictions - static DataTable GetRestrictions() - { - var table = new DataTable("Restrictions"); - - table.Columns.Add("CollectionName", typeof(string)); - table.Columns.Add("RestrictionName", typeof(string)); - table.Columns.Add("RestrictionDefault", typeof(string)); - table.Columns.Add("RestrictionNumber", typeof(int)); - - table.Rows.Add("Database", "Name", "Name", 1); - table.Rows.Add("Tables", "Catalog", "table_catalog", 1); - table.Rows.Add("Tables", "Schema", "schema_catalog", 2); - table.Rows.Add("Tables", "Table", "table_name", 3); - table.Rows.Add("Tables", "TableType", "table_type", 4); - table.Rows.Add("Columns", "Catalog", "table_catalog", 1); - table.Rows.Add("Columns", "Schema", "table_schema", 2); - table.Rows.Add("Columns", "TableName", "table_name", 3); - table.Rows.Add("Columns", "Column", "column_name", 4); - table.Rows.Add("Views", "Catalog", "table_catalog", 1); - table.Rows.Add("Views", "Schema", "table_schema", 2); - table.Rows.Add("Views", "Table", "table_name", 3); - table.Rows.Add("Users", "User", "user_name", 1); - table.Rows.Add("Indexes", "Catalog", "table_catalog", 1); - table.Rows.Add("Indexes", "Schema", "table_schema", 2); - table.Rows.Add("Indexes", "Table", "table_name", 3); - table.Rows.Add("Indexes", "Index", "index_name", 4); - table.Rows.Add("IndexColumns", "Catalog", "table_catalog", 1); - table.Rows.Add("IndexColumns", "Schema", "table_schema", 2); - table.Rows.Add("IndexColumns", "Table", "table_name", 3); - table.Rows.Add("IndexColumns", "Index", "index_name", 4); - table.Rows.Add("IndexColumns", "Column", "column_name", 5); - - return table; - } + static NpgsqlCommand BuildCommand(NpgsqlConnection conn, StringBuilder query, string?[]? restrictions, params string[]? names) + => BuildCommand(conn, query, restrictions, true, names); - static NpgsqlCommand BuildCommand(NpgsqlConnection conn, StringBuilder query, string?[]? restrictions, params string[]? names) - => BuildCommand(conn, query, restrictions, true, names); + static NpgsqlCommand BuildCommand(NpgsqlConnection conn, StringBuilder query, string?[]? restrictions, bool addWhere, params string[]? names) + { + var command = new NpgsqlCommand(); - static NpgsqlCommand BuildCommand(NpgsqlConnection conn, StringBuilder query, string?[]? restrictions, bool addWhere, params string[]? names) + if (restrictions != null && names != null) { - var command = new NpgsqlCommand(); - - if (restrictions != null && names != null) + for (var i = 0; i < restrictions.Length && i < names.Length; ++i) { - for (var i = 0; i < restrictions.Length && i < names.Length; ++i) + if (restrictions[i] is { Length: > 0 } restriction) { - if (restrictions[i] is string restriction && restriction.Length != 0) + if (addWhere) + { + query.Append(" WHERE "); + addWhere = false; + } + else { - if (addWhere) - { - query.Append(" WHERE "); - addWhere = false; - } - else - { - query.Append(" AND "); - } + query.Append(" AND "); + } - var paramName = RemoveSpecialChars(names[i]); + var paramName = RemoveSpecialChars(names[i]); - query.AppendFormat("{0} = :{1}", names[i], paramName); + query.AppendFormat("{0} = :{1}", names[i], paramName); - command.Parameters.Add(new NpgsqlParameter(paramName, restriction)); - } + command.Parameters.Add(new NpgsqlParameter(paramName, restriction)); } } - command.CommandText = query.ToString(); - command.Connection = conn; - - return command; } + command.CommandText = query.ToString(); + command.Connection = conn; + + return command; + } - static string RemoveSpecialChars(string paramName) - => paramName.Replace("(", "").Replace(")", "").Replace(".", ""); + static string RemoveSpecialChars(string paramName) + => paramName.Replace("(", "").Replace(")", "").Replace(".", ""); - static async Task GetDatabases(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) - { - var databases = new DataTable("Databases") { Locale = CultureInfo.InvariantCulture }; - databases.Columns.AddRange(new[] { + static Task GetDatabases(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Databases") + { + Locale = CultureInfo.InvariantCulture, + Columns = + { new DataColumn("database_name"), new DataColumn("owner"), new DataColumn("encoding") - }); - - var getDatabases = new StringBuilder(); - - getDatabases.Append("SELECT d.datname AS database_name, u.usename AS owner, pg_catalog.pg_encoding_to_char(d.encoding) AS encoding FROM pg_catalog.pg_database d LEFT JOIN pg_catalog.pg_user u ON d.datdba = u.usesysid"); + } + }; - using var command = BuildCommand(conn, getDatabases, restrictions, "datname"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(databases, async, cancellationToken); + var sql = new StringBuilder(); - return databases; - } + sql.Append( + """ +SELECT d.datname, u.usename, pg_catalog.pg_encoding_to_char(d.encoding) +FROM pg_catalog.pg_database d +LEFT JOIN pg_catalog.pg_user u ON d.datdba = u.usesysid +"""); - static async Task GetSchemata(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, "datname"), + dataTable, + (reader, row) => { - var schemata = new DataTable("Schemata") { Locale = CultureInfo.InvariantCulture }; + row["database_name"] = GetFieldValueOrDBNull(reader, 0); + row["owner"] = GetFieldValueOrDBNull(reader, 1); + row["encoding"] = GetFieldValueOrDBNull(reader, 2); + }, cancellationToken); + } - schemata.Columns.AddRange(new[] { + static Task GetSchemata(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Schemata") + { + Locale = CultureInfo.InvariantCulture, + Columns = + { new DataColumn("catalog_name"), new DataColumn("schema_name"), new DataColumn("schema_owner") - }); + } + }; - var getSchemata = new StringBuilder(@" + var sql = new StringBuilder( + """ SELECT * FROM ( - SELECT current_database() AS catalog_name, - nspname AS schema_name, - r.rolname AS schema_owner - FROM - pg_catalog.pg_namespace LEFT JOIN pg_catalog.pg_roles r ON r.oid = nspowner - ) tmp"); - - using var command = BuildCommand(conn, getSchemata, restrictions, "catalog_name", "schema_name", "schema_owner"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(schemata, async, cancellationToken); - - return schemata; - } - + SELECT current_database(), nspname, r.rolname + FROM pg_catalog.pg_namespace + LEFT JOIN pg_catalog.pg_roles r ON r.oid = nspowner +) tmp +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, "catalog_name", "schema_name", "schema_owner"), + dataTable, + (reader, row) => + { + row["catalog_name"] = GetFieldValueOrDBNull(reader, 0); + row["schema_name"] = GetFieldValueOrDBNull(reader, 1); + row["schema_owner"] = GetFieldValueOrDBNull(reader, 2); + }, cancellationToken); + } - static async Task GetTables(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + static Task GetTables(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Tables") { - var tables = new DataTable("Tables") { Locale = CultureInfo.InvariantCulture }; - - tables.Columns.AddRange(new[] { + Locale = CultureInfo.InvariantCulture, + Columns = + { new DataColumn("table_catalog"), new DataColumn("table_schema"), new DataColumn("table_name"), new DataColumn("table_type") - }); + } + }; - var getTables = new StringBuilder(); + var sql = new StringBuilder(); - getTables.Append(@" + sql.Append( + """ SELECT table_catalog, table_schema, table_name, table_type FROM information_schema.tables WHERE table_type IN ('BASE TABLE', 'FOREIGN', 'FOREIGN TABLE') AND - table_schema NOT IN ('pg_catalog', 'information_schema')"); - - using var command = BuildCommand(conn, getTables, restrictions, false, "table_catalog", "table_schema", "table_name", "table_type"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(tables, async, cancellationToken); - - return tables; - } + table_schema NOT IN ('pg_catalog', 'information_schema') +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "table_catalog", "table_schema", "table_name", "table_type"), + dataTable, + (reader, row) => + { + row["table_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["table_schema"] = GetFieldValueOrDBNull(reader, 1); + row["table_name"] = GetFieldValueOrDBNull(reader, 2); + row["table_type"] = GetFieldValueOrDBNull(reader, 3); + }, cancellationToken); + } - static async Task GetColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + static Task GetColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Columns") { - var columns = new DataTable("Columns") { Locale = CultureInfo.InvariantCulture }; - - columns.Columns.AddRange(new[] { - new DataColumn("table_catalog"), new DataColumn("table_schema"), new DataColumn("table_name"), - new DataColumn("column_name"), new DataColumn("ordinal_position", typeof(int)), new DataColumn("column_default"), - new DataColumn("is_nullable"), new DataColumn("data_type"), - new DataColumn("character_maximum_length", typeof(int)), new DataColumn("character_octet_length", typeof(int)), - new DataColumn("numeric_precision", typeof(int)), new DataColumn("numeric_precision_radix", typeof(int)), - new DataColumn("numeric_scale", typeof(int)), new DataColumn("datetime_precision", typeof(int)), - new DataColumn("character_set_catalog"), new DataColumn("character_set_schema"), - new DataColumn("character_set_name"), new DataColumn("collation_catalog") - }); - - var getColumns = new StringBuilder(@" -SELECT - table_catalog, table_schema, table_name, column_name, ordinal_position, column_default, is_nullable, - udt_name::regtype::text AS data_type, character_maximum_length, character_octet_length, numeric_precision, - numeric_precision_radix, numeric_scale, datetime_precision, character_set_catalog, character_set_schema, - character_set_name, collation_catalog -FROM information_schema.columns"); - - using var command = BuildCommand(conn, getColumns, restrictions, "table_catalog", "table_schema", "table_name", "column_name"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(columns, async, cancellationToken); + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("column_name"), + new DataColumn("ordinal_position", typeof(int)), + new DataColumn("column_default"), + new DataColumn("is_nullable"), + new DataColumn("data_type"), + new DataColumn("character_maximum_length", typeof(int)), + new DataColumn("character_octet_length", typeof(int)), + new DataColumn("numeric_precision", typeof(int)), + new DataColumn("numeric_precision_radix", typeof(int)), + new DataColumn("numeric_scale", typeof(int)), + new DataColumn("datetime_precision", typeof(int)), + new DataColumn("character_set_catalog"), + new DataColumn("character_set_schema"), + new DataColumn("character_set_name"), + new DataColumn("collation_catalog") + } + }; - return columns; - } + var sql = new StringBuilder( + """ +SELECT + table_catalog, + table_schema, + table_name, + column_name, + ordinal_position, + column_default, + is_nullable, + CASE WHEN udt_schema is NULL THEN udt_name ELSE format_type(typ.oid, NULL) END, + character_maximum_length, + character_octet_length, + numeric_precision, + numeric_precision_radix, + numeric_scale, + datetime_precision, + character_set_catalog, + character_set_schema, + character_set_name, + collation_catalog +FROM information_schema.columns +JOIN pg_namespace AS ns ON ns.nspname = udt_schema +JOIN pg_type AS typ ON typnamespace = ns.oid AND typname = udt_name +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, "table_catalog", "table_schema", "table_name", "column_name"), + dataTable, + (reader, row) => + { + row["table_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["table_schema"] = GetFieldValueOrDBNull(reader, 1); + row["table_name"] = GetFieldValueOrDBNull(reader, 2); + row["column_name"] = GetFieldValueOrDBNull(reader, 3); + row["ordinal_position"] = GetFieldValueOrDBNull(reader, 4); + row["column_default"] = GetFieldValueOrDBNull(reader, 5); + row["is_nullable"] = GetFieldValueOrDBNull(reader, 6); + row["data_type"] = GetFieldValueOrDBNull(reader, 7); + row["character_maximum_length"] = GetFieldValueOrDBNull(reader, 8); + row["character_octet_length"] = GetFieldValueOrDBNull(reader, 9); + row["numeric_precision"] = GetFieldValueOrDBNull(reader, 10); + row["numeric_precision_radix"] = GetFieldValueOrDBNull(reader, 11); + row["numeric_scale"] = GetFieldValueOrDBNull(reader, 12); + row["datetime_precision"] = GetFieldValueOrDBNull(reader, 13); + row["character_set_catalog"] = GetFieldValueOrDBNull(reader, 14); + row["character_set_schema"] = GetFieldValueOrDBNull(reader, 15); + row["character_set_name"] = GetFieldValueOrDBNull(reader, 16); + row["collation_catalog"] = GetFieldValueOrDBNull(reader, 17); + }, cancellationToken); + } - static async Task GetViews(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + static Task GetViews(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Views") { - var views = new DataTable("Views") { Locale = CultureInfo.InvariantCulture }; - - views.Columns.AddRange(new[] { - new DataColumn("table_catalog"), new DataColumn("table_schema"), new DataColumn("table_name"), - new DataColumn("check_option"), new DataColumn("is_updatable") - }); + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("check_option"), + new DataColumn("is_updatable") + } + }; - var getViews = new StringBuilder(@" + var sql = new StringBuilder( + """ SELECT table_catalog, table_schema, table_name, check_option, is_updatable FROM information_schema.views -WHERE table_schema NOT IN ('pg_catalog', 'information_schema')"); +WHERE table_schema NOT IN ('pg_catalog', 'information_schema') +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "table_catalog", "table_schema", "table_name"), + dataTable, + (reader, row) => + { + row["table_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["table_schema"] = GetFieldValueOrDBNull(reader, 1); + row["table_name"] = GetFieldValueOrDBNull(reader, 2); + row["check_option"] = GetFieldValueOrDBNull(reader, 3); + row["is_updatable"] = GetFieldValueOrDBNull(reader, 3); + }, cancellationToken); + } - using var command = BuildCommand(conn, getViews, restrictions, false, "table_catalog", "table_schema", "table_name"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(views, async, cancellationToken); + static Task GetMaterializedViews(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("MaterializedViews") + { + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("table_owner"), + new DataColumn("has_indexes", typeof(bool)), + new DataColumn("is_populated", typeof(bool)) + } + }; - return views; - } + var sql = new StringBuilder(); - static async Task GetUsers(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) - { - var users = new DataTable("Users") { Locale = CultureInfo.InvariantCulture }; + sql.Append("""SELECT current_database(), schemaname, matviewname, matviewowner, hasindexes, ispopulated FROM pg_catalog.pg_matviews"""); - users.Columns.AddRange(new[] { new DataColumn("user_name"), new DataColumn("user_sysid", typeof(int)) }); + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, "current_database()", "schemaname", "matviewname", "matviewowner"), + dataTable, + (reader, row) => + { + row["table_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["table_schema"] = GetFieldValueOrDBNull(reader, 1); + row["table_name"] = GetFieldValueOrDBNull(reader, 2); + row["table_owner"] = GetFieldValueOrDBNull(reader, 3); + row["has_indexes"] = GetFieldValueOrDBNull(reader, 4); + row["is_populated"] = GetFieldValueOrDBNull(reader, 5); + }, cancellationToken); + } - var getUsers = new StringBuilder(); + static Task GetUsers(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Users") + { + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("user_name"), + new DataColumn("user_sysid", typeof(uint)) + } + }; - getUsers.Append("SELECT usename as user_name, usesysid as user_sysid FROM pg_catalog.pg_user"); + var sql = new StringBuilder(); - using var command = BuildCommand(conn, getUsers, restrictions, "usename"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(users, async, cancellationToken); + sql.Append("SELECT usename, usesysid FROM pg_catalog.pg_user"); - return users; - } + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, "usename"), + dataTable, + (reader, row) => + { + row["user_name"] = GetFieldValueOrDBNull(reader, 0); + row["user_sysid"] = GetFieldValueOrDBNull(reader, 1); + }, cancellationToken); + } - static async Task GetIndexes(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + static Task GetIndexes(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("Indexes") { - var indexes = new DataTable("Indexes") { Locale = CultureInfo.InvariantCulture }; - - indexes.Columns.AddRange(new[] { - new DataColumn("table_catalog"), new DataColumn("table_schema"), new DataColumn("table_name"), - new DataColumn("index_name"), new DataColumn("type_desc") - }); - - var getIndexes = new StringBuilder(@" -SELECT current_database() AS table_catalog, - n.nspname AS table_schema, - t.relname AS table_name, - i.relname AS index_name, - '' AS type_desc + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("index_name"), + new DataColumn("type_desc") + } + }; + + var sql = new StringBuilder( + """ +SELECT current_database(), + n.nspname, + t.relname, + i.relname, + '' FROM pg_catalog.pg_class i JOIN pg_catalog.pg_index ix ON ix.indexrelid = i.oid @@ -319,35 +472,52 @@ pg_catalog.pg_class i WHERE i.relkind = 'i' AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND - t.relkind = 'r'"); - - using var command = BuildCommand(conn, getIndexes, restrictions, false, "current_database()", "n.nspname", "t.relname", "i.relname"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(indexes, async, cancellationToken); - - return indexes; - } + t.relkind = 'r' +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "current_database()", "n.nspname", "t.relname", "i.relname"), + dataTable, + (reader, row) => + { + row["table_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["table_schema"] = GetFieldValueOrDBNull(reader, 1); + row["table_name"] = GetFieldValueOrDBNull(reader, 2); + row["index_name"] = GetFieldValueOrDBNull(reader, 3); + row["type_desc"] = GetFieldValueOrDBNull(reader, 4); + }, cancellationToken); + } - static async Task GetIndexColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + static Task GetIndexColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var dataTable = new DataTable("IndexColumns") { - var indexColumns = new DataTable("IndexColumns") { Locale = CultureInfo.InvariantCulture }; - - indexColumns.Columns.AddRange(new[] { - new DataColumn("constraint_catalog"), new DataColumn("constraint_schema"), new DataColumn("constraint_name"), - new DataColumn("table_catalog"), new DataColumn("table_schema"), new DataColumn("table_name"), - new DataColumn("column_name"), new DataColumn("index_name") - }); + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("constraint_catalog"), + new DataColumn("constraint_schema"), + new DataColumn("constraint_name"), + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("column_name"), + new DataColumn("index_name") + } + }; - var getIndexColumns = new StringBuilder(@" + var sql = new StringBuilder( + """ SELECT - current_database() AS constraint_catalog, - t_ns.nspname AS constraint_schema, - ix_cls.relname AS constraint_name, - current_database() AS table_catalog, - ix_ns.nspname AS table_schema, - t.relname AS table_name, - a.attname AS column_name, - ix_cls.relname AS index_name + current_database(), + t_ns.nspname, + ix_cls.relname, + current_database(), + ix_ns.nspname, + t.relname, + a.attname, + ix_cls.relname FROM pg_class t JOIN pg_index ix ON t.oid = ix.indrelid @@ -359,69 +529,117 @@ pg_class t ix_cls.relkind = 'i' AND t_ns.nspname NOT IN ('pg_catalog', 'pg_toast') AND a.attnum = ANY(ix.indkey) AND - t.relkind = 'r'"); - - using var command = BuildCommand(conn, getIndexColumns, restrictions, false, "current_database()", "n.nspname", "t.relname", "i.relname", "a.attname"); - using var adapter = new NpgsqlDataAdapter(command); - await adapter.Fill(indexColumns, async, cancellationToken); - - return indexColumns; - } + t.relkind = 'r' +"""); + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "current_database()", "t_ns.nspname", "t.relname", "ix_cls.relname", "a.attname"), + dataTable, + (reader, row) => + { + row["constraint_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["constraint_schema"] = GetFieldValueOrDBNull(reader, 1); + row["constraint_name"] = GetFieldValueOrDBNull(reader, 2); + row["table_catalog"] = GetFieldValueOrDBNull(reader, 3); + row["table_schema"] = GetFieldValueOrDBNull(reader, 4); + row["table_name"] = GetFieldValueOrDBNull(reader, 5); + row["column_name"] = GetFieldValueOrDBNull(reader, 6); + row["index_name"] = GetFieldValueOrDBNull(reader, 7); + }, cancellationToken); + } - static async Task GetConstraints(NpgsqlConnection conn, string?[]? restrictions, string? constraintType, bool async, CancellationToken cancellationToken = default) - { - var getConstraints = new StringBuilder(@" + static Task GetConstraints(NpgsqlConnection conn, string?[]? restrictions, string? constraintType, bool async, CancellationToken cancellationToken = default) + { + var sql = new StringBuilder( + """ SELECT - current_database() AS ""CONSTRAINT_CATALOG"", - pgn.nspname AS ""CONSTRAINT_SCHEMA"", - pgc.conname AS ""CONSTRAINT_NAME"", - current_database() AS ""TABLE_CATALOG"", - pgtn.nspname AS ""TABLE_SCHEMA"", - pgt.relname AS ""TABLE_NAME"", - ""CONSTRAINT_TYPE"", - pgc.condeferrable AS ""IS_DEFERRABLE"", - pgc.condeferred AS ""INITIALLY_DEFERRED"" + current_database(), + pgn.nspname, + pgc.conname, + current_database(), + pgtn.nspname, + pgt.relname, + constraint_type, + pgc.condeferrable, + pgc.condeferred FROM pg_catalog.pg_constraint pgc JOIN pg_catalog.pg_namespace pgn ON pgc.connamespace = pgn.oid JOIN pg_catalog.pg_class pgt ON pgc.conrelid = pgt.oid JOIN pg_catalog.pg_namespace pgtn ON pgt.relnamespace = pgtn.oid JOIN ( - SELECT 'PRIMARY KEY' AS ""CONSTRAINT_TYPE"", 'p' AS ""contype"" + SELECT 'PRIMARY KEY' AS constraint_type, 'p' AS contype UNION ALL - SELECT 'FOREIGN KEY' AS ""CONSTRAINT_TYPE"", 'f' AS ""contype"" + SELECT 'FOREIGN KEY' AS constraint_type, 'f' AS contype UNION ALL - SELECT 'UNIQUE KEY' AS ""CONSTRAINT_TYPE"", 'u' AS ""contype"" -) mapping_table ON mapping_table.contype = pgc.contype"); - if ("ForeignKeys".Equals(constraintType)) - getConstraints.Append(" and pgc.contype='f'"); - else if ("PrimaryKey".Equals(constraintType)) - getConstraints.Append(" and pgc.contype='p'"); - else if ("UniqueKeys".Equals(constraintType)) - getConstraints.Append(" and pgc.contype='u'"); - else - constraintType = "Constraints"; - - using var command = BuildCommand(conn, getConstraints, restrictions, false, "current_database()", "pgtn.nspname", "pgt.relname", "pgc.conname"); - using var adapter = new NpgsqlDataAdapter(command); - var table = new DataTable(constraintType) { Locale = CultureInfo.InvariantCulture }; - - await adapter.Fill(table, async, cancellationToken); + SELECT 'UNIQUE KEY' AS constraint_type, 'u' AS contype +) mapping_table ON mapping_table.contype = pgc.contype +"""); - return table; + switch (constraintType) + { + case "ForeignKeys": + sql.Append(" and pgc.contype='f'"); + break; + case "PrimaryKey": + sql.Append(" and pgc.contype='p'"); + break; + case "UniqueKeys": + sql.Append(" and pgc.contype='u'"); + break; + default: + constraintType = "Constraints"; + break; } - static async Task GetConstraintColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + var dataTable = new DataTable(constraintType) { - var getConstraintColumns = new StringBuilder(@" -SELECT current_database() AS constraint_catalog, - n.nspname AS constraint_schema, - c.conname AS constraint_name, - current_database() AS table_catalog, - n.nspname AS table_schema, - t.relname AS table_name, - a.attname AS column_name, - a.attnum AS ordinal_number, + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("CONSTRAINT_CATALOG"), + new DataColumn("CONSTRAINT_SCHEMA"), + new DataColumn("CONSTRAINT_NAME"), + new DataColumn("TABLE_CATALOG"), + new DataColumn("TABLE_SCHEMA"), + new DataColumn("TABLE_NAME"), + new DataColumn("CONSTRAINT_TYPE"), + new DataColumn("IS_DEFERRABLE", typeof(bool)), + new DataColumn("INITIALLY_DEFERRED", typeof(bool)) + } + }; + + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "current_database()", "pgtn.nspname", "pgt.relname", "pgc.conname"), + dataTable, + (reader, row) => + { + row["CONSTRAINT_CATALOG"] = GetFieldValueOrDBNull(reader, 0); + row["CONSTRAINT_SCHEMA"] = GetFieldValueOrDBNull(reader, 1); + row["CONSTRAINT_NAME"] = GetFieldValueOrDBNull(reader, 2); + row["TABLE_CATALOG"] = GetFieldValueOrDBNull(reader, 3); + row["TABLE_SCHEMA"] = GetFieldValueOrDBNull(reader, 4); + row["TABLE_NAME"] = GetFieldValueOrDBNull(reader, 5); + row["CONSTRAINT_TYPE"] = GetFieldValueOrDBNull(reader, 6); + row["IS_DEFERRABLE"] = GetFieldValueOrDBNull(reader, 7); + row["INITIALLY_DEFERRED"] = GetFieldValueOrDBNull(reader, 8); + }, cancellationToken); + } + + static Task GetConstraintColumns(NpgsqlConnection conn, string?[]? restrictions, bool async, CancellationToken cancellationToken = default) + { + var sql = new StringBuilder( + """ +SELECT current_database(), + n.nspname, + c.conname, + current_database(), + n.nspname, + t.relname, + a.attname, + a.attnum, mapping_table.constraint_type FROM pg_constraint c JOIN pg_namespace n on n.oid = c.connamespace @@ -435,112 +653,156 @@ UNION ALL SELECT 'UNIQUE KEY' AS constraint_type, 'u' AS contype ) mapping_table ON mapping_table.contype = c.contype - AND n.nspname NOT IN ('pg_catalog', 'pg_toast')"); + AND n.nspname NOT IN ('pg_catalog', 'pg_toast') +"""); - using var command = BuildCommand(conn, getConstraintColumns, restrictions, false, "current_database()", "n.nspname", "t.relname", "c.conname", "a.attname"); - using var adapter = new NpgsqlDataAdapter(command); - var table = new DataTable("ConstraintColumns") { Locale = CultureInfo.InvariantCulture }; - - await adapter.Fill(table, async, cancellationToken); + var dataTable = new DataTable("ConstraintColumns") + { + Locale = CultureInfo.InvariantCulture, + Columns = + { + new DataColumn("constraint_catalog"), + new DataColumn("constraint_schema"), + new DataColumn("constraint_name"), + new DataColumn("table_catalog"), + new DataColumn("table_schema"), + new DataColumn("table_name"), + new DataColumn("column_name"), + new DataColumn("ordinal_number", typeof(int)), + new DataColumn("constraint_type") + } + }; - return table; - } + return ParseResults( + async, + BuildCommand(conn, sql, restrictions, false, "current_database()", "n.nspname", "t.relname", "c.conname", "a.attname"), + dataTable, + (reader, row) => + { + row["constraint_catalog"] = GetFieldValueOrDBNull(reader, 0); + row["constraint_schema"] = GetFieldValueOrDBNull(reader, 1); + row["constraint_name"] = GetFieldValueOrDBNull(reader, 2); + row["table_catalog"] = GetFieldValueOrDBNull(reader, 3); + row["table_schema"] = GetFieldValueOrDBNull(reader, 4); + row["table_name"] = GetFieldValueOrDBNull(reader, 5); + row["column_name"] = GetFieldValueOrDBNull(reader, 6); + row["ordinal_number"] = GetFieldValueOrDBNull(reader, 7); + row["constraint_type"] = GetFieldValueOrDBNull(reader, 8); + }, cancellationToken); + } - static DataTable GetDataSourceInformation(NpgsqlConnection conn) + static DataTable GetDataSourceInformation(NpgsqlConnection conn) + { + var table = new DataTable("DataSourceInformation"); + var row = table.Rows.Add(); + + table.Columns.Add("CompositeIdentifierSeparatorPattern", typeof(string)); + // TODO: DefaultCatalog? Was in XML (unfilled) but isn't in docs + table.Columns.Add("DataSourceProductName", typeof(string)); + table.Columns.Add("DataSourceProductVersion", typeof(string)); + table.Columns.Add("DataSourceProductVersionNormalized", typeof(string)); + table.Columns.Add("GroupByBehavior", typeof(GroupByBehavior)); + table.Columns.Add("IdentifierPattern", typeof(string)); + table.Columns.Add("IdentifierCase", typeof(IdentifierCase)); + table.Columns.Add("OrderByColumnsInSelect", typeof(bool)); + table.Columns.Add("ParameterMarkerFormat", typeof(string)); + table.Columns.Add("ParameterMarkerPattern", typeof(string)); + table.Columns.Add("ParameterNameMaxLength", typeof(int)); + table.Columns.Add("QuotedIdentifierPattern", typeof(string)); + table.Columns.Add("QuotedIdentifierCase", typeof(IdentifierCase)); + table.Columns.Add("ParameterNamePattern", typeof(string)); + table.Columns.Add("StatementSeparatorPattern", typeof(string)); + table.Columns.Add("StringLiteralPattern", typeof(string)); + table.Columns.Add("SupportedJoinOperators", typeof(SupportedJoinOperators)); + + var version = conn.PostgreSqlVersion; + var normalizedVersion = $"{version.Major:00}.{version.Minor:00}"; + if (version.Build >= 0) + normalizedVersion += $".{version.Build:00}"; + + row["CompositeIdentifierSeparatorPattern"] = @"\."; + row["DataSourceProductName"] = "Npgsql"; + row["DataSourceProductVersion"] = version.ToString(); + row["DataSourceProductVersionNormalized"] = normalizedVersion; + row["GroupByBehavior"] = GroupByBehavior.Unrelated; + row["IdentifierPattern"] = @"(^\[\p{Lo}\p{Lu}\p{Ll}_@#][\p{Lo}\p{Lu}\p{Ll}\p{Nd}@$#_]*$)|(^\[[^\]\0]|\]\]+\]$)|(^\""[^\""\0]|\""\""+\""$)"; + row["IdentifierCase"] = IdentifierCase.Insensitive; + row["OrderByColumnsInSelect"] = false; + row["QuotedIdentifierPattern"] = @"""(([^\""]|\""\"")*)"""; + row["QuotedIdentifierCase"] = IdentifierCase.Sensitive; + row["StatementSeparatorPattern"] = ";"; + row["StringLiteralPattern"] = @"'(([^']|'')*)'"; + row["SupportedJoinOperators"] = + SupportedJoinOperators.FullOuter | + SupportedJoinOperators.Inner | + SupportedJoinOperators.LeftOuter | + SupportedJoinOperators.RightOuter; + + row["ParameterNameMaxLength"] = 63; // For function out parameters + row["ParameterMarkerFormat"] = @"{0}"; // TODO: Not sure + + if (NpgsqlCommand.EnableSqlRewriting) { - var table = new DataTable("DataSourceInformation"); - var row = table.Rows.Add(); - - table.Columns.Add("CompositeIdentifierSeparatorPattern", typeof(string)); - // TODO: DefaultCatalog? Was in XML (unfilled) but isn't in docs - table.Columns.Add("DataSourceProductName", typeof(string)); - table.Columns.Add("DataSourceProductVersion", typeof(string)); - table.Columns.Add("DataSourceProductVersionNormalized", typeof(string)); - table.Columns.Add("GroupByBehavior", typeof(GroupByBehavior)); - table.Columns.Add("IdentifierPattern", typeof(string)); - table.Columns.Add("IdentifierCase", typeof(IdentifierCase)); - table.Columns.Add("OrderByColumnsInSelect", typeof(bool)); - table.Columns.Add("ParameterMarkerFormat", typeof(string)); - table.Columns.Add("ParameterMarkerPattern", typeof(string)); - table.Columns.Add("ParameterNameMaxLength", typeof(int)); - table.Columns.Add("QuotedIdentifierPattern", typeof(string)); - table.Columns.Add("QuotedIdentifierCase", typeof(IdentifierCase)); - table.Columns.Add("ParameterNamePattern", typeof(string)); - table.Columns.Add("StatementSeparatorPattern", typeof(string)); - table.Columns.Add("StringLiteralPattern", typeof(string)); - table.Columns.Add("SupportedJoinOperators", typeof(SupportedJoinOperators)); - - var version = conn.PostgreSqlVersion; - var normalizedVersion = $"{version.Major:00}.{version.Minor:00}"; - if (version.Build >= 0) - normalizedVersion += $".{version.Build:00}"; - - row["CompositeIdentifierSeparatorPattern"] = @"\."; - row["DataSourceProductName"] = "Npgsql"; - row["DataSourceProductVersion"] = version.ToString(); - row["DataSourceProductVersionNormalized"] = normalizedVersion; - row["GroupByBehavior"] = GroupByBehavior.Unrelated; - row["IdentifierPattern"] = @"(^\[\p{Lo}\p{Lu}\p{Ll}_@#][\p{Lo}\p{Lu}\p{Ll}\p{Nd}@$#_]*$)|(^\[[^\]\0]|\]\]+\]$)|(^\""[^\""\0]|\""\""+\""$)"; - row["IdentifierCase"] = IdentifierCase.Insensitive; - row["OrderByColumnsInSelect"] = false; - row["ParameterMarkerFormat"] = @"{0}"; // TODO: Not sure row["ParameterMarkerPattern"] = @"@[\p{Lo}\p{Lu}\p{Ll}\p{Lm}_@#][\p{Lo}\p{Lu}\p{Ll}\p{Lm}\p{Nd}\uff3f_@#\$]*(?=\s+|$)"; - row["ParameterNameMaxLength"] = 63; // For function out parameters - row["QuotedIdentifierPattern"] = @"""(([^\""]|\""\"")*)"""; - row["QuotedIdentifierCase"] = IdentifierCase.Sensitive; row["ParameterNamePattern"] = @"^[\p{Lo}\p{Lu}\p{Ll}\p{Lm}_@#][\p{Lo}\p{Lu}\p{Ll}\p{Lm}\p{Nd}\uff3f_@#\$]*(?=\s+|$)"; - row["StatementSeparatorPattern"] = ";"; - row["StringLiteralPattern"] = @"'(([^']|'')*)'"; - row["SupportedJoinOperators"] = - SupportedJoinOperators.FullOuter | - SupportedJoinOperators.Inner | - SupportedJoinOperators.LeftOuter | - SupportedJoinOperators.RightOuter; - - return table; + } + else + { + row["ParameterMarkerPattern"] = @"$\d+"; + row["ParameterNamePattern"] = @"\d+"; } - #region DataTypes + return table; + } - static DataTable GetDataTypes(NpgsqlConnection conn) + #region DataTypes + + static DataTable GetDataTypes(NpgsqlConnection conn) + { + using var _ = conn.StartTemporaryBindingScope(out var connector); + + var table = new DataTable("DataTypes"); + + table.Columns.Add("TypeName", typeof(string)); + table.Columns.Add("ColumnSize", typeof(long)); + table.Columns.Add("CreateFormat", typeof(string)); + table.Columns.Add("CreateParameters", typeof(string)); + table.Columns.Add("DataType", typeof(string)); + table.Columns.Add("IsAutoIncrementable", typeof(bool)); + table.Columns.Add("IsBestMatch", typeof(bool)); + table.Columns.Add("IsCaseSensitive", typeof(bool)); + table.Columns.Add("IsConcurrencyType", typeof(bool)); + table.Columns.Add("IsFixedLength", typeof(bool)); + table.Columns.Add("IsFixedPrecisionAndScale", typeof(bool)); + table.Columns.Add("IsLiteralSupported", typeof(bool)); + table.Columns.Add("IsLong", typeof(bool)); + table.Columns.Add("IsNullable", typeof(bool)); + table.Columns.Add("IsSearchable", typeof(bool)); + table.Columns.Add("IsSearchableWithLike", typeof(bool)); + table.Columns.Add("IsUnsigned", typeof(bool)); + table.Columns.Add("LiteralPrefix", typeof(string)); + table.Columns.Add("LiteralSuffix", typeof(string)); + table.Columns.Add("MaximumScale", typeof(short)); + table.Columns.Add("MinimumScale", typeof(short)); + table.Columns.Add("NativeDataType", typeof(string)); + table.Columns.Add("ProviderDbType", typeof(int)); + + // Npgsql-specific + table.Columns.Add("OID", typeof(uint)); + + + // TODO: Support type name restriction + try { - using var _ = conn.StartTemporaryBindingScope(out var connector); - - var table = new DataTable("DataTypes"); - - table.Columns.Add("TypeName", typeof(string)); - table.Columns.Add("ColumnSize", typeof(long)); - table.Columns.Add("CreateFormat", typeof(string)); - table.Columns.Add("CreateParameters", typeof(string)); - table.Columns.Add("DataType", typeof(string)); - table.Columns.Add("IsAutoIncrementable", typeof(bool)); - table.Columns.Add("IsBestMatch", typeof(bool)); - table.Columns.Add("IsCaseSensitive", typeof(bool)); - table.Columns.Add("IsConcurrencyType", typeof(bool)); - table.Columns.Add("IsFixedLength", typeof(bool)); - table.Columns.Add("IsFixedPrecisionAndScale", typeof(bool)); - table.Columns.Add("IsLiteralSupported", typeof(bool)); - table.Columns.Add("IsLong", typeof(bool)); - table.Columns.Add("IsNullable", typeof(bool)); - table.Columns.Add("IsSearchable", typeof(bool)); - table.Columns.Add("IsSearchableWithLike", typeof(bool)); - table.Columns.Add("IsUnsigned", typeof(bool)); - table.Columns.Add("LiteralPrefix", typeof(string)); - table.Columns.Add("LiteralSuffix", typeof(string)); - table.Columns.Add("MaximumScale", typeof(short)); - table.Columns.Add("MinimumScale", typeof(short)); - table.Columns.Add("NativeDataType", typeof(string)); - table.Columns.Add("ProviderDbType", typeof(int)); - - // Npgsql-specific - table.Columns.Add("OID", typeof(uint)); - - // TODO: Support type name restriction - - foreach (var baseType in connector.DatabaseInfo.BaseTypes) + PgSerializerOptions.IntrospectionCaller = true; + + var types = new List(); + types.AddRange(connector.DatabaseInfo.BaseTypes); + types.AddRange(connector.DatabaseInfo.EnumTypes); + types.AddRange(connector.DatabaseInfo.CompositeTypes); + foreach (var baseType in types) { - if (!connector.TypeMapper.Mappings.TryGetValue(baseType.Name, out var mapping) && - !connector.TypeMapper.Mappings.TryGetValue(baseType.FullName, out mapping)) + if (connector.SerializerOptions.GetDefaultTypeInfo(baseType) is not { } info) continue; var row = table.Rows.Add(); @@ -548,16 +810,14 @@ static DataTable GetDataTypes(NpgsqlConnection conn) PopulateDefaultDataTypeInfo(row, baseType); PopulateHardcodedDataTypeInfo(row, baseType); - if (mapping.ClrTypes.Length > 0) - row["DataType"] = mapping.ClrTypes[0].FullName; - if (mapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)mapping.NpgsqlDbType.Value; + row["DataType"] = info.Type.FullName; + if (baseType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; } foreach (var arrayType in connector.DatabaseInfo.ArrayTypes) { - if (!connector.TypeMapper.Mappings.TryGetValue(arrayType.Element.Name, out var elementMapping) && - !connector.TypeMapper.Mappings.TryGetValue(arrayType.Element.FullName, out elementMapping)) + if (connector.SerializerOptions.GetDefaultTypeInfo(arrayType) is not { } info) continue; var row = table.Rows.Add(); @@ -569,294 +829,325 @@ static DataTable GetDataTypes(NpgsqlConnection conn) row["TypeName"] = arrayType.DisplayName; row["OID"] = arrayType.OID; row["CreateFormat"] += "[]"; - if (elementMapping.ClrTypes.Length > 0) - row["DataType"] = elementMapping.ClrTypes[0].MakeArrayType().FullName; - if (elementMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)(elementMapping.NpgsqlDbType.Value | NpgsqlDbType.Array); + row["DataType"] = info.Type.FullName; + if (arrayType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; } foreach (var rangeType in connector.DatabaseInfo.RangeTypes) { - if (!connector.TypeMapper.Mappings.TryGetValue(rangeType.Subtype.Name, out var elementMapping) && - !connector.TypeMapper.Mappings.TryGetValue(rangeType.Subtype.FullName, out elementMapping)) + if (connector.SerializerOptions.GetDefaultTypeInfo(rangeType) is not { } info) continue; var row = table.Rows.Add(); PopulateDefaultDataTypeInfo(row, rangeType.Subtype); - // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). + // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). PopulateHardcodedDataTypeInfo(row, rangeType.Subtype); row["TypeName"] = rangeType.DisplayName; row["OID"] = rangeType.OID; row["CreateFormat"] = rangeType.DisplayName.ToUpperInvariant(); - if (elementMapping.ClrTypes.Length > 0) - row["DataType"] = typeof(NpgsqlRange<>).MakeGenericType(elementMapping.ClrTypes[0]).FullName; - if (elementMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)(elementMapping.NpgsqlDbType.Value | NpgsqlDbType.Range); - } - - foreach (var enumType in connector.DatabaseInfo.EnumTypes) - { - if (!connector.TypeMapper.Mappings.TryGetValue(enumType.Name, out var mapping) && - !connector.TypeMapper.Mappings.TryGetValue(enumType.FullName, out mapping)) - continue; - - var row = table.Rows.Add(); - - PopulateDefaultDataTypeInfo(row, enumType); - PopulateHardcodedDataTypeInfo(row, enumType); - - if (mapping.ClrTypes.Length > 0) - row["DataType"] = mapping.ClrTypes[0].FullName; + row["DataType"] = info.Type.FullName; + if (rangeType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; } - foreach (var compositeType in connector.DatabaseInfo.CompositeTypes) + foreach (var multirangeType in connector.DatabaseInfo.MultirangeTypes) { - if (!connector.TypeMapper.Mappings.TryGetValue(compositeType.Name, out var mapping) && - !connector.TypeMapper.Mappings.TryGetValue(compositeType.FullName, out mapping)) + var subtypeType = multirangeType.Subrange.Subtype; + if (connector.SerializerOptions.GetDefaultTypeInfo(multirangeType) is not { } info) continue; var row = table.Rows.Add(); - PopulateDefaultDataTypeInfo(row, compositeType); - PopulateHardcodedDataTypeInfo(row, compositeType); + PopulateDefaultDataTypeInfo(row, subtypeType); + // Populate hardcoded values based on the subtype type (e.g. citext[] is case-insensitive). + PopulateHardcodedDataTypeInfo(row, subtypeType); - if (mapping.ClrTypes.Length > 0) - row["DataType"] = mapping.ClrTypes[0].FullName; + row["TypeName"] = multirangeType.DisplayName; + row["OID"] = multirangeType.OID; + row["CreateFormat"] = multirangeType.DisplayName.ToUpperInvariant(); + row["DataType"] = info.Type.FullName; + if (multirangeType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; } foreach (var domainType in connector.DatabaseInfo.DomainTypes) { - if (!connector.TypeMapper.Mappings.TryGetValue(domainType.BaseType.Name, out var baseMapping) && - !connector.TypeMapper.Mappings.TryGetValue(domainType.BaseType.FullName, out baseMapping)) + var representationalType = domainType.GetRepresentationalType(); + if (connector.SerializerOptions.GetDefaultTypeInfo(representationalType) is not { } info) continue; var row = table.Rows.Add(); - PopulateDefaultDataTypeInfo(row, domainType.BaseType); + PopulateDefaultDataTypeInfo(row, representationalType); // Populate hardcoded values based on the element type (e.g. citext[] is case-insensitive). - PopulateHardcodedDataTypeInfo(row, domainType.BaseType); + PopulateHardcodedDataTypeInfo(row, representationalType); row["TypeName"] = domainType.DisplayName; row["OID"] = domainType.OID; // A domain is never the best match, since its underlying base type is row["IsBestMatch"] = false; - if (baseMapping.ClrTypes.Length > 0) - row["DataType"] = baseMapping.ClrTypes[0].FullName; - if (baseMapping.NpgsqlDbType.HasValue) - row["ProviderDbType"] = (int)baseMapping.NpgsqlDbType.Value; + row["DataType"] = info.Type.FullName; + if (representationalType.DataTypeName.ToNpgsqlDbType() is { } npgsqlDbType) + row["ProviderDbType"] = (int)npgsqlDbType; } - - return table; } - - /// - /// Populates some generic type information that is common for base types, arrays, enums, etc. Some will - /// be overridden later. - /// - static void PopulateDefaultDataTypeInfo(DataRow row, PostgresType type) + finally { - row["TypeName"] = type.DisplayName; - // Skipping ColumnSize at least for now, not very meaningful - row["CreateFormat"] = type.DisplayName.ToUpperInvariant(); - row["CreateParameters"] = ""; - row["IsAutoIncrementable"] = false; - // We populate the DataType above from mapping.ClrTypes, which means we take the .NET type from - // which we *infer* the PostgreSQL type. Since only a single PostgreSQL type gets inferred from a given - // .NET type, we never have the same DataType in more than one row - so the mapping is always the - // best match. See the hardcoding override below for some exceptions. - row["IsBestMatch"] = true; - row["IsCaseSensitive"] = true; - row["IsConcurrencyType"] = false; - row["IsFixedLength"] = false; - row["IsFixedPrecisionAndScale"] = false; - row["IsLiteralSupported"] = false; // See hardcoding override below - row["IsLong"] = false; - row["IsNullable"] = true; - row["IsSearchable"] = true; - row["IsSearchableWithLike"] = false; - row["IsUnsigned"] = DBNull.Value; // See hardcoding override below - // LiteralPrefix/Suffix: no literal for now except for strings, see hardcoding override below - row["MaximumScale"] = DBNull.Value; - row["MinimumScale"] = DBNull.Value; - // NativeDataType is unset - row["OID"] = type.OID; + PgSerializerOptions.IntrospectionCaller = false; } - /// - /// Sets some custom, hardcoded info on a DataType row that cannot be loaded/inferred from PostgreSQL - /// - static void PopulateHardcodedDataTypeInfo(DataRow row, PostgresType type) + return table; + } + + /// + /// Populates some generic type information that is common for base types, arrays, enums, etc. Some will + /// be overridden later. + /// + static void PopulateDefaultDataTypeInfo(DataRow row, PostgresType type) + { + row["TypeName"] = type.DisplayName; + // Skipping ColumnSize at least for now, not very meaningful + row["CreateFormat"] = type.DisplayName.ToUpperInvariant(); + row["CreateParameters"] = ""; + row["IsAutoIncrementable"] = false; + // We populate the DataType above from mapping.ClrTypes, which means we take the .NET type from + // which we *infer* the PostgreSQL type. Since only a single PostgreSQL type gets inferred from a given + // .NET type, we never have the same DataType in more than one row - so the mapping is always the + // best match. See the hardcoding override below for some exceptions. + row["IsBestMatch"] = true; + row["IsCaseSensitive"] = true; + row["IsConcurrencyType"] = false; + row["IsFixedLength"] = false; + row["IsFixedPrecisionAndScale"] = false; + row["IsLiteralSupported"] = false; // See hardcoding override below + row["IsLong"] = false; + row["IsNullable"] = true; + row["IsSearchable"] = true; + row["IsSearchableWithLike"] = false; + row["IsUnsigned"] = DBNull.Value; // See hardcoding override below + // LiteralPrefix/Suffix: no literal for now except for strings, see hardcoding override below + row["MaximumScale"] = DBNull.Value; + row["MinimumScale"] = DBNull.Value; + // NativeDataType is unset + row["OID"] = type.OID; + } + + /// + /// Sets some custom, hardcoded info on a DataType row that cannot be loaded/inferred from PostgreSQL + /// + static void PopulateHardcodedDataTypeInfo(DataRow row, PostgresType type) + { + switch (type.Name) { - switch (type.Name) - { - case "varchar": - case "char": - row["DataType"] = "String"; - row["IsBestMatch"] = false; - goto case "text"; - case "text": - row["CreateFormat"] += "({0})"; - row["CreateParameters"] = "size"; - row["IsSearchableWithLike"] = true; - row["IsLiteralSupported"] = true; - row["LiteralPrefix"] = "'"; - row["LiteralSuffix"] = "'"; - return; - case "numeric": - row["CreateFormat"] += "({0},{1})"; - row["CreateParameters"] = "precision, scale"; - row["MaximumScale"] = 16383; - row["MinimumScale"] = 16383; - row["IsUnsigned"] = false; - return; - case "bytea": - row["IsLong"] = true; - return; - case "citext": - row["IsCaseSensitive"] = false; - return; - case "integer": - case "smallint": - case "bigint": - case "double precision": - case "real": - case "money": - row["IsUnsigned"] = false; - return; - case "oid": - case "cid": - case "regtype": - case "regconfig": - row["IsUnsigned"] = true; - return; - case "xid": - row["IsUnsigned"] = true; - row["IsConcurrencyType"] = true; - return; - } + case "varchar": + case "char": + row["DataType"] = "String"; + row["IsBestMatch"] = false; + goto case "text"; + case "text": + row["CreateFormat"] += "({0})"; + row["CreateParameters"] = "size"; + row["IsSearchableWithLike"] = true; + row["IsLiteralSupported"] = true; + row["LiteralPrefix"] = "'"; + row["LiteralSuffix"] = "'"; + return; + case "numeric": + row["CreateFormat"] += "({0},{1})"; + row["CreateParameters"] = "precision, scale"; + row["MaximumScale"] = 16383; + row["MinimumScale"] = 16383; + row["IsUnsigned"] = false; + return; + case "bytea": + row["IsLong"] = true; + return; + case "citext": + row["IsCaseSensitive"] = false; + return; + case "integer": + case "smallint": + case "bigint": + case "double precision": + case "real": + case "money": + row["IsUnsigned"] = false; + return; + case "oid": + case "cid": + case "regtype": + case "regconfig": + row["IsUnsigned"] = true; + return; + case "xid": + row["IsUnsigned"] = true; + row["IsConcurrencyType"] = true; + return; } + } - #endregion DataTypes + #endregion DataTypes - #region Reserved Keywords + #region Reserved Keywords - static DataTable GetReservedWords() + static DataTable GetReservedWords() + { + var table = new DataTable("ReservedWords") { Locale = CultureInfo.InvariantCulture }; + table.Columns.Add("ReservedWord", typeof(string)); + foreach (var keyword in ReservedKeywords) + table.Rows.Add(keyword); + return table; + } + + /// + /// List of keywords taken from PostgreSQL 9.0 reserved words documentation. + /// + static readonly string[] ReservedKeywords = + { + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "AUTHORIZATION", + "BINARY", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONCURRENTLY", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_SCHEMA", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FREEZE", + "FROM", + "FULL", + "GRANT", + "GROUP", + "HAVING", + "ILIKE", + "IN", + "INITIALLY", + "INNER", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "LATERAL", + "LEADING", + "LEFT", + "LIKE", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NATURAL", + "NOT", + "NOTNULL", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "OUTER", + "OVER", + "OVERLAPS", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "RIGHT", + "SELECT", + "SESSION_USER", + "SIMILAR", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "VERBOSE", + "WHEN", + "WHERE", + "WINDOW", + "WITH" + }; + + #endregion Reserved Keywords + + static async Task ParseResults(bool async, NpgsqlCommand command, DataTable dataTable, Action populateRow, CancellationToken cancellationToken) + { + NpgsqlDataReader? reader = null; + try { - var table = new DataTable("ReservedWords") { Locale = CultureInfo.InvariantCulture }; - table.Columns.Add("ReservedWord", typeof(string)); - foreach (var keyword in ReservedKeywords) - table.Rows.Add(keyword); - return table; - } + reader = async + ? await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false) + : command.ExecuteReader(); + + dataTable.BeginLoadData(); + + while (async ? await reader.ReadAsync(cancellationToken).ConfigureAwait(false) : reader.Read()) + populateRow(reader, dataTable.Rows.Add()); - /// - /// List of keywords taken from PostgreSQL 9.0 reserved words documentation. - /// - static readonly string[] ReservedKeywords = + return dataTable; + } + finally { - "ALL", - "ANALYSE", - "ANALYZE", - "AND", - "ANY", - "ARRAY", - "AS", - "ASC", - "ASYMMETRIC", - "AUTHORIZATION", - "BINARY", - "BOTH", - "CASE", - "CAST", - "CHECK", - "COLLATE", - "COLUMN", - "CONCURRENTLY", - "CONSTRAINT", - "CREATE", - "CROSS", - "CURRENT_CATALOG", - "CURRENT_DATE", - "CURRENT_ROLE", - "CURRENT_SCHEMA", - "CURRENT_TIME", - "CURRENT_TIMESTAMP", - "CURRENT_USER", - "DEFAULT", - "DEFERRABLE", - "DESC", - "DISTINCT", - "DO", - "ELSE", - "END", - "EXCEPT", - "FALSE", - "FETCH", - "FOR", - "FOREIGN", - "FREEZE", - "FROM", - "FULL", - "GRANT", - "GROUP", - "HAVING", - "ILIKE", - "IN", - "INITIALLY", - "INNER", - "INTERSECT", - "INTO", - "IS", - "ISNULL", - "JOIN", - "LATERAL", - "LEADING", - "LEFT", - "LIKE", - "LIMIT", - "LOCALTIME", - "LOCALTIMESTAMP", - "NATURAL", - "NOT", - "NOTNULL", - "NULL", - "OFFSET", - "ON", - "ONLY", - "OR", - "ORDER", - "OUTER", - "OVER", - "OVERLAPS", - "PLACING", - "PRIMARY", - "REFERENCES", - "RETURNING", - "RIGHT", - "SELECT", - "SESSION_USER", - "SIMILAR", - "SOME", - "SYMMETRIC", - "TABLE", - "THEN", - "TO", - "TRAILING", - "TRUE", - "UNION", - "UNIQUE", - "USER", - "USING", - "VARIADIC", - "VERBOSE", - "WHEN", - "WHERE", - "WINDOW", - "WITH" - }; + dataTable.EndLoadData(); - #endregion Reserved Keywords + if (async) + { + if (reader is not null) + await reader.DisposeAsync().ConfigureAwait(false); +#if NETSTANDARD2_0 + command.Dispose(); +#else + await command.DisposeAsync().ConfigureAwait(false); +#endif + } + else + { + reader?.Dispose(); + command.Dispose(); + } + } } + + static object GetFieldValueOrDBNull(NpgsqlDataReader reader, int ordinal) + => reader.IsDBNull(ordinal) ? DBNull.Value : reader.GetFieldValue(ordinal)!; } diff --git a/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs new file mode 100644 index 0000000000..72cfeb4949 --- /dev/null +++ b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs @@ -0,0 +1,687 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Internal.ResolverFactories; +using Npgsql.Properties; +using Npgsql.TypeMapping; +using NpgsqlTypes; + +namespace Npgsql; + +/// +/// Provides a simple API for configuring and creating an , from which database connections can be obtained. +/// +/// +/// On this builder, various features are disabled by default; unless you're looking to save on code size (e.g. when publishing with +/// NativeAOT), use instead. +/// +public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper +{ + static UnsupportedTypeInfoResolver UnsupportedTypeInfoResolver { get; } = new(); + + ILoggerFactory? _loggerFactory; + bool _sensitiveDataLoggingEnabled; + + TransportSecurityHandler _transportSecurityHandler = new(); + RemoteCertificateValidationCallback? _userCertificateValidationCallback; + Action? _clientCertificatesCallback; + + IntegratedSecurityHandler _integratedSecurityHandler = new(); + + Func? _passwordProvider; + Func>? _passwordProviderAsync; + + Func>? _periodicPasswordProvider; + TimeSpan _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval; + + PgTypeInfoResolverChainBuilder _resolverChainBuilder = new(); // mutable struct, don't make readonly. + + readonly UserTypeMapper _userTypeMapper; + + Action? _connectionInitializer; + Func? _connectionInitializerAsync; + + internal JsonSerializerOptions? JsonSerializerOptions { get; private set; } + + internal Action ConfigureDefaultFactories { get; set; } + + /// + /// A connection string builder that can be used to configured the connection string on the builder. + /// + public NpgsqlConnectionStringBuilder ConnectionStringBuilder { get; } + + /// + /// Returns the connection string, as currently configured on the builder. + /// + public string ConnectionString => ConnectionStringBuilder.ToString(); + + static NpgsqlSlimDataSourceBuilder() + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new PgTypeInfoResolverFactory[] { new AdoTypeInfoResolverFactory() }); + + /// + /// A diagnostics name used by Npgsql when generating tracing, logging and metrics. + /// + public string? Name { get; set; } + + /// + /// Constructs a new , optionally starting out from the given + /// . + /// + public NpgsqlSlimDataSourceBuilder(string? connectionString = null) + : this(new NpgsqlConnectionStringBuilder(connectionString)) + {} + + internal NpgsqlSlimDataSourceBuilder(NpgsqlConnectionStringBuilder connectionStringBuilder) + { + ConnectionStringBuilder = connectionStringBuilder; + _userTypeMapper = new() { DefaultNameTranslator = GlobalTypeMapper.Instance.DefaultNameTranslator }; + ConfigureDefaultFactories = static instance => instance.AppendDefaultFactories(); + ConfigureResolverChain = static chain => chain.Add(UnsupportedTypeInfoResolver); + } + + /// + /// Sets the that will be used for logging. + /// + /// The logger factory to be used. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseLoggerFactory(ILoggerFactory? loggerFactory) + { + _loggerFactory = loggerFactory; + return this; + } + + /// + /// Enables parameters to be included in logging. This includes potentially sensitive information from data sent to PostgreSQL. + /// You should only enable this flag in development, or if you have the appropriate security measures in place based on the + /// sensitivity of this data. + /// + /// If , then sensitive data is logged. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableParameterLogging(bool parameterLoggingEnabled = true) + { + _sensitiveDataLoggingEnabled = parameterLoggingEnabled; + return this; + } + + /// + /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. + /// + /// Options to customize JSON serialization and deserialization. + /// + public NpgsqlSlimDataSourceBuilder ConfigureJsonOptions(JsonSerializerOptions serializerOptions) + { + JsonSerializerOptions = serializerOptions; + return this; + } + + #region Authentication + + /// + /// When using SSL/TLS, this is a callback that allows customizing how the PostgreSQL-provided certificate is verified. This is an + /// advanced API, consider using or instead. + /// + /// The callback containing custom callback verification logic. + /// + /// + /// Cannot be used in conjunction with , or + /// . + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseUserCertificateValidationCallback( + RemoteCertificateValidationCallback userCertificateValidationCallback) + { + _userCertificateValidationCallback = userCertificateValidationCallback; + + return this; + } + + /// + /// Specifies an SSL/TLS certificate which Npgsql will send to PostgreSQL for certificate-based authentication. + /// + /// The client certificate to be sent to PostgreSQL when opening a connection. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseClientCertificate(X509Certificate? clientCertificate) + { + if (clientCertificate is null) + return UseClientCertificatesCallback(null); + + var clientCertificates = new X509CertificateCollection { clientCertificate }; + return UseClientCertificates(clientCertificates); + } + + /// + /// Specifies a collection of SSL/TLS certificates which Npgsql will send to PostgreSQL for certificate-based authentication. + /// + /// The client certificate collection to be sent to PostgreSQL when opening a connection. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseClientCertificates(X509CertificateCollection? clientCertificates) + => UseClientCertificatesCallback(clientCertificates is null ? null : certs => certs.AddRange(clientCertificates)); + + /// + /// Specifies a callback to modify the collection of SSL/TLS client certificates which Npgsql will send to PostgreSQL for + /// certificate-based authentication. This is an advanced API, consider using or + /// instead. + /// + /// The callback to modify the client certificate collection. + /// + /// + /// The callback is invoked every time a physical connection is opened, and is therefore suitable for rotating short-lived client + /// certificates. Simply make sure the certificate collection argument has the up-to-date certificate(s). + /// + /// + /// The callback's collection argument already includes any client certificates specified via the connection string or environment + /// variables. + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseClientCertificatesCallback(Action? clientCertificatesCallback) + { + _clientCertificatesCallback = clientCertificatesCallback; + + return this; + } + + /// + /// Sets the that will be used validate SSL certificate, received from the server. + /// + /// The CA certificate. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseRootCertificate(X509Certificate2? rootCertificate) + => rootCertificate is null + ? UseRootCertificateCallback(null) + : UseRootCertificateCallback(() => rootCertificate); + + /// + /// Specifies a callback that will be used to validate SSL certificate, received from the server. + /// + /// The callback to get CA certificate. + /// The same builder instance so that multiple calls can be chained. + /// + /// This overload, which accepts a callback, is suitable for scenarios where the certificate rotates + /// and might change during the lifetime of the application. + /// When that's not the case, use the overload which directly accepts the certificate. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseRootCertificateCallback(Func? rootCertificateCallback) + { + _transportSecurityHandler.RootCertificateCallback = rootCertificateCallback; + + return this; + } + + /// + /// Configures a periodic password provider, which is automatically called by the data source at some regular interval. This is the + /// recommended way to fetch a rotating access token. + /// + /// A callback which returns the password to be sent to PostgreSQL. + /// How long to cache the password before re-invoking the callback. + /// + /// If a password refresh attempt fails, it will be re-attempted with this interval. + /// This should typically be much lower than . + /// + /// The same builder instance so that multiple calls can be chained. + /// + /// + /// The provided callback is invoked in a timer, and not when opening connections. It therefore doesn't affect opening time. + /// + /// + /// The provided cancellation token is only triggered when the entire data source is disposed. If you'd like to apply a timeout to the + /// token fetching, do so within the provided callback. + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UsePeriodicPasswordProvider( + Func>? passwordProvider, + TimeSpan successRefreshInterval, + TimeSpan failureRefreshInterval) + { + if (successRefreshInterval < TimeSpan.Zero) + throw new ArgumentException( + string.Format(NpgsqlStrings.ArgumentMustBePositive, nameof(successRefreshInterval)), nameof(successRefreshInterval)); + if (failureRefreshInterval < TimeSpan.Zero) + throw new ArgumentException( + string.Format(NpgsqlStrings.ArgumentMustBePositive, nameof(failureRefreshInterval)), nameof(failureRefreshInterval)); + + _periodicPasswordProvider = passwordProvider; + _periodicPasswordSuccessRefreshInterval = successRefreshInterval; + _periodicPasswordFailureRefreshInterval = failureRefreshInterval; + + return this; + } + + /// + /// Configures a password provider, which is called by the data source when opening connections. + /// + /// + /// A callback that may be invoked during which returns the password to be sent to PostgreSQL. + /// + /// + /// A callback that may be invoked during which returns the password to be sent to PostgreSQL. + /// + /// The same builder instance so that multiple calls can be chained. + /// + /// + /// The provided callback is invoked when opening connections. Therefore its important the callback internally depends on cached + /// data or returns quickly otherwise. Any unnecessary delay will affect connection opening time. + /// + /// + public NpgsqlSlimDataSourceBuilder UsePasswordProvider( + Func? passwordProvider, + Func>? passwordProviderAsync) + { + if (passwordProvider is null != passwordProviderAsync is null) + throw new ArgumentException(NpgsqlStrings.SyncAndAsyncPasswordProvidersRequired); + + _passwordProvider = passwordProvider; + _passwordProviderAsync = passwordProviderAsync; + return this; + } + + #endregion Authentication + + #region Type mapping + + /// + public INpgsqlNameTranslator DefaultNameTranslator + { + get => _userTypeMapper.DefaultNameTranslator; + set => _userTypeMapper.DefaultNameTranslator = value; + } + + /// + public INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + { + _userTypeMapper.MapEnum(pgName, nameTranslator); + return this; + } + + /// + public bool UnmapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + => _userTypeMapper.UnmapEnum(pgName, nameTranslator); + + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + public INpgsqlTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _userTypeMapper.MapEnum(clrType, pgName, nameTranslator); + return this; + } + + /// + public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _userTypeMapper.UnmapEnum(clrType, pgName, nameTranslator); + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _userTypeMapper.MapComposite(typeof(T), pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _userTypeMapper.UnmapComposite(typeof(T), pgName, nameTranslator); + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _userTypeMapper.UnmapComposite(clrType, pgName, nameTranslator); + + + /// + public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) => _resolverChainBuilder.PrependResolverFactory(factory); + + /// + void INpgsqlTypeMapper.Reset() => _resolverChainBuilder.Clear(); + + internal Action> ConfigureResolverChain { get; set; } + internal void AppendResolverFactory(PgTypeInfoResolverFactory factory) + => _resolverChainBuilder.AppendResolverFactory(factory); + internal void AppendResolverFactory(Func factory) where T : PgTypeInfoResolverFactory + => _resolverChainBuilder.AppendResolverFactory(factory); + + internal void AppendDefaultFactories() + { + // When used publicly we start off with our slim defaults. + _resolverChainBuilder.AppendResolverFactory(_userTypeMapper); + if (GlobalTypeMapper.Instance.GetUserMappingsResolverFactory() is { } userMappingsResolverFactory) + _resolverChainBuilder.AppendResolverFactory(userMappingsResolverFactory); + foreach (var factory in GlobalTypeMapper.Instance.GetPluginResolverFactories()) + _resolverChainBuilder.AppendResolverFactory(factory); + _resolverChainBuilder.AppendResolverFactory(new AdoTypeInfoResolverFactory()); + } + + #endregion Type mapping + + #region Optional opt-ins + + /// + /// Sets up mappings for the PostgreSQL array types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableArrays() + { + _resolverChainBuilder.EnableArrays(); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL range types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableRanges() + { + _resolverChainBuilder.EnableRanges(); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL multirange types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableMultiranges() + { + _resolverChainBuilder.EnableMultiranges(); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL record type as a .NET object[]. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableRecords() + { + AddTypeInfoResolverFactory(new RecordTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL tsquery and tsvector types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableFullTextSearch() + { + AddTypeInfoResolverFactory(new FullTextSearchTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL ltree extension types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableLTree() + { + AddTypeInfoResolverFactory(new LTreeTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up mappings for extra conversions from PostgreSQL to .NET types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableExtraConversions() + { + AddTypeInfoResolverFactory(new ExtraConversionResolverFactory()); + return this; + } + + /// + /// Enables the possibility to use TLS/SSl encryption for connections to PostgreSQL. This does not guarantee that encryption will + /// actually be used; see for more details. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableTransportSecurity() + { + _transportSecurityHandler = new RealTransportSecurityHandler(); + return this; + } + + /// + /// Enables the possibility to use GSS/SSPI authentication for connections to PostgreSQL. This does not guarantee that it will + /// actually be used; see for more details. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableIntegratedSecurity() + { + _integratedSecurityHandler = new RealIntegratedSecurityHandler(); + return this; + } + + /// + /// Sets up dynamic System.Text.Json mappings. This allows mapping arbitrary .NET types to PostgreSQL json and jsonb + /// types, as well as and its derived types. + /// + /// + /// A list of CLR types to map to PostgreSQL jsonb (no need to specify ). + /// + /// + /// A list of CLR types to map to PostgreSQL json (no need to specify ). + /// + /// + /// Due to the dynamic nature of these mappings, they are not compatible with NativeAOT or trimming. + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + public NpgsqlSlimDataSourceBuilder EnableDynamicJson( + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null) + { + _resolverChainBuilder.AppendResolverFactory(() => new JsonDynamicTypeInfoResolverFactory(jsonbClrTypes, jsonClrTypes, JsonSerializerOptions)); + return this; + } + + /// + /// Sets up mappings for the PostgreSQL record type as a .NET or . + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode("The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + public NpgsqlSlimDataSourceBuilder EnableRecordsAsTuples() + { + AddTypeInfoResolverFactory(new TupledRecordTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up mappings allowing the use of unmapped enum, range and multirange types. + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + public NpgsqlSlimDataSourceBuilder EnableUnmappedTypes() + { + AddTypeInfoResolverFactory(new UnmappedTypeInfoResolverFactory()); + return this; + } + + #endregion Optional opt-ins + + /// + /// Register a connection initializer, which allows executing arbitrary commands when a physical database connection is first opened. + /// + /// + /// A synchronous connection initialization lambda, which will be called from when a new physical + /// connection is opened. + /// + /// + /// An asynchronous connection initialization lambda, which will be called from + /// when a new physical connection is opened. + /// + /// + /// If an initializer is registered, both sync and async versions must be provided. If you do not use sync APIs in your code, simply + /// throw , which would also catch accidental cases of sync opening. + /// + /// + /// Take care that the setting you apply in the initializer does not get reverted when the connection is returned to the pool, since + /// Npgsql sends DISCARD ALL by default. The option can be used to + /// turn this off. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UsePhysicalConnectionInitializer( + Action? connectionInitializer, + Func? connectionInitializerAsync) + { + if (connectionInitializer is null != connectionInitializerAsync is null) + throw new ArgumentException(NpgsqlStrings.SyncAndAsyncConnectionInitializersRequired); + + _connectionInitializer = connectionInitializer; + _connectionInitializerAsync = connectionInitializerAsync; + + return this; + } + + /// + /// Builds and returns an which is ready for use. + /// + public NpgsqlDataSource Build() + { + var config = PrepareConfiguration(); + var connectionStringBuilder = ConnectionStringBuilder.Clone(); + + if (ConnectionStringBuilder.Host!.Contains(",")) + { + ValidateMultiHost(); + + return new NpgsqlMultiHostDataSource(connectionStringBuilder, config); + } + + return ConnectionStringBuilder.Multiplexing + ? new MultiplexingDataSource(connectionStringBuilder, config) + : ConnectionStringBuilder.Pooling + ? new PoolingDataSource(connectionStringBuilder, config) + : new UnpooledDataSource(connectionStringBuilder, config); + } + + /// + /// Builds and returns a which is ready for use for load-balancing and failover scenarios. + /// + public NpgsqlMultiHostDataSource BuildMultiHost() + { + var config = PrepareConfiguration(); + + ValidateMultiHost(); + + return new(ConnectionStringBuilder.Clone(), config); + } + + NpgsqlDataSourceConfiguration PrepareConfiguration() + { + ConnectionStringBuilder.PostProcessAndValidate(); + + if (!_transportSecurityHandler.SupportEncryption && (_userCertificateValidationCallback is not null || _clientCertificatesCallback is not null)) + { + throw new InvalidOperationException(NpgsqlStrings.TransportSecurityDisabled); + } + + if (_passwordProvider is not null && _periodicPasswordProvider is not null) + { + throw new NotSupportedException(NpgsqlStrings.CannotSetMultiplePasswordProviderKinds); + } + + if ((_passwordProvider is not null || _periodicPasswordProvider is not null) && + (ConnectionStringBuilder.Password is not null || ConnectionStringBuilder.Passfile is not null)) + { + throw new NotSupportedException(NpgsqlStrings.CannotSetBothPasswordProviderAndPassword); + } + + ConfigureDefaultFactories(this); + + return new( + Name, + _loggerFactory is null + ? NpgsqlLoggingConfiguration.NullConfiguration + : new NpgsqlLoggingConfiguration(_loggerFactory, _sensitiveDataLoggingEnabled), + _transportSecurityHandler, + _integratedSecurityHandler, + _userCertificateValidationCallback, + _clientCertificatesCallback, + _passwordProvider, + _passwordProviderAsync, + _periodicPasswordProvider, + _periodicPasswordSuccessRefreshInterval, + _periodicPasswordFailureRefreshInterval, + _resolverChainBuilder.Build(ConfigureResolverChain), + HackyEnumMappings(), + DefaultNameTranslator, + _connectionInitializer, + _connectionInitializerAsync); + + List HackyEnumMappings() + { + var mappings = new List(); + + if (_userTypeMapper.Items.Count > 0) + foreach (var userTypeMapping in _userTypeMapper.Items) + if (userTypeMapping is UserTypeMapper.EnumMapping enumMapping) + mappings.Add(new(enumMapping.ClrType, enumMapping.PgTypeName, enumMapping.NameTranslator)); + + if (GlobalTypeMapper.Instance.HackyEnumTypeMappings.Count > 0) + mappings.AddRange(GlobalTypeMapper.Instance.HackyEnumTypeMappings); + + return mappings; + } + } + + void ValidateMultiHost() + { + if (ConnectionStringBuilder.TargetSessionAttributes is not null) + throw new InvalidOperationException(NpgsqlStrings.CannotSpecifyTargetSessionAttributes); + if (ConnectionStringBuilder.Multiplexing) + throw new NotSupportedException("Multiplexing is not supported with multiple hosts"); + if (ConnectionStringBuilder.ReplicationMode != ReplicationMode.Off) + throw new NotSupportedException("Replication is not supported with multiple hosts"); + } + + INpgsqlTypeMapper INpgsqlTypeMapper.ConfigureJsonOptions(JsonSerializerOptions serializerOptions) + => ConfigureJsonOptions(serializerOptions); + + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode( + "Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableDynamicJson(Type[]? jsonbClrTypes, Type[]? jsonClrTypes) + => EnableDynamicJson(jsonbClrTypes, jsonClrTypes); + + [RequiresUnreferencedCode( + "The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableRecordsAsTuples() + => EnableRecordsAsTuples(); + + [RequiresUnreferencedCode( + "The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper INpgsqlTypeMapper.EnableUnmappedTypes() + => EnableUnmappedTypes(); +} diff --git a/src/Npgsql/NpgsqlSqlEventSource.cs b/src/Npgsql/NpgsqlSqlEventSource.cs index 6daac26276..1e37a2355f 100644 --- a/src/Npgsql/NpgsqlSqlEventSource.cs +++ b/src/Npgsql/NpgsqlSqlEventSource.cs @@ -1,31 +1,28 @@ using System.Diagnostics.Tracing; -using System.Runtime.CompilerServices; -namespace Npgsql +namespace Npgsql; + +sealed class NpgsqlSqlEventSource : EventSource { - sealed class NpgsqlSqlEventSource : EventSource - { - public static readonly NpgsqlSqlEventSource Log = new NpgsqlSqlEventSource(); + public static readonly NpgsqlSqlEventSource Log = new(); - const string EventSourceName = "Npgsql.Sql"; + const string EventSourceName = "Npgsql.Sql"; - const int CommandStartId = 3; - const int CommandStopId = 4; + const int CommandStartId = 3; + const int CommandStopId = 4; - internal NpgsqlSqlEventSource() : base(EventSourceName) {} + internal NpgsqlSqlEventSource() : base(EventSourceName) {} - // NOTE - // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They - // enable creating 'activities'. - // For more information, take a look at the following blog post: - // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ - // - A stop event's event id must be next one after its start event. + // NOTE + // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They + // enable creating 'activities'. + // For more information, take a look at the following blog post: + // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ + // - A stop event's event id must be next one after its start event. - [Event(CommandStartId, Level = EventLevel.Informational)] - public void CommandStart(string sql) => Log.WriteEvent(CommandStartId, sql); + [Event(CommandStartId, Level = EventLevel.Informational)] + public void CommandStart(string sql) => WriteEvent(CommandStartId, sql); - [MethodImpl(MethodImplOptions.NoInlining)] - [Event(CommandStopId, Level = EventLevel.Informational)] - public void CommandStop() => Log.WriteEvent(CommandStopId); - } + [Event(CommandStopId, Level = EventLevel.Informational)] + public void CommandStop() => WriteEvent(CommandStopId); } diff --git a/src/Npgsql/NpgsqlStatement.cs b/src/Npgsql/NpgsqlStatement.cs deleted file mode 100644 index 00ddb5ffbf..0000000000 --- a/src/Npgsql/NpgsqlStatement.cs +++ /dev/null @@ -1,124 +0,0 @@ -using System.Collections.Generic; -using Npgsql.BackendMessages; - -namespace Npgsql -{ - /// - /// Represents a single SQL statement within Npgsql. - /// - /// Instances aren't constructed directly; users should construct an - /// object and populate its property as in standard ADO.NET. - /// Npgsql will analyze that property and constructed instances of - /// internally. - /// - /// Users can retrieve instances from - /// and access information about statement execution (e.g. affected rows). - /// - public sealed class NpgsqlStatement - { - /// - /// The SQL text of the statement. - /// - public string SQL { get; set; } = string.Empty; - - /// - /// Specifies the type of query, e.g. SELECT. - /// - public StatementType StatementType { get; internal set; } - - /// - /// The number of rows affected or retrieved. - /// - /// - /// See the command tag in the CommandComplete message, - /// https://www.postgresql.org/docs/current/static/protocol-message-formats.html - /// - public uint Rows => (uint)LongRows; - - /// - /// The number of rows affected or retrieved. - /// - /// - /// See the command tag in the CommandComplete message, - /// https://www.postgresql.org/docs/current/static/protocol-message-formats.html - /// - public ulong LongRows { get; internal set; } - - /// - /// For an INSERT, the object ID of the inserted row if is 1 and - /// the target table has OIDs; otherwise 0. - /// - public uint OID { get; internal set; } - - /// - /// The input parameters sent with this statement. - /// - public List InputParameters { get; } = new List(); - - /// - /// The RowDescription message for this query. If null, the query does not return rows (e.g. INSERT) - /// - internal RowDescriptionMessage? Description - { - get => PreparedStatement == null ? _description : PreparedStatement.Description; - set - { - if (PreparedStatement == null) - _description = value; - else - PreparedStatement.Description = value; - } - } - - RowDescriptionMessage? _description; - - /// - /// If this statement has been automatically prepared, references the . - /// Null otherwise. - /// - internal PreparedStatement? PreparedStatement - { - get => _preparedStatement != null && _preparedStatement.State == PreparedState.Unprepared - ? _preparedStatement = null - : _preparedStatement; - set => _preparedStatement = value; - } - - PreparedStatement? _preparedStatement; - - internal bool IsPreparing; - - /// - /// Holds the server-side (prepared) statement name. Empty string for non-prepared statements. - /// - internal string StatementName => PreparedStatement?.Name ?? ""; - - /// - /// Whether this statement has already been prepared (including automatic preparation). - /// - internal bool IsPrepared => PreparedStatement?.IsPrepared == true; - - internal void Reset() - { - SQL = string.Empty; - StatementType = StatementType.Select; - _description = null; - LongRows = 0; - OID = 0; - InputParameters.Clear(); - PreparedStatement = null; - } - - internal void ApplyCommandComplete(CommandCompleteMessage msg) - { - StatementType = msg.StatementType; - LongRows = msg.Rows; - OID = msg.OID; - } - - /// - /// Returns the SQL text of the statement. - /// - public override string ToString() => SQL ?? ""; - } -} diff --git a/src/Npgsql/NpgsqlTracingOptions.cs b/src/Npgsql/NpgsqlTracingOptions.cs new file mode 100644 index 0000000000..4aa61beec6 --- /dev/null +++ b/src/Npgsql/NpgsqlTracingOptions.cs @@ -0,0 +1,9 @@ +namespace Npgsql; + +/// +/// Options to configure Npgsql's support for OpenTelemetry tracing. +/// Currently no options are available. +/// +public class NpgsqlTracingOptions +{ +} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlTransaction.cs b/src/Npgsql/NpgsqlTransaction.cs index 0a7921b753..1c4e01e049 100644 --- a/src/Npgsql/NpgsqlTransaction.cs +++ b/src/Npgsql/NpgsqlTransaction.cs @@ -2,456 +2,493 @@ using System.Data; using System.Data.Common; using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Npgsql.Logging; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; -namespace Npgsql +namespace Npgsql; + +/// +/// Represents a transaction to be made in a PostgreSQL database. This class cannot be inherited. +/// +public sealed class NpgsqlTransaction : DbTransaction { + #region Fields and Properties + /// - /// Represents a transaction to be made in a PostgreSQL database. This class cannot be inherited. + /// Specifies the object associated with the transaction. /// - public sealed class NpgsqlTransaction : DbTransaction + /// The object associated with the transaction. + public new NpgsqlConnection? Connection { - #region Fields and Properties - - /// - /// Specifies the object associated with the transaction. - /// - /// The object associated with the transaction. - public new NpgsqlConnection? Connection + get { - get - { - CheckReady(); - return _connector.Connection; - } + CheckDisposed(); + return _connector?.Connection; } + } + + // Note that with ambient transactions, it's possible for a transaction to be pending after its connection + // is already closed. So we capture the connector and perform everything directly on it. + NpgsqlConnector _connector; + + /// + /// Specifies the object associated with the transaction. + /// + /// The object associated with the transaction. + protected override DbConnection? DbConnection => Connection; - // Note that with ambient transactions, it's possible for a transaction to be pending after its connection - // is already closed. So we capture the connector and perform everything directly on it. - NpgsqlConnector _connector; - - /// - /// Specifies the object associated with the transaction. - /// - /// The object associated with the transaction. - protected override DbConnection? DbConnection => Connection; - - /// - /// If true, the transaction has been committed/rolled back, but not disposed. - /// - internal bool IsCompleted => _connector is null || _connector.TransactionStatus == TransactionStatus.Idle; - - internal bool IsDisposed; - - /// - /// Specifies the IsolationLevel for this transaction. - /// - /// The IsolationLevel for this transaction. - /// The default is ReadCommitted. - public override IsolationLevel IsolationLevel + /// + /// If true, the transaction has been committed/rolled back, but not disposed. + /// + internal bool IsCompleted => _connector is null || _connector.TransactionStatus == TransactionStatus.Idle; + + internal bool IsDisposed; + + Exception? _disposeReason; + + /// + /// Specifies the isolation level for this transaction. + /// + /// The isolation level for this transaction. The default is . + public override IsolationLevel IsolationLevel + { + get { - get - { - CheckReady(); - return _isolationLevel; - } + CheckReady(); + return _isolationLevel; } - IsolationLevel _isolationLevel; - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlTransaction)); + } + IsolationLevel _isolationLevel; - const IsolationLevel DefaultIsolationLevel = IsolationLevel.ReadCommitted; + readonly ILogger _transactionLogger; - #endregion + const IsolationLevel DefaultIsolationLevel = IsolationLevel.ReadCommitted; - #region Initialization + #endregion - internal NpgsqlTransaction(NpgsqlConnector connector) - => _connector = connector; + #region Initialization - internal void Init(IsolationLevel isolationLevel = DefaultIsolationLevel) - { - Debug.Assert(isolationLevel != IsolationLevel.Chaos); + internal NpgsqlTransaction(NpgsqlConnector connector) + { + _connector = connector; + _transactionLogger = connector.TransactionLogger; + } - if (!_connector.DatabaseInfo.SupportsTransactions) - return; + internal void Init(IsolationLevel isolationLevel = DefaultIsolationLevel) + { + Debug.Assert(isolationLevel != IsolationLevel.Chaos); - Log.Debug($"Beginning transaction with isolation level {isolationLevel}", _connector.Id); - switch (isolationLevel) - { - case IsolationLevel.RepeatableRead: - case IsolationLevel.Snapshot: - _connector.PrependInternalMessage(PregeneratedMessages.BeginTransRepeatableRead, 2); - break; - case IsolationLevel.Serializable: - _connector.PrependInternalMessage(PregeneratedMessages.BeginTransSerializable, 2); - break; - case IsolationLevel.ReadUncommitted: - // PG doesn't really support ReadUncommitted, it's the same as ReadCommitted. But we still - // send as if. - _connector.PrependInternalMessage(PregeneratedMessages.BeginTransReadUncommitted, 2); - break; - case IsolationLevel.ReadCommitted: - _connector.PrependInternalMessage(PregeneratedMessages.BeginTransReadCommitted, 2); - break; - case IsolationLevel.Unspecified: - isolationLevel = DefaultIsolationLevel; - goto case DefaultIsolationLevel; - default: - throw new NotSupportedException("Isolation level not supported: " + isolationLevel); - } + if (!_connector.DatabaseInfo.SupportsTransactions) + return; - _connector.TransactionStatus = TransactionStatus.Pending; - _isolationLevel = isolationLevel; - IsDisposed = false; + switch (isolationLevel) + { + case IsolationLevel.RepeatableRead: + case IsolationLevel.Snapshot: + _connector.PrependInternalMessage(PregeneratedMessages.BeginTransRepeatableRead, 2); + break; + case IsolationLevel.Serializable: + _connector.PrependInternalMessage(PregeneratedMessages.BeginTransSerializable, 2); + break; + case IsolationLevel.ReadUncommitted: + // PG doesn't really support ReadUncommitted, it's the same as ReadCommitted. But we still + // send as if. + _connector.PrependInternalMessage(PregeneratedMessages.BeginTransReadUncommitted, 2); + break; + case IsolationLevel.ReadCommitted: + _connector.PrependInternalMessage(PregeneratedMessages.BeginTransReadCommitted, 2); + break; + case IsolationLevel.Unspecified: + isolationLevel = DefaultIsolationLevel; + goto case DefaultIsolationLevel; + default: + throw new NotSupportedException("Isolation level not supported: " + isolationLevel); } - #endregion + _connector.TransactionStatus = TransactionStatus.Pending; + _isolationLevel = isolationLevel; + IsDisposed = false; - #region Commit + LogMessages.StartedTransaction(_transactionLogger, isolationLevel, _connector.Id); + } - /// - /// Commits the database transaction. - /// - public override void Commit() => Commit(false).GetAwaiter().GetResult(); + #endregion - async Task Commit(bool async, CancellationToken cancellationToken = default) - { - CheckReady(); + #region Commit - if (!_connector.DatabaseInfo.SupportsTransactions) - return; + /// + /// Commits the database transaction. + /// + public override void Commit() => Commit(false).GetAwaiter().GetResult(); - using (_connector.StartUserAction(cancellationToken)) - { - Log.Debug("Committing transaction", _connector.Id); - await _connector.ExecuteInternalCommand(PregeneratedMessages.CommitTransaction, async, cancellationToken); - } + async Task Commit(bool async, CancellationToken cancellationToken = default) + { + CheckReady(); + + if (!_connector.DatabaseInfo.SupportsTransactions) + return; + + using (_connector.StartUserAction(cancellationToken)) + { + await _connector.ExecuteInternalCommand(PregeneratedMessages.CommitTransaction, async, cancellationToken).ConfigureAwait(false); + LogMessages.CommittedTransaction(_transactionLogger, _connector.Id); } + } - /// - /// Commits the database transaction. - /// - /// The token to monitor for cancellation requests. The default value is . + /// + /// Commits the database transaction. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// #if NETSTANDARD2_0 - public Task CommitAsync(CancellationToken cancellationToken = default) + public Task CommitAsync(CancellationToken cancellationToken = default) #else - public override Task CommitAsync(CancellationToken cancellationToken = default) + public override Task CommitAsync(CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return Commit(true, cancellationToken); - } + => Commit(async: true, cancellationToken); + + #endregion - #endregion + #region Rollback + + /// + /// Rolls back a transaction from a pending state. + /// + public override void Rollback() => Rollback(false).GetAwaiter().GetResult(); - #region Rollback + async Task Rollback(bool async, CancellationToken cancellationToken = default) + { + CheckReady(); - /// - /// Rolls back a transaction from a pending state. - /// - public override void Rollback() => Rollback(false).GetAwaiter().GetResult(); + if (!_connector.DatabaseInfo.SupportsTransactions) + return; - Task Rollback(bool async, CancellationToken cancellationToken = default) + using (_connector.StartUserAction(cancellationToken)) { - CheckReady(); - return _connector.DatabaseInfo.SupportsTransactions - ? _connector.Rollback(async, cancellationToken) - : Task.CompletedTask; + await _connector.Rollback(async, cancellationToken).ConfigureAwait(false); + LogMessages.RolledBackTransaction(_transactionLogger, _connector.Id); } + } - /// - /// Rolls back a transaction from a pending state. - /// - /// The token to monitor for cancellation requests. The default value is . + /// + /// Rolls back a transaction from a pending state. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// #if NETSTANDARD2_0 - public Task RollbackAsync(CancellationToken cancellationToken = default) + public Task RollbackAsync(CancellationToken cancellationToken = default) #else - public override Task RollbackAsync(CancellationToken cancellationToken = default) + public override Task RollbackAsync(CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return Rollback(true, cancellationToken); - } + => Rollback(async: true, cancellationToken); - #endregion + #endregion - #region Savepoints + #region Savepoints - /// - /// Creates a transaction save point. - /// - /// The name of the savepoint. - /// - /// This method does not cause a database roundtrip to be made. The savepoint creation statement will instead be sent along with - /// the next command. - /// -#if NET - public override void Save(string name) + /// + /// Creates a transaction save point. + /// + /// The name of the savepoint. + /// + /// This method does not cause a database roundtrip to be made. The savepoint creation statement will instead be sent along with + /// the next command. + /// +#if NET5_0_OR_GREATER + public override void Save(string name) #else - public void Save(string name) + public void Save(string name) #endif - { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); - - CheckReady(); - if (!_connector.DatabaseInfo.SupportsTransactions) - return; + { + if (name == null) + throw new ArgumentNullException(nameof(name)); + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentException("name can't be empty", nameof(name)); - // Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't - // have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions) - using var _ = _connector.StartUserAction(); + CheckReady(); + if (!_connector.DatabaseInfo.SupportsTransactions) + return; - Log.Debug($"Creating savepoint {name}", _connector.Id); + // Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't + // have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions) + using var _ = _connector.StartUserAction(); - if (RequiresQuoting(name)) - name = $"\"{name.Replace("\"", "\"\"")}\""; + LogMessages.CreatingSavepoint(_transactionLogger, name, _connector.Id); - // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters. - // Since we are prepending, we assume below that the statement will always fit in the buffer. - _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query); - _connector.WriteBuffer.WriteInt32( - sizeof(int) + // Message length (including self excluding code) - _connector.TextEncoding.GetByteCount("SAVEPOINT ") + - _connector.TextEncoding.GetByteCount(name) + - sizeof(byte)); // Null terminator + if (RequiresQuoting(name)) + name = $"\"{name.Replace("\"", "\"\"")}\""; - _connector.WriteBuffer.WriteString("SAVEPOINT "); - _connector.WriteBuffer.WriteString(name); - _connector.WriteBuffer.WriteByte(0); + // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters. + // Since we are prepending, we assume below that the statement will always fit in the buffer. + _connector.WriteQuery("SAVEPOINT " + name, async: false).GetAwaiter().GetResult(); - _connector.PendingPrependedResponses += 2; - } + _connector.PendingPrependedResponses += 2; + } - /// - /// Creates a transaction save point. - /// - /// The name of the savepoint. - /// The token to monitor for cancellation requests. The default value is . - /// - /// This method does not cause a database roundtrip to be made, and will therefore always complete synchronously. - /// The savepoint creation statement will instead be sent along with the next command. - /// -#if NET - public override Task SaveAsync(string name, CancellationToken cancellationToken = default) + /// + /// Creates a transaction save point. + /// + /// The name of the savepoint. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + /// This method does not cause a database roundtrip to be made, and will therefore always complete synchronously. + /// The savepoint creation statement will instead be sent along with the next command. + /// +#if NET5_0_OR_GREATER + public override Task SaveAsync(string name, CancellationToken cancellationToken = default) #else - public Task SaveAsync(string name, CancellationToken cancellationToken = default) + public Task SaveAsync(string name, CancellationToken cancellationToken = default) #endif - { - Save(name); - return Task.CompletedTask; - } + { + Save(name); + return Task.CompletedTask; + } - async Task Rollback(string name, bool async, CancellationToken cancellationToken = default) + async Task Rollback(bool async, string name, CancellationToken cancellationToken = default) + { + if (name == null) + throw new ArgumentNullException(nameof(name)); + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentException("name can't be empty", nameof(name)); + + CheckReady(); + if (!_connector.DatabaseInfo.SupportsTransactions) + return; + using (_connector.StartUserAction(cancellationToken)) { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); - - CheckReady(); - if (!_connector.DatabaseInfo.SupportsTransactions) - return; - using (_connector.StartUserAction(cancellationToken)) - { - Log.Debug($"Rolling back savepoint {name}", _connector.Id); - - if (RequiresQuoting(name)) - name = $"\"{name.Replace("\"", "\"\"")}\""; - - await _connector.ExecuteInternalCommand($"ROLLBACK TO SAVEPOINT {name}", async, cancellationToken); - } + var quotedName = RequiresQuoting(name) ? $"\"{name.Replace("\"", "\"\"")}\"" : name; + await _connector.ExecuteInternalCommand($"ROLLBACK TO SAVEPOINT {quotedName}", async, cancellationToken).ConfigureAwait(false); + LogMessages.RolledBackToSavepoint(_transactionLogger, name, _connector.Id); } + } - /// - /// Rolls back a transaction from a pending savepoint state. - /// - /// The name of the savepoint. -#if NET - public override void Rollback(string name) + /// + /// Rolls back a transaction from a pending savepoint state. + /// + /// The name of the savepoint. +#if NET5_0_OR_GREATER + public override void Rollback(string name) #else - public void Rollback(string name) + public void Rollback(string name) #endif - => Rollback(name, false).GetAwaiter().GetResult(); - - /// - /// Rolls back a transaction from a pending savepoint state. - /// - /// The name of the savepoint. - /// The token to monitor for cancellation requests. The default value is . -#if NET - public override Task RollbackAsync(string name, CancellationToken cancellationToken = default) + => Rollback(async: false, name).GetAwaiter().GetResult(); + + /// + /// Rolls back a transaction from a pending savepoint state. + /// + /// The name of the savepoint. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// +#if NET5_0_OR_GREATER + public override Task RollbackAsync(string name, CancellationToken cancellationToken = default) #else - public Task RollbackAsync(string name, CancellationToken cancellationToken = default) + public Task RollbackAsync(string name, CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return Rollback(name, true, cancellationToken); - } + => Rollback(async: true, name, cancellationToken); - async Task Release(string name, bool async, CancellationToken cancellationToken = default) + async Task Release(bool async, string name, CancellationToken cancellationToken = default) + { + if (name == null) + throw new ArgumentNullException(nameof(name)); + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentException("name can't be empty", nameof(name)); + + CheckReady(); + if (!_connector.DatabaseInfo.SupportsTransactions) + return; + using (_connector.StartUserAction(cancellationToken)) { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); - - CheckReady(); - if (!_connector.DatabaseInfo.SupportsTransactions) - return; - using (_connector.StartUserAction(cancellationToken)) - { - Log.Debug($"Releasing savepoint {name}", _connector.Id); - - if (RequiresQuoting(name)) - name = $"\"{name.Replace("\"", "\"\"")}\""; - - await _connector.ExecuteInternalCommand($"RELEASE SAVEPOINT {name}", async, cancellationToken); - } + var quotedName = RequiresQuoting(name) ? $"\"{name.Replace("\"", "\"\"")}\"" : name; + await _connector.ExecuteInternalCommand($"RELEASE SAVEPOINT {quotedName}", async, cancellationToken).ConfigureAwait(false); + LogMessages.ReleasedSavepoint(_transactionLogger, name, _connector.Id); } + } - /// - /// Releases a transaction from a pending savepoint state. - /// - /// The name of the savepoint. -#if NET - public override void Release(string name) => Release(name, false).GetAwaiter().GetResult(); + /// + /// Releases a transaction from a pending savepoint state. + /// + /// The name of the savepoint. +#if NET5_0_OR_GREATER + public override void Release(string name) #else - public void Release(string name) => Release(name, false).GetAwaiter().GetResult(); + public void Release(string name) #endif + => Release(async: false, name).GetAwaiter().GetResult(); - /// - /// Releases a transaction from a pending savepoint state. - /// - /// The name of the savepoint. - /// The token to monitor for cancellation requests. The default value is . -#if NET - public override Task ReleaseAsync(string name, CancellationToken cancellationToken = default) + /// + /// Releases a transaction from a pending savepoint state. + /// + /// The name of the savepoint. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// +#if NET5_0_OR_GREATER + public override Task ReleaseAsync(string name, CancellationToken cancellationToken = default) #else - public Task ReleaseAsync(string name, CancellationToken cancellationToken = default) + public Task ReleaseAsync(string name, CancellationToken cancellationToken = default) #endif - { - using (NoSynchronizationContextScope.Enter()) - return Release(name, true, cancellationToken); - } + => Release(async: false, name, cancellationToken); - #endregion + /// + /// Indicates whether this transaction supports database savepoints. + /// +#if NET5_0_OR_GREATER + public override bool SupportsSavepoints +#else + public bool SupportsSavepoints +#endif + { + get => _connector.DatabaseInfo.SupportsTransactions; + } - #region Dispose + #endregion - /// - /// Disposes the transaction, rolling it back if it is still pending. - /// - protected override void Dispose(bool disposing) - { - if (IsDisposed) - return; + #region Dispose - if (disposing) + /// + /// Disposes the transaction, rolling it back if it is still pending. + /// + protected override void Dispose(bool disposing) + { + if (IsDisposed) + return; + + if (disposing) + { + if (!IsCompleted) { - if (!IsCompleted) + try { - // We're disposing, so no cancellation token _connector.CloseOngoingOperations(async: false).GetAwaiter().GetResult(); Rollback(); } - - IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + catch + { + Debug.Assert(_connector.IsBroken); + } } + + IsDisposed = true; + _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); } + } - /// - /// Disposes the transaction, rolling it back if it is still pending. - /// + /// + /// Disposes the transaction, rolling it back if it is still pending. + /// #if NETSTANDARD2_0 - public ValueTask DisposeAsync() + public ValueTask DisposeAsync() #else - public override ValueTask DisposeAsync() + public override ValueTask DisposeAsync() #endif + { + if (!IsDisposed) { - if (!IsDisposed) + if (!IsCompleted) { - if (!IsCompleted) - { - using (NoSynchronizationContextScope.Enter()) - return DisposeAsyncInternal(); - } - - IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + return DisposeAsyncInternal(); } - return default; - async ValueTask DisposeAsyncInternal() + IsDisposed = true; + _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + } + return default; + + async ValueTask DisposeAsyncInternal() + { + // We're disposing, so no cancellation token + try { - // We're disposing, so no cancellation token - await _connector.CloseOngoingOperations(async: true); - await Rollback(async: true); - IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + await _connector.CloseOngoingOperations(async: true).ConfigureAwait(false); + await Rollback(async: true).ConfigureAwait(false); } + catch (Exception ex) + { + Debug.Assert(_connector.IsBroken); + LogMessages.ExceptionDuringTransactionDispose(_transactionLogger, _connector.Id, ex); + } + + IsDisposed = true; + _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); } + } - /// - /// Disposes the transaction, without rolling back. Used only in special circumstances, e.g. when - /// the connection is broken. - /// - internal void DisposeImmediately() => IsDisposed = true; + /// + /// Disposes the transaction, without rolling back. Used only in special circumstances, e.g. when + /// the connection is broken. + /// + internal void DisposeImmediately(Exception? disposeReason) + { + IsDisposed = true; + _disposeReason = disposeReason; + } - #endregion + #endregion - #region Checks + #region Checks - void CheckReady() - { - if (IsDisposed) - throw new ObjectDisposedException(typeof(NpgsqlTransaction).Name); - if (IsCompleted) - throw new InvalidOperationException("This NpgsqlTransaction has completed; it is no longer usable."); - } + void CheckReady() + { + CheckDisposed(); + if (IsCompleted) + ThrowHelper.ThrowInvalidOperationException("This NpgsqlTransaction has completed; it is no longer usable."); + } - static bool RequiresQuoting(string identifier) - { - Debug.Assert(identifier.Length > 0); + void CheckDisposed() + { + if (IsDisposed) + ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlTransaction), _disposeReason); + } - var first = identifier[0]; - if (first != '_' && !char.IsLower(first)) - return true; + static bool RequiresQuoting(string identifier) + { + Debug.Assert(identifier.Length > 0); - foreach (var c in identifier.AsSpan(1)) - if (c != '_' && c != '$' && !char.IsLower(c) && !char.IsDigit(c)) - return true; + var first = identifier[0]; + if (first != '_' && !char.IsLower(first)) + return true; - return false; - } + foreach (var c in identifier.AsSpan(1)) + if (c != '_' && c != '$' && !char.IsLower(c) && !char.IsDigit(c)) + return true; - #endregion + return false; + } - #region Misc + #endregion - /// - /// Unbinds transaction from the connector. - /// Should be called before the connector is returned to the pool. - /// - internal void UnbindIfNecessary() + #region Misc + + /// + /// Unbinds transaction from the connector. + /// Should be called before the connector is returned to the pool. + /// + internal void UnbindIfNecessary() + { + // We're closing the connection, but transaction is not yet disposed + // We have to unbind the transaction from the connector, otherwise there could be a concurrency issues + // See #3306 + if (!IsDisposed) { - // We're closing the connection, but transaction is not yet disposed - // We have to unbind the transaction from the connector, otherwise there could be a concurency issues - // See #3306 - if (!IsDisposed) + if (_connector.UnboundTransaction is { IsDisposed: true } previousTransaction) { - _connector.Transaction = null; - _connector = null!; + previousTransaction._connector = _connector; + _connector.Transaction = previousTransaction; } - } + else + _connector.Transaction = null; - #endregion + _connector.UnboundTransaction = this; + _connector = null!; + } } + + #endregion } diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlDate.cs b/src/Npgsql/NpgsqlTypes/NpgsqlDate.cs deleted file mode 100644 index eed1828a19..0000000000 --- a/src/Npgsql/NpgsqlTypes/NpgsqlDate.cs +++ /dev/null @@ -1,453 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Text; -using JetBrains.Annotations; - -#pragma warning disable 1591 - -// ReSharper disable once CheckNamespace -namespace NpgsqlTypes -{ - [Serializable] - public readonly struct NpgsqlDate : IEquatable, IComparable, IComparable, - IComparer, IComparer - { - //Number of days since January 1st CE (January 1st EV). 1 Jan 1 CE = 0, 2 Jan 1 CE = 1, 31 Dec 1 BCE = -1, etc. - readonly int _daysSinceEra; - readonly InternalType _type; - - #region Constants - - static readonly int[] CommonYearDays = { 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365 }; - static readonly int[] LeapYearDays = { 0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366 }; - static readonly int[] CommonYearMaxes = { 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 }; - static readonly int[] LeapYearMaxes = { 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 }; - - /// - /// Represents the date 1970-01-01 - /// - public static readonly NpgsqlDate Epoch = new NpgsqlDate(1970, 1, 1); - - /// - /// Represents the date 0001-01-01 - /// - public static readonly NpgsqlDate Era = new NpgsqlDate(0); - - public const int MaxYear = 5874897; - public const int MinYear = -4714; - public static readonly NpgsqlDate MaxCalculableValue = new NpgsqlDate(MaxYear, 12, 31); - public static readonly NpgsqlDate MinCalculableValue = new NpgsqlDate(MinYear, 11, 24); - - public static readonly NpgsqlDate Infinity = new NpgsqlDate(InternalType.Infinity); - public static readonly NpgsqlDate NegativeInfinity = new NpgsqlDate(InternalType.NegativeInfinity); - - const int DaysInYear = 365; //Common years - const int DaysIn4Years = 4 * DaysInYear + 1; //Leap year every 4 years. - const int DaysInCentury = 25 * DaysIn4Years - 1; //Except no leap year every 100. - const int DaysIn4Centuries = 4 * DaysInCentury + 1; //Except leap year every 400. - - #endregion - - #region Constructors - - NpgsqlDate(InternalType type) - { - _type = type; - _daysSinceEra = 0; - } - - internal NpgsqlDate(int days) - { - _type = InternalType.Finite; - _daysSinceEra = days; - } - - public NpgsqlDate(DateTime dateTime) : this((int)(dateTime.Ticks / TimeSpan.TicksPerDay)) {} - - public NpgsqlDate(NpgsqlDate copyFrom) : this(copyFrom._daysSinceEra) {} - - public NpgsqlDate(int year, int month, int day) - { - _type = InternalType.Finite; - if (year == 0 || year < MinYear || year > MaxYear || month < 1 || month > 12 || day < 1 || - (day > (IsLeap(year) ? 366 : 365))) - { - throw new ArgumentOutOfRangeException(); - } - - _daysSinceEra = DaysForYears(year) + (IsLeap(year) ? LeapYearDays : CommonYearDays)[month - 1] + day - 1; - } - - #endregion - - #region String Conversions - - public override string ToString() - => _type switch - { - InternalType.Infinity => "infinity", - InternalType.NegativeInfinity => "-infinity", - //Format of yyyy-MM-dd with " BC" for BCE and optional " AD" for CE which we omit here. - _ => new StringBuilder(Math.Abs(Year).ToString("D4")) - .Append('-').Append(Month.ToString("D2")) - .Append('-').Append(Day.ToString("D2")) - .Append(_daysSinceEra < 0 ? " BC" : "").ToString() - }; - - public static NpgsqlDate Parse(string str) - { - - if (str == null) { - throw new ArgumentNullException(nameof(str)); - } - - if (str == "infinity") - return Infinity; - - if (str == "-infinity") - return NegativeInfinity; - - str = str.Trim(); - try { - var idx = str.IndexOf('-'); - if (idx == -1) { - throw new FormatException(); - } - var year = int.Parse(str.Substring(0, idx)); - var idxLast = idx + 1; - if ((idx = str.IndexOf('-', idxLast)) == -1) { - throw new FormatException(); - } - var month = int.Parse(str.Substring(idxLast, idx - idxLast)); - idxLast = idx + 1; - if ((idx = str.IndexOf(' ', idxLast)) == -1) { - idx = str.Length; - } - var day = int.Parse(str.Substring(idxLast, idx - idxLast)); - if (str.Contains("BC")) { - year = -year; - } - return new NpgsqlDate(year, month, day); - } catch (OverflowException) { - throw; - } catch (Exception) { - throw new FormatException(); - } - } - - public static bool TryParse(string str, out NpgsqlDate date) - { - try { - date = Parse(str); - return true; - } catch { - date = Era; - return false; - } - } - - #endregion - - #region Public Properties - - public static NpgsqlDate Now => new NpgsqlDate(DateTime.Now); - public static NpgsqlDate Today => Now; - public static NpgsqlDate Yesterday => Now.AddDays(-1); - public static NpgsqlDate Tomorrow => Now.AddDays(1); - - public int DayOfYear => _daysSinceEra - DaysForYears(Year) + 1; - - public int Year - { - get - { - var guess = (int)Math.Round(_daysSinceEra/365.2425); - var test = guess - 1; - while (DaysForYears(++test) <= _daysSinceEra) {} - return test - 1; - } - } - - public int Month - { - get - { - var i = 1; - var target = DayOfYear; - var array = IsLeapYear ? LeapYearDays : CommonYearDays; - while (target > array[i]) - { - ++i; - } - return i; - } - } - - public int Day => DayOfYear - (IsLeapYear ? LeapYearDays : CommonYearDays)[Month - 1]; - - public DayOfWeek DayOfWeek => (DayOfWeek) ((_daysSinceEra + 1)%7); - - internal int DaysSinceEra => _daysSinceEra; - - public bool IsLeapYear => IsLeap(Year); - - public bool IsInfinity => _type == InternalType.Infinity; - public bool IsNegativeInfinity => _type == InternalType.NegativeInfinity; - - public bool IsFinite - => _type switch { - InternalType.Finite => true, - InternalType.Infinity => false, - InternalType.NegativeInfinity => false, - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug.") - }; - - #endregion - - #region Internals - - static int DaysForYears(int years) - { - //Number of years after 1CE (0 for 1CE, -1 for 1BCE, 1 for 2CE). - var calcYear = years < 1 ? years : years - 1; - - return calcYear / 400 * DaysIn4Centuries //Blocks of 400 years with their leap and common years - + calcYear % 400 / 100 * DaysInCentury //Remaining blocks of 100 years with their leap and common years - + calcYear % 100 / 4 * DaysIn4Years //Remaining blocks of 4 years with their leap and common years - + calcYear % 4 * DaysInYear //Remaining years, all common - + (calcYear < 0 ? -1 : 0); //And 1BCE is leap. - } - - static bool IsLeap(int year) - { - //Every 4 years is a leap year - //Except every 100 years isn't a leap year. - //Except every 400 years is. - if (year < 1) - { - year = year + 1; - } - return (year%4 == 0) && ((year%100 != 0) || (year%400 == 0)); - } - - #endregion - - #region Arithmetic - - public NpgsqlDate AddDays(int days) - => _type switch - { - InternalType.Infinity => Infinity, - InternalType.NegativeInfinity => NegativeInfinity, - InternalType.Finite => new NpgsqlDate(_daysSinceEra + days), - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug.") - }; - - public NpgsqlDate AddYears(int years) - { - switch (_type) { - case InternalType.Infinity: - return Infinity; - case InternalType.NegativeInfinity: - return NegativeInfinity; - case InternalType.Finite: - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug."); - } - - var newYear = Year + years; - if (newYear >= 0 && _daysSinceEra < 0) //cross 1CE/1BCE divide going up - { - ++newYear; - } - else if (newYear <= 0 && _daysSinceEra >= 0) //cross 1CE/1BCE divide going down - { - --newYear; - } - return new NpgsqlDate(newYear, Month, Day); - } - - public NpgsqlDate AddMonths(int months) - { - switch (_type) { - case InternalType.Infinity: - return Infinity; - case InternalType.NegativeInfinity: - return NegativeInfinity; - case InternalType.Finite: - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug."); - } - - var newYear = Year; - var newMonth = Month + months; - - while (newMonth > 12) - { - newMonth -= 12; - newYear += 1; - } - while (newMonth < 1) - { - newMonth += 12; - newYear -= 1; - } - var maxDay = (IsLeap(newYear) ? LeapYearMaxes : CommonYearMaxes)[newMonth - 1]; - var newDay = Day > maxDay ? maxDay : Day; - return new NpgsqlDate(newYear, newMonth, newDay); - - } - - public NpgsqlDate Add(in NpgsqlTimeSpan interval) - { - switch (_type) { - case InternalType.Infinity: - return Infinity; - case InternalType.NegativeInfinity: - return NegativeInfinity; - case InternalType.Finite: - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug."); - } - - return AddMonths(interval.Months).AddDays(interval.Days); - } - - internal NpgsqlDate Add(in NpgsqlTimeSpan interval, int carriedOverflow) - { - switch (_type) { - case InternalType.Infinity: - return Infinity; - case InternalType.NegativeInfinity: - return NegativeInfinity; - case InternalType.Finite: - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug."); - } - - return AddMonths(interval.Months).AddDays(interval.Days + carriedOverflow); - } - - #endregion - - #region Comparison - - public int Compare(NpgsqlDate x, NpgsqlDate y) => x.CompareTo(y); - - public int Compare(object? x, object? y) - { - if (x == null) - { - return y == null ? 0 : -1; - } - if (y == null) - { - return 1; - } - if (!(x is IComparable) || !(y is IComparable)) - { - throw new ArgumentException(); - } - return ((IComparable) x).CompareTo(y); - } - - public bool Equals(NpgsqlDate other) - => _type switch - { - InternalType.Infinity => other._type == InternalType.Infinity, - InternalType.NegativeInfinity => other._type == InternalType.NegativeInfinity, - InternalType.Finite => other._type == InternalType.Finite && _daysSinceEra == other._daysSinceEra, - _ => false - }; - - public override bool Equals(object? obj) => obj is NpgsqlDate date && Equals(date); - - public int CompareTo(NpgsqlDate other) - => _type switch - { - InternalType.Infinity => other._type == InternalType.Infinity ? 0 : 1, - InternalType.NegativeInfinity => other._type == InternalType.NegativeInfinity ? 0 : -1, - _ => other._type switch - { - InternalType.Infinity => -1, - InternalType.NegativeInfinity => 1, - _ => _daysSinceEra.CompareTo(other._daysSinceEra) - } - }; - - public int CompareTo(object? o) - => o == null - ? 1 - : o is NpgsqlDate npgsqlDate - ? CompareTo(npgsqlDate) - : throw new ArgumentException(); - - public override int GetHashCode() => _daysSinceEra; - - #endregion - - #region Operators - - public static bool operator ==(NpgsqlDate x, NpgsqlDate y) => x.Equals(y); - public static bool operator !=(NpgsqlDate x, NpgsqlDate y) => !(x == y); - public static bool operator <(NpgsqlDate x, NpgsqlDate y) => x.CompareTo(y) < 0; - public static bool operator >(NpgsqlDate x, NpgsqlDate y) => x.CompareTo(y) > 0; - public static bool operator <=(NpgsqlDate x, NpgsqlDate y) => x.CompareTo(y) <= 0; - public static bool operator >=(NpgsqlDate x, NpgsqlDate y) => x.CompareTo(y) >= 0; - - public static DateTime ToDateTime(NpgsqlDate date) - { - switch (date._type) - { - case InternalType.Infinity: - case InternalType.NegativeInfinity: - throw new InvalidCastException("Infinity values can't be cast to DateTime"); - case InternalType.Finite: - try { return new DateTime(date._daysSinceEra * NpgsqlTimeSpan.TicksPerDay); } - catch { throw new InvalidCastException(); } - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {date._type} of enum {nameof(NpgsqlDate)}.{nameof(InternalType)}. Please file a bug."); - } - } - - public static explicit operator DateTime(NpgsqlDate date) => ToDateTime(date); - - public static NpgsqlDate ToNpgsqlDate(DateTime date) - => new NpgsqlDate((int)(date.Ticks / NpgsqlTimeSpan.TicksPerDay)); - - public static explicit operator NpgsqlDate(DateTime date) => ToNpgsqlDate(date); - - public static NpgsqlDate operator +(NpgsqlDate date, NpgsqlTimeSpan interval) - => date.Add(interval); - - public static NpgsqlDate operator +(NpgsqlTimeSpan interval, NpgsqlDate date) - => date.Add(interval); - - public static NpgsqlDate operator -(NpgsqlDate date, NpgsqlTimeSpan interval) - => date.Subtract(interval); - - public NpgsqlDate Subtract(in NpgsqlTimeSpan interval) => Add(-interval); - - public static NpgsqlTimeSpan operator -(NpgsqlDate dateX, NpgsqlDate dateY) - { - if (dateX._type != InternalType.Finite || dateY._type != InternalType.Finite) - throw new ArgumentException("Can't subtract infinity date values"); - - return new NpgsqlTimeSpan(0, dateX._daysSinceEra - dateY._daysSinceEra, 0); - } - - #endregion - - enum InternalType - { - Finite, - Infinity, - NegativeInfinity - } - } -} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlDateTime.cs b/src/Npgsql/NpgsqlTypes/NpgsqlDateTime.cs deleted file mode 100644 index 4f25ac680c..0000000000 --- a/src/Npgsql/NpgsqlTypes/NpgsqlDateTime.cs +++ /dev/null @@ -1,478 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using Npgsql.Util; - -#pragma warning disable 1591 - -// ReSharper disable once CheckNamespace -namespace NpgsqlTypes -{ - /// - /// A struct similar to .NET DateTime but capable of storing PostgreSQL's timestamp and timestamptz types. - /// DateTime is capable of storing values from year 1 to 9999 at 100-nanosecond precision, - /// while PostgreSQL's timestamps store values from 4713BC to 5874897AD with 1-microsecond precision. - /// - [Serializable] - public readonly struct NpgsqlDateTime : IEquatable, IComparable, IComparable, - IComparer, IComparer - { - #region Fields - - readonly NpgsqlDate _date; - readonly TimeSpan _time; - readonly InternalType _type; - - #endregion - - #region Constants - - public static readonly NpgsqlDateTime Epoch = new NpgsqlDateTime(NpgsqlDate.Epoch); - public static readonly NpgsqlDateTime Era = new NpgsqlDateTime(NpgsqlDate.Era); - - public static readonly NpgsqlDateTime Infinity = - new NpgsqlDateTime(InternalType.Infinity, NpgsqlDate.Era, TimeSpan.Zero); - - public static readonly NpgsqlDateTime NegativeInfinity = - new NpgsqlDateTime(InternalType.NegativeInfinity, NpgsqlDate.Era, TimeSpan.Zero); - - // 9999-12-31 - const int MaxDateTimeDay = 3652058; - - #endregion - - #region Constructors - - NpgsqlDateTime(InternalType type, NpgsqlDate date, TimeSpan time) - { - if (!date.IsFinite && type != InternalType.Infinity && type != InternalType.NegativeInfinity) - throw new ArgumentException("Can't construct an NpgsqlDateTime with a non-finite date, use Infinity and NegativeInfinity instead", nameof(date)); - - _type = type; - _date = date; - _time = time; - } - - public NpgsqlDateTime(NpgsqlDate date, TimeSpan time, DateTimeKind kind = DateTimeKind.Unspecified) - : this(KindToInternalType(kind), date, time) {} - - public NpgsqlDateTime(NpgsqlDate date) - : this(date, TimeSpan.Zero) {} - - public NpgsqlDateTime(int year, int month, int day, int hours, int minutes, int seconds, DateTimeKind kind=DateTimeKind.Unspecified) - : this(new NpgsqlDate(year, month, day), new TimeSpan(0, hours, minutes, seconds), kind) {} - - public NpgsqlDateTime(int year, int month, int day, int hours, int minutes, int seconds, int milliseconds, DateTimeKind kind = DateTimeKind.Unspecified) - : this(new NpgsqlDate(year, month, day), new TimeSpan(0, hours, minutes, seconds, milliseconds), kind) { } - - public NpgsqlDateTime(DateTime dateTime) - : this(new NpgsqlDate(dateTime.Date), dateTime.TimeOfDay, dateTime.Kind) {} - - public NpgsqlDateTime(long ticks, DateTimeKind kind) - : this(new DateTime(ticks, kind)) { } - - public NpgsqlDateTime(long ticks) - : this(new DateTime(ticks, DateTimeKind.Unspecified)) { } - - #endregion - - #region Public Properties - - public NpgsqlDate Date => _date; - public TimeSpan Time => _time; - public int DayOfYear => _date.DayOfYear; - public int Year => _date.Year; - public int Month => _date.Month; - public int Day => _date.Day; - public DayOfWeek DayOfWeek => _date.DayOfWeek; - public bool IsLeapYear => _date.IsLeapYear; - - public long Ticks => _date.DaysSinceEra * NpgsqlTimeSpan.TicksPerDay + _time.Ticks; - public int Millisecond => _time.Milliseconds; - public int Second => _time.Seconds; - public int Minute => _time.Minutes; - public int Hour => _time.Hours; - public bool IsInfinity => _type == InternalType.Infinity; - public bool IsNegativeInfinity => _type == InternalType.NegativeInfinity; - - public bool IsFinite - => _type switch - { - InternalType.FiniteUnspecified => true, - InternalType.FiniteUtc => true, - InternalType.FiniteLocal => true, - InternalType.Infinity => false, - InternalType.NegativeInfinity => false, - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDateTime)}.{nameof(InternalType)}. Please file a bug.") - }; - - public DateTimeKind Kind - => _type switch - { - InternalType.FiniteUtc => DateTimeKind.Utc, - InternalType.FiniteLocal => DateTimeKind.Local, - InternalType.FiniteUnspecified => DateTimeKind.Unspecified, - InternalType.Infinity => DateTimeKind.Unspecified, - InternalType.NegativeInfinity => DateTimeKind.Unspecified, - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(DateTimeKind)}. Please file a bug.") - }; - - /// - /// Cast of an to a . - /// - /// An equivalent . - public DateTime ToDateTime() - { - if (!IsFinite) - throw new InvalidCastException("Can't convert infinite timestamp values to DateTime"); - - if (_date.DaysSinceEra < 0 || _date.DaysSinceEra > MaxDateTimeDay) - throw new InvalidCastException("Out of the range of DateTime (year must be between 1 and 9999)"); - - return new DateTime(Ticks, Kind); - } - - /// - /// Converts the value of the current object to Coordinated Universal Time (UTC). - /// - /// - /// See the MSDN documentation for DateTime.ToUniversalTime(). - /// Note: this method only takes into account the time zone's base offset, and does - /// not respect daylight savings. See https://github.com/npgsql/npgsql/pull/684 for more - /// details. - /// - public NpgsqlDateTime ToUniversalTime() - { - switch (_type) - { - case InternalType.FiniteUnspecified: - // Treat as Local - case InternalType.FiniteLocal: - if (_date.DaysSinceEra >= 1 && _date.DaysSinceEra <= MaxDateTimeDay - 1) - { - // Day between 0001-01-02 and 9999-12-30, so we can use DateTime and it will always succeed - return new NpgsqlDateTime(Subtract(TimeZoneInfo.Local.GetUtcOffset(new DateTime(ToDateTime().Ticks, DateTimeKind.Local))).Ticks, DateTimeKind.Utc); - } - // Else there are no DST rules available in the system for outside the DateTime range, so just use the base offset - return new NpgsqlDateTime(Subtract(TimeZoneInfo.Local.BaseUtcOffset).Ticks, DateTimeKind.Utc); - case InternalType.FiniteUtc: - case InternalType.Infinity: - case InternalType.NegativeInfinity: - return this; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDateTime)}.{nameof(InternalType)}. Please file a bug."); - } - } - - /// - /// Converts the value of the current object to local time. - /// - /// - /// See the MSDN documentation for DateTime.ToLocalTime(). - /// Note: this method only takes into account the time zone's base offset, and does - /// not respect daylight savings. See https://github.com/npgsql/npgsql/pull/684 for more - /// details. - /// - public NpgsqlDateTime ToLocalTime() - { - switch (_type) { - case InternalType.FiniteUnspecified: - // Treat as UTC - case InternalType.FiniteUtc: - if (_date.DaysSinceEra >= 1 && _date.DaysSinceEra <= MaxDateTimeDay - 1) - { - // Day between 0001-01-02 and 9999-12-30, so we can use DateTime and it will always succeed - return new NpgsqlDateTime(TimeZoneInfo.ConvertTime(new DateTime(ToDateTime().Ticks, DateTimeKind.Utc), TimeZoneInfo.Local)); - } - // Else there are no DST rules available in the system for outside the DateTime range, so just use the base offset - return new NpgsqlDateTime(Add(TimeZoneInfo.Local.BaseUtcOffset).Ticks, DateTimeKind.Local); - case InternalType.FiniteLocal: - case InternalType.Infinity: - case InternalType.NegativeInfinity: - return this; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {_type} of enum {nameof(NpgsqlDateTime)}.{nameof(InternalType)}. Please file a bug."); - } - } - - public static NpgsqlDateTime Now => new NpgsqlDateTime(DateTime.Now); - - #endregion - - #region String Conversions - - public override string ToString() - => _type switch - { - InternalType.Infinity => "infinity", - InternalType.NegativeInfinity => "-infinity", - _ => $"{_date} {_time}" - }; - - public static NpgsqlDateTime Parse(string str) - { - if (str == null) { - throw new NullReferenceException(); - } - switch (str = str.Trim().ToLowerInvariant()) { - case "infinity": - return Infinity; - case "-infinity": - return NegativeInfinity; - default: - try { - var idxSpace = str.IndexOf(' '); - var datePart = str.Substring(0, idxSpace); - if (str.Contains("bc")) { - datePart += " BC"; - } - var idxSecond = str.IndexOf(' ', idxSpace + 1); - if (idxSecond == -1) { - idxSecond = str.Length; - } - var timePart = str.Substring(idxSpace + 1, idxSecond - idxSpace - 1); - return new NpgsqlDateTime(NpgsqlDate.Parse(datePart), TimeSpan.Parse(timePart)); - } catch (OverflowException) { - throw; - } catch { - throw new FormatException(); - } - } - } - - #endregion - - #region Comparisons - - public bool Equals(NpgsqlDateTime other) - => _type switch - { - InternalType.Infinity => other._type == InternalType.Infinity, - InternalType.NegativeInfinity => other._type == InternalType.NegativeInfinity, - _ => other._type == _type && _date.Equals(other._date) && _time.Equals(other._time) - }; - - public override bool Equals(object? obj) - => obj is NpgsqlDateTime time && Equals(time); - - public override int GetHashCode() - => _type switch - { - InternalType.Infinity => int.MaxValue, - InternalType.NegativeInfinity => int.MinValue, - _ => _date.GetHashCode() ^ PGUtil.RotateShift(_time.GetHashCode(), 16) - }; - - public int CompareTo(NpgsqlDateTime other) - { - switch (_type) { - case InternalType.Infinity: - return other._type == InternalType.Infinity ? 0 : 1; - case InternalType.NegativeInfinity: - return other._type == InternalType.NegativeInfinity ? 0 : -1; - default: - switch (other._type) { - case InternalType.Infinity: - return -1; - case InternalType.NegativeInfinity: - return 1; - default: - var cmp = _date.CompareTo(other._date); - return cmp == 0 ? _time.CompareTo(other._time) : cmp; - } - } - } - - public int CompareTo(object? o) - => o == null - ? 1 - : o is NpgsqlDateTime npgsqlDateTime - ? CompareTo(npgsqlDateTime) - : throw new ArgumentException(); - - public int Compare(NpgsqlDateTime x, NpgsqlDateTime y) => x.CompareTo(y); - - public int Compare(object? x, object? y) - { - if (x == null) - return y == null ? 0 : -1; - if (y == null) - return 1; - if (!(x is IComparable) || !(y is IComparable)) - throw new ArgumentException(); - return ((IComparable)x).CompareTo(y); - } - - #endregion - - #region Arithmetic - - /// - /// Returns a new that adds the value of the specified to the value of this instance. - /// - /// An NpgsqlTimeSpan interval. - /// An object whose value is the sum of the date and time represented by this instance and the time interval represented by value. - public NpgsqlDateTime Add(in NpgsqlTimeSpan value) => AddTicks(value.UnjustifyInterval().TotalTicks); - - /// - /// Returns a new that adds the value of the specified TimeSpan to the value of this instance. - /// - /// A positive or negative time interval. - /// An object whose value is the sum of the date and time represented by this instance and the time interval represented by value. - public NpgsqlDateTime Add(TimeSpan value) { return AddTicks(value.Ticks); } - - /// - /// Returns a new that adds the specified number of years to the value of this instance. - /// - /// A number of years. The value parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and the number of years represented by value. - public NpgsqlDateTime AddYears(int value) - => _type switch - { - InternalType.Infinity => this, - InternalType.NegativeInfinity => this, - _ => new NpgsqlDateTime(_type, _date.AddYears(value), _time) - }; - - /// - /// Returns a new that adds the specified number of months to the value of this instance. - /// - /// A number of months. The months parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and months. - public NpgsqlDateTime AddMonths(int value) - => _type switch - { - InternalType.Infinity => this, - InternalType.NegativeInfinity => this, - _ => new NpgsqlDateTime(_type, _date.AddMonths(value), _time) - }; - - /// - /// Returns a new that adds the specified number of days to the value of this instance. - /// - /// A number of whole and fractional days. The value parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and the number of days represented by value. - public NpgsqlDateTime AddDays(double value) => Add(TimeSpan.FromDays(value)); - - /// - /// Returns a new that adds the specified number of hours to the value of this instance. - /// - /// A number of whole and fractional hours. The value parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and the number of hours represented by value. - public NpgsqlDateTime AddHours(double value) => Add(TimeSpan.FromHours(value)); - - /// - /// Returns a new that adds the specified number of minutes to the value of this instance. - /// - /// A number of whole and fractional minutes. The value parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and the number of minutes represented by value. - public NpgsqlDateTime AddMinutes(double value) => Add(TimeSpan.FromMinutes(value)); - - /// - /// Returns a new that adds the specified number of minutes to the value of this instance. - /// - /// A number of whole and fractional minutes. The value parameter can be negative or positive. - /// An object whose value is the sum of the date and time represented by this instance and the number of minutes represented by value. - public NpgsqlDateTime AddSeconds(double value) => Add(TimeSpan.FromSeconds(value)); - - /// - /// Returns a new that adds the specified number of milliseconds to the value of this instance. - /// - /// A number of whole and fractional milliseconds. The value parameter can be negative or positive. Note that this value is rounded to the nearest integer. - /// An object whose value is the sum of the date and time represented by this instance and the number of milliseconds represented by value. - public NpgsqlDateTime AddMilliseconds(double value) => Add(TimeSpan.FromMilliseconds(value)); - - /// - /// Returns a new that adds the specified number of ticks to the value of this instance. - /// - /// A number of 100-nanosecond ticks. The value parameter can be positive or negative. - /// An object whose value is the sum of the date and time represented by this instance and the time represented by value. - public NpgsqlDateTime AddTicks(long value) - => _type switch - { - InternalType.Infinity => this, - InternalType.NegativeInfinity => this, - _ => new NpgsqlDateTime(Ticks + value, Kind), - }; - - public NpgsqlDateTime Subtract(in NpgsqlTimeSpan interval) => Add(-interval); - - public NpgsqlTimeSpan Subtract(NpgsqlDateTime timestamp) - { - switch (_type) { - case InternalType.Infinity: - case InternalType.NegativeInfinity: - throw new InvalidOperationException("You cannot subtract infinity timestamps"); - } - switch (timestamp._type) { - case InternalType.Infinity: - case InternalType.NegativeInfinity: - throw new InvalidOperationException("You cannot subtract infinity timestamps"); - } - return new NpgsqlTimeSpan(0, _date.DaysSinceEra - timestamp._date.DaysSinceEra, _time.Ticks - timestamp._time.Ticks); - } - - #endregion - - #region Operators - - public static NpgsqlDateTime operator +(NpgsqlDateTime timestamp, NpgsqlTimeSpan interval) - => timestamp.Add(interval); - - public static NpgsqlDateTime operator +(NpgsqlTimeSpan interval, NpgsqlDateTime timestamp) - => timestamp.Add(interval); - - public static NpgsqlDateTime operator -(NpgsqlDateTime timestamp, NpgsqlTimeSpan interval) - => timestamp.Subtract(interval); - - public static NpgsqlTimeSpan operator -(NpgsqlDateTime x, NpgsqlDateTime y) => x.Subtract(y); - public static bool operator ==(NpgsqlDateTime x, NpgsqlDateTime y) => x.Equals(y); - public static bool operator !=(NpgsqlDateTime x, NpgsqlDateTime y) => !(x == y); - public static bool operator <(NpgsqlDateTime x, NpgsqlDateTime y) => x.CompareTo(y) < 0; - public static bool operator >(NpgsqlDateTime x, NpgsqlDateTime y) => x.CompareTo(y) > 0; - public static bool operator <=(NpgsqlDateTime x, NpgsqlDateTime y) => x.CompareTo(y) <= 0; - public static bool operator >=(NpgsqlDateTime x, NpgsqlDateTime y) => x.CompareTo(y) >= 0; - - #endregion - - #region Casts - - /// - /// Implicit cast of a to an - /// - /// A - /// An equivalent . - public static implicit operator NpgsqlDateTime(DateTime dateTime) => ToNpgsqlDateTime(dateTime); - public static NpgsqlDateTime ToNpgsqlDateTime(DateTime dateTime) => new NpgsqlDateTime(dateTime); - - /// - /// Explicit cast of an to a . - /// - /// An . - /// An equivalent . - public static explicit operator DateTime(NpgsqlDateTime npgsqlDateTime) - => npgsqlDateTime.ToDateTime(); - - #endregion - - public NpgsqlDateTime Normalize() => Add(NpgsqlTimeSpan.Zero); - - static InternalType KindToInternalType(DateTimeKind kind) - => kind switch - { - DateTimeKind.Unspecified => InternalType.FiniteUnspecified, - DateTimeKind.Utc => InternalType.FiniteUtc, - DateTimeKind.Local => InternalType.FiniteLocal, - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {kind} of enum {nameof(NpgsqlDateTime)}.{nameof(InternalType)}. Please file a bug.") - }; - - enum InternalType - { - FiniteUnspecified, - FiniteUtc, - FiniteLocal, - Infinity, - NegativeInfinity - } - } -} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs index a7539a404b..687ebf16b7 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs @@ -1,568 +1,986 @@ using System; +using System.Data; using Npgsql; +using Npgsql.Internal.Postgres; +using static Npgsql.Util.Statics; #pragma warning disable CA1720 // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL data type that can be written or read to the database. +/// Used in places such as to unambiguously specify +/// how to encode or decode values. +/// +/// +/// See https://www.postgresql.org/docs/current/static/datatype.html. +/// +// Source for PG OIDs: +public enum NpgsqlDbType { + // Note that it's important to never change the numeric values of this enum, since user applications + // compile them in. + + #region Numeric Types + /// - /// Represents a PostgreSQL data type that can be written or read to the database. - /// Used in places such as to unambiguously specify - /// how to encode or decode values. + /// Corresponds to the PostgreSQL 8-byte "bigint" type. /// - /// See https://www.postgresql.org/docs/current/static/datatype.html - public enum NpgsqlDbType - { - // Note that it's important to never change the numeric values of this enum, since user applications - // compile them in. - - #region Numeric Types - - /// - /// Corresponds to the PostgreSQL 8-byte "bigint" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("int8", PostgresTypeOIDs.Int8)] - Bigint = 1, - - /// - /// Corresponds to the PostgreSQL 8-byte floating-point "double" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("float8", PostgresTypeOIDs.Float8)] - Double = 8, - - /// - /// Corresponds to the PostgreSQL 4-byte "integer" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("int4", PostgresTypeOIDs.Int4)] - Integer = 9, - - /// - /// Corresponds to the PostgreSQL arbitrary-precision "numeric" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("numeric", PostgresTypeOIDs.Numeric)] - Numeric = 13, - - /// - /// Corresponds to the PostgreSQL floating-point "real" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("float4", PostgresTypeOIDs.Float4)] - Real = 17, - - /// - /// Corresponds to the PostgreSQL 2-byte "smallint" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html - [BuiltInPostgresType("int2", PostgresTypeOIDs.Int2)] - Smallint = 18, - - /// - /// Corresponds to the PostgreSQL "money" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-money.html - [BuiltInPostgresType("money", PostgresTypeOIDs.Money)] - Money = 12, - - #endregion - - #region Boolean Type - - /// - /// Corresponds to the PostgreSQL "boolean" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-boolean.html - [BuiltInPostgresType("bool", PostgresTypeOIDs.Bool)] - Boolean = 2, - - #endregion - - #region Geometric types - - /// - /// Corresponds to the PostgreSQL geometric "box" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("box", PostgresTypeOIDs.Box)] - Box = 3, - - /// - /// Corresponds to the PostgreSQL geometric "circle" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("circle", PostgresTypeOIDs.Circle)] - Circle = 5, - - /// - /// Corresponds to the PostgreSQL geometric "line" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("line", PostgresTypeOIDs.Line)] - Line = 10, - - /// - /// Corresponds to the PostgreSQL geometric "lseg" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("lseg", PostgresTypeOIDs.LSeg)] - LSeg = 11, - - /// - /// Corresponds to the PostgreSQL geometric "path" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("path", PostgresTypeOIDs.Path)] - Path = 14, - - /// - /// Corresponds to the PostgreSQL geometric "point" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("point", PostgresTypeOIDs.Point)] - Point = 15, - - /// - /// Corresponds to the PostgreSQL geometric "polygon" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - [BuiltInPostgresType("polygon", PostgresTypeOIDs.Polygon)] - Polygon = 16, - - #endregion - - #region Character Types - - /// - /// Corresponds to the PostgreSQL "char(n)" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-character.html - [BuiltInPostgresType("bpchar", PostgresTypeOIDs.BPChar)] - Char = 6, - - /// - /// Corresponds to the PostgreSQL "text" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-character.html - [BuiltInPostgresType("text", PostgresTypeOIDs.Text)] - Text = 19, - - /// - /// Corresponds to the PostgreSQL "varchar" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-character.html - [BuiltInPostgresType("varchar", PostgresTypeOIDs.Varchar)] - Varchar = 22, - - /// - /// Corresponds to the PostgreSQL internal "name" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-character.html - [BuiltInPostgresType("name", PostgresTypeOIDs.Name)] - Name = 32, - - /// - /// Corresponds to the PostgreSQL "citext" type for the citext module. - /// - /// See https://www.postgresql.org/docs/current/static/citext.html - Citext = 51, // Extension type - - /// - /// Corresponds to the PostgreSQL "char" type. - /// - /// - /// This is an internal field and should normally not be used for regular applications. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-text.html - /// - [BuiltInPostgresType("char", PostgresTypeOIDs.Char)] - InternalChar = 38, - - #endregion - - #region Binary Data Types - - /// - /// Corresponds to the PostgreSQL "bytea" type, holding a raw byte string. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-binary.html - [BuiltInPostgresType("bytea", PostgresTypeOIDs.Bytea)] - Bytea = 4, - - #endregion - - #region Date/Time Types - - /// - /// Corresponds to the PostgreSQL "date" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("date", PostgresTypeOIDs.Date)] - Date = 7, - - /// - /// Corresponds to the PostgreSQL "time" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("time", PostgresTypeOIDs.Time)] - Time = 20, - - /// - /// Corresponds to the PostgreSQL "timestamp" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("timestamp", PostgresTypeOIDs.Timestamp)] - Timestamp = 21, - - /// - /// Corresponds to the PostgreSQL "timestamp with time zone" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [Obsolete("Use TimestampTz instead")] // NOTE: Don't remove this (see #1694) - TimestampTZ = TimestampTz, - - /// - /// Corresponds to the PostgreSQL "timestamp with time zone" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("timestamptz", PostgresTypeOIDs.TimestampTz)] - TimestampTz = 26, - - /// - /// Corresponds to the PostgreSQL "interval" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("interval", PostgresTypeOIDs.Interval)] - Interval = 30, - - /// - /// Corresponds to the PostgreSQL "time with time zone" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [Obsolete("Use TimeTz instead")] // NOTE: Don't remove this (see #1694) - TimeTZ = TimeTz, - - /// - /// Corresponds to the PostgreSQL "time with time zone" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [BuiltInPostgresType("timetz", PostgresTypeOIDs.TimeTz)] - TimeTz = 31, - - /// - /// Corresponds to the obsolete PostgreSQL "abstime" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html - [Obsolete("The PostgreSQL abstime time is obsolete.")] - [BuiltInPostgresType("abstime", PostgresTypeOIDs.Abstime)] - Abstime = 33, - - #endregion - - #region Network Address Types - - /// - /// Corresponds to the PostgreSQL "inet" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html - [BuiltInPostgresType("inet", PostgresTypeOIDs.Inet)] - Inet = 24, - - /// - /// Corresponds to the PostgreSQL "cidr" type, a field storing an IPv4 or IPv6 network. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html - [BuiltInPostgresType("cidr", PostgresTypeOIDs.Cidr)] - Cidr = 44, - - /// - /// Corresponds to the PostgreSQL "macaddr" type, a field storing a 6-byte physical address. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html - [BuiltInPostgresType("macaddr", PostgresTypeOIDs.Macaddr)] - MacAddr = 34, - - /// - /// Corresponds to the PostgreSQL "macaddr8" type, a field storing a 6-byte or 8-byte physical address. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html - [BuiltInPostgresType("macaddr8", PostgresTypeOIDs.Macaddr8)] - MacAddr8 = 54, - - #endregion - - #region Bit String Types - - /// - /// Corresponds to the PostgreSQL "bit" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-bit.html - [BuiltInPostgresType("bit", PostgresTypeOIDs.Bit)] - Bit = 25, - - /// - /// Corresponds to the PostgreSQL "varbit" type, a field storing a variable-length string of bits. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-boolean.html - [BuiltInPostgresType("varbit", PostgresTypeOIDs.Varbit)] - Varbit = 39, - - #endregion - - #region Text Search Types - - /// - /// Corresponds to the PostgreSQL "tsvector" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html - [BuiltInPostgresType("tsvector", PostgresTypeOIDs.TsVector)] - TsVector = 45, - - /// - /// Corresponds to the PostgreSQL "tsquery" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html - [BuiltInPostgresType("tsquery", PostgresTypeOIDs.TsQuery)] - TsQuery = 46, - - /// - /// Corresponds to the PostgreSQL "regconfig" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html - [BuiltInPostgresType("regconfig", PostgresTypeOIDs.Regconfig)] - Regconfig = 56, - - #endregion - - #region UUID Type - - /// - /// Corresponds to the PostgreSQL "uuid" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-uuid.html - [BuiltInPostgresType("uuid", PostgresTypeOIDs.Uuid)] - Uuid = 27, - - #endregion - - #region XML Type - - /// - /// Corresponds to the PostgreSQL "xml" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-xml.html - [BuiltInPostgresType("xml", PostgresTypeOIDs.Xml)] - Xml = 28, - - #endregion - - #region JSON Types - - /// - /// Corresponds to the PostgreSQL "json" type, a field storing JSON in text format. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-json.html - /// - [BuiltInPostgresType("json", PostgresTypeOIDs.Json)] - Json = 35, - - /// - /// Corresponds to the PostgreSQL "jsonb" type, a field storing JSON in an optimized binary. - /// format. - /// - /// - /// Supported since PostgreSQL 9.4. - /// See https://www.postgresql.org/docs/current/static/datatype-json.html - /// - [BuiltInPostgresType("jsonb", PostgresTypeOIDs.Jsonb)] - Jsonb = 36, - - /// - /// Corresponds to the PostgreSQL "jsonpath" type, a field storing JSON path in text format. - /// format. - /// - /// - /// Supported since PostgreSQL 12. - /// See https://www.postgresql.org/docs/current/datatype-json.html#DATATYPE-JSONPATH - /// - [BuiltInPostgresType("jsonpath", PostgresTypeOIDs.JsonPath)] - JsonPath = 57, - - #endregion - - #region HSTORE Type - - /// - /// Corresponds to the PostgreSQL "hstore" type, a dictionary of string key-value pairs. - /// - /// See https://www.postgresql.org/docs/current/static/hstore.html - Hstore = 37, // Extension type - - #endregion - - #region Arrays - - /// - /// Corresponds to the PostgreSQL "array" type, a variable-length multidimensional array of - /// another type. This value must be combined with another value from - /// via a bit OR (e.g. NpgsqlDbType.Array | NpgsqlDbType.Integer) - /// - /// See https://www.postgresql.org/docs/current/static/arrays.html - Array = int.MinValue, - - #endregion - - #region Range Types - - /// - /// Corresponds to the PostgreSQL "range" type, continuous range of values of specific type. - /// This value must be combined with another value from - /// via a bit OR (e.g. NpgsqlDbType.Range | NpgsqlDbType.Integer) - /// - /// - /// Supported since PostgreSQL 9.2. - /// See https://www.postgresql.org/docs/9.2/static/rangetypes.html - /// - Range = 0x40000000, - - #endregion - - #region Internal Types - - /// - /// Corresponds to the PostgreSQL "refcursor" type. - /// - [BuiltInPostgresType("refcursor", PostgresTypeOIDs.Refcursor)] - Refcursor = 23, - - /// - /// Corresponds to the PostgreSQL internal "oidvector" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-oid.html - [BuiltInPostgresType("oidvector", PostgresTypeOIDs.Oidvector)] - Oidvector = 29, - - /// - /// Corresponds to the PostgreSQL internal "int2vector" type. - /// - [BuiltInPostgresType("int2vector", PostgresTypeOIDs.Int2vector)] - Int2Vector = 52, - - /// - /// Corresponds to the PostgreSQL "oid" type. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-oid.html - [BuiltInPostgresType("oid", PostgresTypeOIDs.Oid)] - Oid = 41, - - /// - /// Corresponds to the PostgreSQL "xid" type, an internal transaction identifier. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-oid.html - [BuiltInPostgresType("xid", PostgresTypeOIDs.Xid)] - Xid = 42, - - /// - /// Corresponds to the PostgreSQL "cid" type, an internal command identifier. - /// - /// See https://www.postgresql.org/docs/current/static/datatype-oid.html - [BuiltInPostgresType("cid", PostgresTypeOIDs.Cid)] - Cid = 43, - - /// - /// Corresponds to the PostgreSQL "regtype" type, a numeric (OID) ID of a type in the pg_type table. - /// - [BuiltInPostgresType("regtype", PostgresTypeOIDs.Regtype)] - Regtype = 49, - - /// - /// Corresponds to the PostgreSQL "tid" type, a tuple id identifying the physical location of a row within its table. - /// - [BuiltInPostgresType("tid", PostgresTypeOIDs.Tid)] - Tid = 53, - - /// - /// Corresponds to the PostgreSQL "pg_lsn" type, which can be used to store LSN (Log Sequence Number) data which - /// is a pointer to a location in the WAL. - /// - /// - /// See: https://www.postgresql.org/docs/current/datatype-pg-lsn.html and - /// https://git.postgresql.org/gitweb/?p=postgresql.git;a=commit;h=7d03a83f4d0736ba869fa6f93973f7623a27038a - /// - [BuiltInPostgresType("pg_lsn", 3220)] - PgLsn = 59, - - #endregion - - #region Special - - /// - /// A special value that can be used to send parameter values to the database without - /// specifying their type, allowing the database to cast them to another value based on context. - /// The value will be converted to a string and send as text. - /// - /// - /// This value shouldn't ordinarily be used, and makes sense only when sending a data type - /// unsupported by Npgsql. - /// - [BuiltInPostgresType("unknown", PostgresTypeOIDs.Unknown)] - Unknown = 40, - - #endregion - - #region PostGIS - - /// - /// The geometry type for PostgreSQL spatial extension PostGIS. - /// - Geometry = 50, // Extension type - - /// - /// The geography (geodetic) type for PostgreSQL spatial extension PostGIS. - /// - Geography = 55, // Extension type - - #endregion - - #region Label tree types - - /// - /// The PostgreSQL ltree type, each value is a label path "a.label.tree.value", forming a tree in a set. - /// - /// See http://www.postgresql.org/docs/current/static/ltree.html - LTree = 60, // Extension type - - /// - /// The PostgreSQL lquery type for PostgreSQL extension ltree - /// - /// See http://www.postgresql.org/docs/current/static/ltree.html - LQuery = 61, // Extension type - - /// - /// The PostgreSQL ltxtquery type for PostgreSQL extension ltree - /// - /// See http://www.postgresql.org/docs/current/static/ltree.html - LTxtQuery = 62, // Extension type - - #endregion - } + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Bigint = 1, /// - /// Represents a built-in PostgreSQL type as it appears in pg_type, including its name and OID. - /// Extension types with variable OIDs are not represented. + /// Corresponds to the PostgreSQL 8-byte floating-point "double" type. /// - class BuiltInPostgresType : Attribute - { - internal string Name { get; } - internal uint OID { get; } + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Double = 8, + + /// + /// Corresponds to the PostgreSQL 4-byte "integer" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Integer = 9, + + /// + /// Corresponds to the PostgreSQL arbitrary-precision "numeric" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Numeric = 13, + + /// + /// Corresponds to the PostgreSQL floating-point "real" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Real = 17, + + /// + /// Corresponds to the PostgreSQL 2-byte "smallint" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html + Smallint = 18, + + /// + /// Corresponds to the PostgreSQL "money" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-money.html + Money = 12, + + #endregion + + #region Boolean Type + + /// + /// Corresponds to the PostgreSQL "boolean" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-boolean.html + Boolean = 2, + + #endregion + + #region Geometric types + + /// + /// Corresponds to the PostgreSQL geometric "box" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Box = 3, + + /// + /// Corresponds to the PostgreSQL geometric "circle" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Circle = 5, + + /// + /// Corresponds to the PostgreSQL geometric "line" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Line = 10, + + /// + /// Corresponds to the PostgreSQL geometric "lseg" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + LSeg = 11, + + /// + /// Corresponds to the PostgreSQL geometric "path" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Path = 14, + + /// + /// Corresponds to the PostgreSQL geometric "point" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Point = 15, + + /// + /// Corresponds to the PostgreSQL geometric "polygon" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html + Polygon = 16, + + #endregion + + #region Character Types + + /// + /// Corresponds to the PostgreSQL "char(n)" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-character.html + Char = 6, + + /// + /// Corresponds to the PostgreSQL "text" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-character.html + Text = 19, + + /// + /// Corresponds to the PostgreSQL "varchar" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-character.html + Varchar = 22, + + /// + /// Corresponds to the PostgreSQL internal "name" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-character.html + Name = 32, + + /// + /// Corresponds to the PostgreSQL "citext" type for the citext module. + /// + /// See https://www.postgresql.org/docs/current/static/citext.html + Citext = 51, // Extension type + + /// + /// Corresponds to the PostgreSQL "char" type. + /// + /// + /// This is an internal field and should normally not be used for regular applications. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-text.html + /// + InternalChar = 38, + + #endregion + + #region Binary Data Types + + /// + /// Corresponds to the PostgreSQL "bytea" type, holding a raw byte string. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-binary.html + Bytea = 4, + + #endregion + + #region Date/Time Types + + /// + /// Corresponds to the PostgreSQL "date" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + Date = 7, + + /// + /// Corresponds to the PostgreSQL "time" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + Time = 20, + + /// + /// Corresponds to the PostgreSQL "timestamp" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + Timestamp = 21, + + /// + /// Corresponds to the PostgreSQL "timestamp with time zone" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + TimestampTz = 26, + + /// + /// Corresponds to the PostgreSQL "interval" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + Interval = 30, + + /// + /// Corresponds to the PostgreSQL "time with time zone" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + TimeTz = 31, + + /// + /// Corresponds to the obsolete PostgreSQL "abstime" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html + [Obsolete("The PostgreSQL abstime time is obsolete.")] + Abstime = 33, + + #endregion + + #region Network Address Types + + /// + /// Corresponds to the PostgreSQL "inet" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html + Inet = 24, + + /// + /// Corresponds to the PostgreSQL "cidr" type, a field storing an IPv4 or IPv6 network. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html + Cidr = 44, + + /// + /// Corresponds to the PostgreSQL "macaddr" type, a field storing a 6-byte physical address. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html + MacAddr = 34, + + /// + /// Corresponds to the PostgreSQL "macaddr8" type, a field storing a 6-byte or 8-byte physical address. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html + MacAddr8 = 54, + + #endregion + + #region Bit String Types + + /// + /// Corresponds to the PostgreSQL "bit" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-bit.html + Bit = 25, + + /// + /// Corresponds to the PostgreSQL "varbit" type, a field storing a variable-length string of bits. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-boolean.html + Varbit = 39, + + #endregion + + #region Text Search Types + + /// + /// Corresponds to the PostgreSQL "tsvector" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html + TsVector = 45, + + /// + /// Corresponds to the PostgreSQL "tsquery" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html + TsQuery = 46, + + /// + /// Corresponds to the PostgreSQL "regconfig" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html + Regconfig = 56, + + #endregion + + #region UUID Type + + /// + /// Corresponds to the PostgreSQL "uuid" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-uuid.html + Uuid = 27, + + #endregion - internal BuiltInPostgresType(string name, uint oid) + #region XML Type + + /// + /// Corresponds to the PostgreSQL "xml" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-xml.html + Xml = 28, + + #endregion + + #region JSON Types + + /// + /// Corresponds to the PostgreSQL "json" type, a field storing JSON in text format. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-json.html + /// + Json = 35, + + /// + /// Corresponds to the PostgreSQL "jsonb" type, a field storing JSON in an optimized binary. + /// format. + /// + /// + /// Supported since PostgreSQL 9.4. + /// See https://www.postgresql.org/docs/current/static/datatype-json.html + /// + Jsonb = 36, + + /// + /// Corresponds to the PostgreSQL "jsonpath" type, a field storing JSON path in text format. + /// format. + /// + /// + /// Supported since PostgreSQL 12. + /// See https://www.postgresql.org/docs/current/datatype-json.html#DATATYPE-JSONPATH + /// + JsonPath = 57, + + #endregion + + #region HSTORE Type + + /// + /// Corresponds to the PostgreSQL "hstore" type, a dictionary of string key-value pairs. + /// + /// See https://www.postgresql.org/docs/current/static/hstore.html + Hstore = 37, // Extension type + + #endregion + + #region Internal Types + + /// + /// Corresponds to the PostgreSQL "refcursor" type. + /// + Refcursor = 23, + + /// + /// Corresponds to the PostgreSQL internal "oidvector" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-oid.html + Oidvector = 29, + + /// + /// Corresponds to the PostgreSQL internal "int2vector" type. + /// + Int2Vector = 52, + + /// + /// Corresponds to the PostgreSQL "oid" type. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-oid.html + Oid = 41, + + /// + /// Corresponds to the PostgreSQL "xid" type, an internal transaction identifier. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-oid.html + Xid = 42, + + /// + /// Corresponds to the PostgreSQL "xid8" type, an internal transaction identifier. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-oid.html + Xid8 = 64, + + /// + /// Corresponds to the PostgreSQL "cid" type, an internal command identifier. + /// + /// See https://www.postgresql.org/docs/current/static/datatype-oid.html + Cid = 43, + + /// + /// Corresponds to the PostgreSQL "regtype" type, a numeric (OID) ID of a type in the pg_type table. + /// + Regtype = 49, + + /// + /// Corresponds to the PostgreSQL "tid" type, a tuple id identifying the physical location of a row within its table. + /// + Tid = 53, + + /// + /// Corresponds to the PostgreSQL "pg_lsn" type, which can be used to store LSN (Log Sequence Number) data which + /// is a pointer to a location in the WAL. + /// + /// + /// See: https://www.postgresql.org/docs/current/datatype-pg-lsn.html and + /// https://git.postgresql.org/gitweb/?p=postgresql.git;a=commit;h=7d03a83f4d0736ba869fa6f93973f7623a27038a + /// + PgLsn = 59, + + #endregion + + #region Special + + /// + /// A special value that can be used to send parameter values to the database without + /// specifying their type, allowing the database to cast them to another value based on context. + /// The value will be converted to a string and send as text. + /// + /// + /// This value shouldn't ordinarily be used, and makes sense only when sending a data type + /// unsupported by Npgsql. + /// + Unknown = 40, + + #endregion + + #region PostGIS + + /// + /// The geometry type for PostgreSQL spatial extension PostGIS. + /// + Geometry = 50, // Extension type + + /// + /// The geography (geodetic) type for PostgreSQL spatial extension PostGIS. + /// + Geography = 55, // Extension type + + #endregion + + #region Label tree types + + /// + /// The PostgreSQL ltree type, each value is a label path "a.label.tree.value", forming a tree in a set. + /// + /// See https://www.postgresql.org/docs/current/static/ltree.html + LTree = 60, // Extension type + + /// + /// The PostgreSQL lquery type for PostgreSQL extension ltree + /// + /// See https://www.postgresql.org/docs/current/static/ltree.html + LQuery = 61, // Extension type + + /// + /// The PostgreSQL ltxtquery type for PostgreSQL extension ltree + /// + /// See https://www.postgresql.org/docs/current/static/ltree.html + LTxtQuery = 62, // Extension type + + #endregion + + #region Range types + + /// + /// Corresponds to the PostgreSQL "int4range" type. + /// + IntegerRange = Range | Integer, + + /// + /// Corresponds to the PostgreSQL "int8range" type. + /// + BigIntRange = Range | Bigint, + + /// + /// Corresponds to the PostgreSQL "numrange" type. + /// + NumericRange = Range | Numeric, + + /// + /// Corresponds to the PostgreSQL "tsrange" type. + /// + TimestampRange = Range | Timestamp, + + /// + /// Corresponds to the PostgreSQL "tstzrange" type. + /// + TimestampTzRange = Range | TimestampTz, + + /// + /// Corresponds to the PostgreSQL "daterange" type. + /// + DateRange = Range | Date, + + #endregion Range types + + #region Multirange types + + /// + /// Corresponds to the PostgreSQL "int4multirange" type. + /// + IntegerMultirange = Multirange | Integer, + + /// + /// Corresponds to the PostgreSQL "int8multirange" type. + /// + BigIntMultirange = Multirange | Bigint, + + /// + /// Corresponds to the PostgreSQL "nummultirange" type. + /// + NumericMultirange = Multirange | Numeric, + + /// + /// Corresponds to the PostgreSQL "tsmultirange" type. + /// + TimestampMultirange = Multirange | Timestamp, + + /// + /// Corresponds to the PostgreSQL "tstzmultirange" type. + /// + TimestampTzMultirange = Multirange | TimestampTz, + + /// + /// Corresponds to the PostgreSQL "datemultirange" type. + /// + DateMultirange = Multirange | Date, + + #endregion Multirange types + + #region Composables + + /// + /// Corresponds to the PostgreSQL "array" type, a variable-length multidimensional array of + /// another type. This value must be combined with another value from + /// via a bit OR (e.g. NpgsqlDbType.Array | NpgsqlDbType.Integer) + /// + /// See https://www.postgresql.org/docs/current/static/arrays.html + Array = int.MinValue, + + /// + /// Corresponds to the PostgreSQL "range" type, continuous range of values of specific type. + /// This value must be combined with another value from + /// via a bit OR (e.g. NpgsqlDbType.Range | NpgsqlDbType.Integer) + /// + /// + /// Supported since PostgreSQL 9.2. + /// See https://www.postgresql.org/docs/current/static/rangetypes.html + /// + Range = 0x40000000, + + /// + /// Corresponds to the PostgreSQL "multirange" type, continuous range of values of specific type. + /// This value must be combined with another value from + /// via a bit OR (e.g. NpgsqlDbType.Multirange | NpgsqlDbType.Integer) + /// + /// + /// Supported since PostgreSQL 14. + /// See https://www.postgresql.org/docs/current/static/rangetypes.html + /// + Multirange = 0x20000000, + + #endregion +} + +static class NpgsqlDbTypeExtensions +{ + internal static NpgsqlDbType? ToNpgsqlDbType(this DbType dbType) + => dbType switch { - Name = name; - OID = oid; - } + DbType.AnsiString => NpgsqlDbType.Text, + DbType.Binary => NpgsqlDbType.Bytea, + DbType.Byte => NpgsqlDbType.Smallint, + DbType.Boolean => NpgsqlDbType.Boolean, + DbType.Currency => NpgsqlDbType.Money, + DbType.Date => NpgsqlDbType.Date, + DbType.DateTime => LegacyTimestampBehavior ? NpgsqlDbType.Timestamp : NpgsqlDbType.TimestampTz, + DbType.Decimal => NpgsqlDbType.Numeric, + DbType.VarNumeric => NpgsqlDbType.Numeric, + DbType.Double => NpgsqlDbType.Double, + DbType.Guid => NpgsqlDbType.Uuid, + DbType.Int16 => NpgsqlDbType.Smallint, + DbType.Int32 => NpgsqlDbType.Integer, + DbType.Int64 => NpgsqlDbType.Bigint, + DbType.Single => NpgsqlDbType.Real, + DbType.String => NpgsqlDbType.Text, + DbType.Time => NpgsqlDbType.Time, + DbType.AnsiStringFixedLength => NpgsqlDbType.Text, + DbType.StringFixedLength => NpgsqlDbType.Text, + DbType.Xml => NpgsqlDbType.Xml, + DbType.DateTime2 => NpgsqlDbType.Timestamp, + DbType.DateTimeOffset => NpgsqlDbType.TimestampTz, + + DbType.Object => null, + DbType.SByte => null, + DbType.UInt16 => null, + DbType.UInt32 => null, + DbType.UInt64 => null, + + _ => throw new ArgumentOutOfRangeException(nameof(dbType), dbType, null) + }; + + public static DbType ToDbType(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => DbType.Int16, + NpgsqlDbType.Integer => DbType.Int32, + NpgsqlDbType.Bigint => DbType.Int64, + NpgsqlDbType.Real => DbType.Single, + NpgsqlDbType.Double => DbType.Double, + NpgsqlDbType.Numeric => DbType.Decimal, + NpgsqlDbType.Money => DbType.Currency, + + // Text types + NpgsqlDbType.Text => DbType.String, + NpgsqlDbType.Xml => DbType.Xml, + NpgsqlDbType.Varchar => DbType.String, + NpgsqlDbType.Char => DbType.String, + NpgsqlDbType.Name => DbType.String, + NpgsqlDbType.Citext => DbType.String, + NpgsqlDbType.Refcursor => DbType.Object, + NpgsqlDbType.Jsonb => DbType.Object, + NpgsqlDbType.Json => DbType.Object, + NpgsqlDbType.JsonPath => DbType.Object, + + // Date/time types + NpgsqlDbType.Timestamp => LegacyTimestampBehavior ? DbType.DateTime : DbType.DateTime2, + NpgsqlDbType.TimestampTz => LegacyTimestampBehavior ? DbType.DateTimeOffset : DbType.DateTime, + NpgsqlDbType.Date => DbType.Date, + NpgsqlDbType.Time => DbType.Time, + + // Misc data types + NpgsqlDbType.Bytea => DbType.Binary, + NpgsqlDbType.Boolean => DbType.Boolean, + NpgsqlDbType.Uuid => DbType.Guid, + + NpgsqlDbType.Unknown => DbType.Object, + + _ => DbType.Object + }; + + /// Can return null when a custom range type is used. + internal static string? ToUnqualifiedDataTypeName(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => "int2", + NpgsqlDbType.Integer => "int4", + NpgsqlDbType.Bigint => "int8", + NpgsqlDbType.Real => "float4", + NpgsqlDbType.Double => "float8", + NpgsqlDbType.Numeric => "numeric", + NpgsqlDbType.Money => "money", + + // Text types + NpgsqlDbType.Text => "text", + NpgsqlDbType.Xml => "xml", + NpgsqlDbType.Varchar => "varchar", + NpgsqlDbType.Char => "bpchar", + NpgsqlDbType.Name => "name", + NpgsqlDbType.Refcursor => "refcursor", + NpgsqlDbType.Jsonb => "jsonb", + NpgsqlDbType.Json => "json", + NpgsqlDbType.JsonPath => "jsonpath", + + // Date/time types + NpgsqlDbType.Timestamp => "timestamp", + NpgsqlDbType.TimestampTz => "timestamptz", + NpgsqlDbType.Date => "date", + NpgsqlDbType.Time => "time", + NpgsqlDbType.TimeTz => "timetz", + NpgsqlDbType.Interval => "interval", + + // Network types + NpgsqlDbType.Cidr => "cidr", + NpgsqlDbType.Inet => "inet", + NpgsqlDbType.MacAddr => "macaddr", + NpgsqlDbType.MacAddr8 => "macaddr8", + + // Full-text search types + NpgsqlDbType.TsQuery => "tsquery", + NpgsqlDbType.TsVector => "tsvector", + + // Geometry types + NpgsqlDbType.Box => "box", + NpgsqlDbType.Circle => "circle", + NpgsqlDbType.Line => "line", + NpgsqlDbType.LSeg => "lseg", + NpgsqlDbType.Path => "path", + NpgsqlDbType.Point => "point", + NpgsqlDbType.Polygon => "polygon", + + + // UInt types + NpgsqlDbType.Oid => "oid", + NpgsqlDbType.Xid => "xid", + NpgsqlDbType.Xid8 => "xid8", + NpgsqlDbType.Cid => "cid", + NpgsqlDbType.Regtype => "regtype", + NpgsqlDbType.Regconfig => "regconfig", + + // Misc types + NpgsqlDbType.Boolean => "bool", + NpgsqlDbType.Bytea => "bytea", + NpgsqlDbType.Uuid => "uuid", + NpgsqlDbType.Varbit => "varbit", + NpgsqlDbType.Bit => "bit", + + // Built-in range types + NpgsqlDbType.IntegerRange => "int4range", + NpgsqlDbType.BigIntRange => "int8range", + NpgsqlDbType.NumericRange => "numrange", + NpgsqlDbType.TimestampRange => "tsrange", + NpgsqlDbType.TimestampTzRange => "tstzrange", + NpgsqlDbType.DateRange => "daterange", + + // Built-in multirange types + NpgsqlDbType.IntegerMultirange => "int4multirange", + NpgsqlDbType.BigIntMultirange => "int8multirange", + NpgsqlDbType.NumericMultirange => "nummultirange", + NpgsqlDbType.TimestampMultirange => "tsmultirange", + NpgsqlDbType.TimestampTzMultirange => "tstzmultirange", + NpgsqlDbType.DateMultirange => "datemultirange", + + // Internal types + NpgsqlDbType.Int2Vector => "int2vector", + NpgsqlDbType.Oidvector => "oidvector", + NpgsqlDbType.PgLsn => "pg_lsn", + NpgsqlDbType.Tid => "tid", + NpgsqlDbType.InternalChar => "char", + + // Plugin types + NpgsqlDbType.Citext => "citext", + NpgsqlDbType.LQuery => "lquery", + NpgsqlDbType.LTree => "ltree", + NpgsqlDbType.LTxtQuery => "ltxtquery", + NpgsqlDbType.Hstore => "hstore", + NpgsqlDbType.Geometry => "geometry", + NpgsqlDbType.Geography => "geography", + + NpgsqlDbType.Unknown => "unknown", + + // Unknown cannot be composed + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown + => "unknown", + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown + => "unknown", + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown + => "unknown", + + _ => npgsqlDbType.HasFlag(NpgsqlDbType.Array) + ? ToUnqualifiedDataTypeName(npgsqlDbType & ~NpgsqlDbType.Array) is { } name ? "_" + name : null + : null // e.g. ranges + }; + + internal static string ToUnqualifiedDataTypeNameOrThrow(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType.ToUnqualifiedDataTypeName() ?? throw new ArgumentOutOfRangeException(nameof(npgsqlDbType), npgsqlDbType, "Cannot convert NpgsqlDbType to DataTypeName"); + + /// Can return null when a plugin type or custom range type is used. + internal static DataTypeName? ToDataTypeName(this NpgsqlDbType npgsqlDbType) + => npgsqlDbType switch + { + // Numeric types + NpgsqlDbType.Smallint => DataTypeNames.Int2, + NpgsqlDbType.Integer => DataTypeNames.Int4, + NpgsqlDbType.Bigint => DataTypeNames.Int8, + NpgsqlDbType.Real => DataTypeNames.Float4, + NpgsqlDbType.Double => DataTypeNames.Float8, + NpgsqlDbType.Numeric => DataTypeNames.Numeric, + NpgsqlDbType.Money => DataTypeNames.Money, + + // Text types + NpgsqlDbType.Text => DataTypeNames.Text, + NpgsqlDbType.Xml => DataTypeNames.Xml, + NpgsqlDbType.Varchar => DataTypeNames.Varchar, + NpgsqlDbType.Char => DataTypeNames.Bpchar, + NpgsqlDbType.Name => DataTypeNames.Name, + NpgsqlDbType.Refcursor => DataTypeNames.RefCursor, + NpgsqlDbType.Jsonb => DataTypeNames.Jsonb, + NpgsqlDbType.Json => DataTypeNames.Json, + NpgsqlDbType.JsonPath => DataTypeNames.Jsonpath, + + // Date/time types + NpgsqlDbType.Timestamp => DataTypeNames.Timestamp, + NpgsqlDbType.TimestampTz => DataTypeNames.TimestampTz, + NpgsqlDbType.Date => DataTypeNames.Date, + NpgsqlDbType.Time => DataTypeNames.Time, + NpgsqlDbType.TimeTz => DataTypeNames.TimeTz, + NpgsqlDbType.Interval => DataTypeNames.Interval, + + // Network types + NpgsqlDbType.Cidr => DataTypeNames.Cidr, + NpgsqlDbType.Inet => DataTypeNames.Inet, + NpgsqlDbType.MacAddr => DataTypeNames.MacAddr, + NpgsqlDbType.MacAddr8 => DataTypeNames.MacAddr8, + + // Full-text search types + NpgsqlDbType.TsQuery => DataTypeNames.TsQuery, + NpgsqlDbType.TsVector => DataTypeNames.TsVector, + + // Geometry types + NpgsqlDbType.Box => DataTypeNames.Box, + NpgsqlDbType.Circle => DataTypeNames.Circle, + NpgsqlDbType.Line => DataTypeNames.Line, + NpgsqlDbType.LSeg => DataTypeNames.LSeg, + NpgsqlDbType.Path => DataTypeNames.Path, + NpgsqlDbType.Point => DataTypeNames.Point, + NpgsqlDbType.Polygon => DataTypeNames.Polygon, + + // UInt types + NpgsqlDbType.Oid => DataTypeNames.Oid, + NpgsqlDbType.Xid => DataTypeNames.Xid, + NpgsqlDbType.Xid8 => DataTypeNames.Xid8, + NpgsqlDbType.Cid => DataTypeNames.Cid, + NpgsqlDbType.Regtype => DataTypeNames.RegType, + NpgsqlDbType.Regconfig => DataTypeNames.RegConfig, + + // Misc types + NpgsqlDbType.Boolean => DataTypeNames.Bool, + NpgsqlDbType.Bytea => DataTypeNames.Bytea, + NpgsqlDbType.Uuid => DataTypeNames.Uuid, + NpgsqlDbType.Varbit => DataTypeNames.Varbit, + NpgsqlDbType.Bit => DataTypeNames.Bit, + + // Built-in range types + NpgsqlDbType.IntegerRange => DataTypeNames.Int4Range, + NpgsqlDbType.BigIntRange => DataTypeNames.Int8Range, + NpgsqlDbType.NumericRange => DataTypeNames.NumRange, + NpgsqlDbType.TimestampRange => DataTypeNames.TsRange, + NpgsqlDbType.TimestampTzRange => DataTypeNames.TsTzRange, + NpgsqlDbType.DateRange => DataTypeNames.DateRange, + + // Internal types + NpgsqlDbType.Int2Vector => DataTypeNames.Int2Vector, + NpgsqlDbType.Oidvector => DataTypeNames.OidVector, + NpgsqlDbType.PgLsn => DataTypeNames.PgLsn, + NpgsqlDbType.Tid => DataTypeNames.Tid, + NpgsqlDbType.InternalChar => DataTypeNames.Char, + + // Special types + NpgsqlDbType.Unknown => DataTypeNames.Unknown, + + // Unknown cannot be composed + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown + => DataTypeNames.Unknown, + + // If both multirange and array are set we first remove array, so array is added to the outermost datatypename. + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) + => ToDataTypeName(npgsqlDbType & ~NpgsqlDbType.Array)?.ToArrayName(), + _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) + => ToDataTypeName((npgsqlDbType | NpgsqlDbType.Range) & ~NpgsqlDbType.Multirange)?.ToDefaultMultirangeName(), + + // Plugin types don't have a stable fully qualified name. + _ => null + }; + + internal static NpgsqlDbType? ToNpgsqlDbType(this DataTypeName dataTypeName) => ToNpgsqlDbType(dataTypeName.UnqualifiedName); + /// Should not be used with display names, first normalize it instead. + internal static NpgsqlDbType? ToNpgsqlDbType(string dataTypeName) + { + var unqualifiedName = dataTypeName; + if (dataTypeName.IndexOf(".", StringComparison.Ordinal) is not -1 and var index) + unqualifiedName = dataTypeName.Substring(0, index); + + return unqualifiedName switch + { + // Numeric types + "int2" => NpgsqlDbType.Smallint, + "int4" => NpgsqlDbType.Integer, + "int8" => NpgsqlDbType.Bigint, + "float4" => NpgsqlDbType.Real, + "float8" => NpgsqlDbType.Double, + "numeric" => NpgsqlDbType.Numeric, + "money" => NpgsqlDbType.Money, + + // Text types + "text" => NpgsqlDbType.Text, + "xml" => NpgsqlDbType.Xml, + "varchar" => NpgsqlDbType.Varchar, + "bpchar" => NpgsqlDbType.Char, + "name" => NpgsqlDbType.Name, + "refcursor" => NpgsqlDbType.Refcursor, + "jsonb" => NpgsqlDbType.Jsonb, + "json" => NpgsqlDbType.Json, + "jsonpath" => NpgsqlDbType.JsonPath, + + // Date/time types + "timestamp" => NpgsqlDbType.Timestamp, + "timestamptz" => NpgsqlDbType.TimestampTz, + "date" => NpgsqlDbType.Date, + "time" => NpgsqlDbType.Time, + "timetz" => NpgsqlDbType.TimeTz, + "interval" => NpgsqlDbType.Interval, + + // Network types + "cidr" => NpgsqlDbType.Cidr, + "inet" => NpgsqlDbType.Inet, + "macaddr" => NpgsqlDbType.MacAddr, + "macaddr8" => NpgsqlDbType.MacAddr8, + + // Full-text search types + "tsquery" => NpgsqlDbType.TsQuery, + "tsvector" => NpgsqlDbType.TsVector, + + // Geometry types + "box" => NpgsqlDbType.Box, + "circle" => NpgsqlDbType.Circle, + "line" => NpgsqlDbType.Line, + "lseg" => NpgsqlDbType.LSeg, + "path" => NpgsqlDbType.Path, + "point" => NpgsqlDbType.Point, + "polygon" => NpgsqlDbType.Polygon, + + // UInt types + "oid" => NpgsqlDbType.Oid, + "xid" => NpgsqlDbType.Xid, + "xid8" => NpgsqlDbType.Xid8, + "cid" => NpgsqlDbType.Cid, + "regtype" => NpgsqlDbType.Regtype, + "regconfig" => NpgsqlDbType.Regconfig, + + // Misc types + "bool" => NpgsqlDbType.Boolean, + "bytea" => NpgsqlDbType.Bytea, + "uuid" => NpgsqlDbType.Uuid, + "varbit" => NpgsqlDbType.Varbit, + "bit" => NpgsqlDbType.Bit, + + // Built-in range types + "int4range" => NpgsqlDbType.IntegerRange, + "int8range" => NpgsqlDbType.BigIntRange, + "numrange" => NpgsqlDbType.NumericRange, + "tsrange" => NpgsqlDbType.TimestampRange, + "tstzrange" => NpgsqlDbType.TimestampTzRange, + "daterange" => NpgsqlDbType.DateRange, + + // Built-in multirange types + "int4multirange" => NpgsqlDbType.IntegerMultirange, + "int8multirange" => NpgsqlDbType.BigIntMultirange, + "nummultirange" => NpgsqlDbType.NumericMultirange, + "tsmultirange" => NpgsqlDbType.TimestampMultirange, + "tstzmultirange" => NpgsqlDbType.TimestampTzMultirange, + "datemultirange" => NpgsqlDbType.DateMultirange, + + // Internal types + "int2vector" => NpgsqlDbType.Int2Vector, + "oidvector" => NpgsqlDbType.Oidvector, + "pg_lsn" => NpgsqlDbType.PgLsn, + "tid" => NpgsqlDbType.Tid, + "char" => NpgsqlDbType.InternalChar, + + // Plugin types + "citext" => NpgsqlDbType.Citext, + "lquery" => NpgsqlDbType.LQuery, + "ltree" => NpgsqlDbType.LTree, + "ltxtquery" => NpgsqlDbType.LTxtQuery, + "hstore" => NpgsqlDbType.Hstore, + "geometry" => NpgsqlDbType.Geometry, + "geography" => NpgsqlDbType.Geography, + + _ when unqualifiedName.Contains("unknown") + => !unqualifiedName.StartsWith("_", StringComparison.Ordinal) + ? NpgsqlDbType.Unknown + : null, + _ when unqualifiedName.StartsWith("_", StringComparison.Ordinal) + => ToNpgsqlDbType(unqualifiedName.Substring(1)) is { } elementNpgsqlDbType + ? elementNpgsqlDbType | NpgsqlDbType.Array + : null, + // e.g. custom ranges, plugin types etc. + _ => null + }; } } diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs b/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs new file mode 100644 index 0000000000..f4b51ba4a9 --- /dev/null +++ b/src/Npgsql/NpgsqlTypes/NpgsqlInterval.cs @@ -0,0 +1,53 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace NpgsqlTypes; + +/// +/// A raw representation of the PostgreSQL interval datatype. Use only when or NodaTime +/// Period do not have sufficient range to handle your values. +/// +/// +///

+/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. +///

+///

+/// Do not use this type unless you have to: prefer or NodaTime +/// Period when possible. +///

+///
+public readonly struct NpgsqlInterval : IEquatable +{ + /// + /// Constructs an . + /// + public NpgsqlInterval(int months, int days, long time) + => (Months, Days, Time) = (months, days, time); + + /// + /// Months and years, after time for alignment. + /// + public int Months { get; } + + /// + /// Days, after time for alignment. + /// + public int Days { get; } + + /// + /// Remaining time unit smaller than a day, in microseconds. + /// + public long Time { get; } + + /// + public bool Equals(NpgsqlInterval other) + => Months == other.Months && Days == other.Days && Time == other.Time; + + /// + public override bool Equals(object? obj) + => obj is NpgsqlInterval other && Equals(other); + + /// + public override int GetHashCode() + => HashCode.Combine(Months, Days, Time); +} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlLogSequenceNumber.cs b/src/Npgsql/NpgsqlTypes/NpgsqlLogSequenceNumber.cs index ea9a850ec3..00ff4131e4 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlLogSequenceNumber.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlLogSequenceNumber.cs @@ -2,340 +2,339 @@ using System.Globalization; // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Wraps a PostgreSQL Write-Ahead Log Sequence Number (see: https://www.postgresql.org/docs/current/datatype-pg-lsn.html) +/// +/// +/// Log Sequence Numbers are a fundamental concept of the PostgreSQL Write-Ahead Log and by that of +/// PostgreSQL replication. See https://www.postgresql.org/docs/current/wal-internals.html for what they represent. +/// +/// This struct provides conversions from/to and and beyond that tries to port +/// the methods and operators in https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/adt/pg_lsn.c +/// but nothing more. +/// +public readonly struct NpgsqlLogSequenceNumber : IEquatable, IComparable { /// - /// Wraps a PostgreSQL Write-Ahead Log Sequence Number (see: https://www.postgresql.org/docs/current/datatype-pg-lsn.html) + /// Zero is used indicate an invalid Log Sequence Number. No XLOG record can begin at zero. /// - /// - /// Log Sequence Numbers are a fundamental concept of the PostgreSQL Write-Ahead Log and by that of - /// PostgreSQL replication. See https://www.postgresql.org/docs/current/wal-internals.html for what they represent. - /// - /// This struct provides conversions from/to and and beyond that tries to port - /// the methods and operators in https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/adt/pg_lsn.c - /// but nothing more. - /// - public readonly struct NpgsqlLogSequenceNumber : IEquatable, IComparable - { - /// - /// Zero is used indicate an invalid Log Sequence Number. No XLOG record can begin at zero. - /// - public static readonly NpgsqlLogSequenceNumber Invalid = default; + public static readonly NpgsqlLogSequenceNumber Invalid = default; - readonly ulong _value; + readonly ulong _value; - /// - /// Initializes a new instance of . - /// - /// The value to wrap. - public NpgsqlLogSequenceNumber(ulong value) - => _value = value; + /// + /// Initializes a new instance of . + /// + /// The value to wrap. + public NpgsqlLogSequenceNumber(ulong value) + => _value = value; - /// - /// Returns a value indicating whether this instance is equal to a specified - /// instance. - /// - /// A instance to compare to this instance. - /// if the current instance is equal to the value parameter; - /// otherwise, . - public bool Equals(NpgsqlLogSequenceNumber other) - => _value == other._value; + /// + /// Returns a value indicating whether this instance is equal to a specified + /// instance. + /// + /// A instance to compare to this instance. + /// if the current instance is equal to the value parameter; + /// otherwise, . + public bool Equals(NpgsqlLogSequenceNumber other) + => _value == other._value; - /// - /// Compares this instance to a specified and returns an indication of their - /// relative values. - /// - /// A instance to compare to this instance. - /// A signed number indicating the relative values of this instance and . - public int CompareTo(NpgsqlLogSequenceNumber value) - => _value.CompareTo(value._value); + /// + /// Compares this instance to a specified and returns an indication of their + /// relative values. + /// + /// A instance to compare to this instance. + /// A signed number indicating the relative values of this instance and . + public int CompareTo(NpgsqlLogSequenceNumber value) + => _value.CompareTo(value._value); - /// - /// Returns a value indicating whether this instance is equal to a specified object. - /// - /// An object to compare to this instance - /// if the current instance is equal to the value parameter; - /// otherwise, . - public override bool Equals(object? obj) - => obj is NpgsqlLogSequenceNumber lsn && lsn._value == _value; + /// + /// Returns a value indicating whether this instance is equal to a specified object. + /// + /// An object to compare to this instance + /// if the current instance is equal to the value parameter; + /// otherwise, . + public override bool Equals(object? obj) + => obj is NpgsqlLogSequenceNumber lsn && lsn._value == _value; - /// - /// Returns the hash code for this instance. - /// - /// A 32-bit signed integer hash code. - public override int GetHashCode() - => _value.GetHashCode(); + /// + /// Returns the hash code for this instance. + /// + /// A 32-bit signed integer hash code. + public override int GetHashCode() + => _value.GetHashCode(); - /// - /// Converts the numeric value of this instance to its equivalent string representation. - /// - /// The string representation of the value of this instance, consisting of two hexadecimal numbers of - /// up to 8 digits each, separated by a slash - public override string ToString() - => unchecked($"{(uint)(_value >> 32):X}/{(uint)_value:X}"); + /// + /// Converts the numeric value of this instance to its equivalent string representation. + /// + /// The string representation of the value of this instance, consisting of two hexadecimal numbers of + /// up to 8 digits each, separated by a slash + public override string ToString() + => unchecked($"{(uint)(_value >> 32):X}/{(uint)_value:X}"); - /// - /// Converts the string representation of a Log Sequence Number to a instance. - /// - /// A string that represents the Log Sequence Number to convert. - /// - /// A equivalent to the Log Sequence Number specified in . - /// - /// The parameter is . - /// - /// The parameter represents a number less than or greater than - /// . - /// - /// The parameter is not in the right format. - public static NpgsqlLogSequenceNumber Parse(string s) - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - => s is null - ? throw new ArgumentNullException(nameof(s)) - : Parse(s.AsSpan()); + /// + /// Converts the string representation of a Log Sequence Number to a instance. + /// + /// A string that represents the Log Sequence Number to convert. + /// + /// A equivalent to the Log Sequence Number specified in . + /// + /// The parameter is . + /// + /// The parameter represents a number less than or greater than + /// . + /// + /// The parameter is not in the right format. + public static NpgsqlLogSequenceNumber Parse(string s) + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + => s is null + ? throw new ArgumentNullException(nameof(s)) + : Parse(s.AsSpan()); - /// - /// Converts the span representation of a Log Sequence Number to a instance. - /// - /// A span containing the characters that represent the Log Sequence Number to convert. - /// - /// A equivalent to the Log Sequence Number specified in . - /// - /// - /// The parameter represents a number less than or greater than - /// . - /// - /// The parameter is not in the right format. - public static NpgsqlLogSequenceNumber Parse(ReadOnlySpan s) - => TryParse(s, out var parsed) - ? parsed - : throw new FormatException($"Invalid Log Sequence Number: '{s.ToString()}'."); + /// + /// Converts the span representation of a Log Sequence Number to a instance. + /// + /// A span containing the characters that represent the Log Sequence Number to convert. + /// + /// A equivalent to the Log Sequence Number specified in . + /// + /// + /// The parameter represents a number less than or greater than + /// . + /// + /// The parameter is not in the right format. + public static NpgsqlLogSequenceNumber Parse(ReadOnlySpan s) + => TryParse(s, out var parsed) + ? parsed + : throw new FormatException($"Invalid Log Sequence Number: '{s.ToString()}'."); - /// - /// Tries to convert the string representation of a Log Sequence Number to an - /// instance. A return value indicates whether the conversion succeeded or failed. - /// - /// A string that represents the Log Sequence Number to convert. - /// - /// When this method returns, contains a instance equivalent to the Log Sequence - /// Number contained in , if the conversion succeeded, or the default value for - /// (0) if the conversion failed. The conversion fails if the - /// parameter is or , is not in the right format, or represents a number - /// less than or greater than . This parameter is - /// passed uninitialized; any value originally supplied in result will be overwritten. - /// - /// - /// if c> was converted successfully; otherwise, . - /// - public static bool TryParse(string s, out NpgsqlLogSequenceNumber result) - => TryParse(s.AsSpan(), out result); + /// + /// Tries to convert the string representation of a Log Sequence Number to an + /// instance. A return value indicates whether the conversion succeeded or failed. + /// + /// A string that represents the Log Sequence Number to convert. + /// + /// When this method returns, contains a instance equivalent to the Log Sequence + /// Number contained in , if the conversion succeeded, or the default value for + /// (0) if the conversion failed. The conversion fails if the + /// parameter is or , is not in the right format, or represents a number + /// less than or greater than . This parameter is + /// passed uninitialized; any value originally supplied in result will be overwritten. + /// + /// + /// if c> was converted successfully; otherwise, . + /// + public static bool TryParse(string s, out NpgsqlLogSequenceNumber result) + => TryParse(s.AsSpan(), out result); - /// - /// Tries to convert the span representation of a Log Sequence Number to an - /// instance. A return value indicates whether the conversion succeeded or failed. - /// - /// A span containing the characters that represent the Log Sequence Number to convert. - /// - /// When this method returns, contains a instance equivalent to the Log Sequence - /// Number contained in , if the conversion succeeded, or the default value for - /// (0) if the conversion failed. The conversion fails if the - /// parameter is empty, is not in the right format, or represents a number less than - /// or greater than . This parameter is passed - /// uninitialized; any value originally supplied in result will be overwritten. - /// - /// - /// if was converted successfully; otherwise, . - public static bool TryParse(ReadOnlySpan s, out NpgsqlLogSequenceNumber result) + /// + /// Tries to convert the span representation of a Log Sequence Number to an + /// instance. A return value indicates whether the conversion succeeded or failed. + /// + /// A span containing the characters that represent the Log Sequence Number to convert. + /// + /// When this method returns, contains a instance equivalent to the Log Sequence + /// Number contained in , if the conversion succeeded, or the default value for + /// (0) if the conversion failed. The conversion fails if the + /// parameter is empty, is not in the right format, or represents a number less than + /// or greater than . This parameter is passed + /// uninitialized; any value originally supplied in result will be overwritten. + /// + /// + /// if was converted successfully; otherwise, . + public static bool TryParse(ReadOnlySpan s, out NpgsqlLogSequenceNumber result) + { + for (var i = 0; i < s.Length; i++) { - for (var i = 0; i < s.Length; i++) - { - if (s[i] != '/') continue; + if (s[i] != '/') continue; #if NETSTANDARD2_0 - var firstPart = s.Slice(0, i).ToString(); - var secondPart = s.Slice(++i).ToString(); + var firstPart = s.Slice(0, i).ToString(); + var secondPart = s.Slice(++i).ToString(); #else - var firstPart = s.Slice(0, i); - var secondPart = s.Slice(++i); + var firstPart = s.Slice(0, i); + var secondPart = s.Slice(++i); #endif - if (!uint.TryParse(firstPart, NumberStyles.AllowHexSpecifier, null, out var first)) - { - result = default; - return false; - } - if (!uint.TryParse(secondPart, NumberStyles.AllowHexSpecifier, null, out var second)) - { - result = default; - return false; - } - result = new NpgsqlLogSequenceNumber(((ulong)first << 32) + second); - return true; + if (!uint.TryParse(firstPart, NumberStyles.AllowHexSpecifier, null, out var first)) + { + result = default; + return false; } - result = default; - return false; + if (!uint.TryParse(secondPart, NumberStyles.AllowHexSpecifier, null, out var second)) + { + result = default; + return false; + } + result = new NpgsqlLogSequenceNumber(((ulong)first << 32) + second); + return true; } + result = default; + return false; + } - /// - /// Converts the value of a 64-bit unsigned integer to a instance. - /// - /// A 64-bit unsigned integer. - /// A new instance of initialized to . - public static explicit operator NpgsqlLogSequenceNumber(ulong value) - => new NpgsqlLogSequenceNumber(value); + /// + /// Converts the value of a 64-bit unsigned integer to a instance. + /// + /// A 64-bit unsigned integer. + /// A new instance of initialized to . + public static explicit operator NpgsqlLogSequenceNumber(ulong value) + => new(value); - /// - /// Converts the value of a instance to a 64-bit unsigned integer value. - /// - /// A instance - /// The contents of as 64-bit unsigned integer. - public static explicit operator ulong(NpgsqlLogSequenceNumber value) - => value._value; + /// + /// Converts the value of a instance to a 64-bit unsigned integer value. + /// + /// A instance + /// The contents of as 64-bit unsigned integer. + public static explicit operator ulong(NpgsqlLogSequenceNumber value) + => value._value; - /// - /// Returns a value that indicates whether two specified instances of are equal. - /// - /// The first Log Sequence Number to compare. - /// The second Log Sequence Number to compare. - /// - /// if equals ; otherwise, . - /// - public static bool operator ==(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value == value2._value; + /// + /// Returns a value that indicates whether two specified instances of are equal. + /// + /// The first Log Sequence Number to compare. + /// The second Log Sequence Number to compare. + /// + /// if equals ; otherwise, . + /// + public static bool operator ==(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value == value2._value; - /// - /// Returns a value that indicates whether two specified instances of are not - /// equal. - /// - /// The first Log Sequence Number to compare. - /// The second Log Sequence Number to compare. - /// - /// if does not equal ; otherwise, - /// . - /// - public static bool operator !=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value != value2._value; + /// + /// Returns a value that indicates whether two specified instances of are not + /// equal. + /// + /// The first Log Sequence Number to compare. + /// The second Log Sequence Number to compare. + /// + /// if does not equal ; otherwise, + /// . + /// + public static bool operator !=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value != value2._value; - /// - /// Returns a value indicating whether a specified instance is greater than - /// another specified instance. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// if is greater than ; otherwise, - /// . - /// - public static bool operator >(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value > value2._value; + /// + /// Returns a value indicating whether a specified instance is greater than + /// another specified instance. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// if is greater than ; otherwise, + /// . + /// + public static bool operator >(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value > value2._value; - /// - /// Returns a value indicating whether a specified instance is less than - /// another specified instance. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// if is less than ; otherwise, - /// . - /// - public static bool operator <(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value < value2._value; + /// + /// Returns a value indicating whether a specified instance is less than + /// another specified instance. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// if is less than ; otherwise, + /// . + /// + public static bool operator <(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value < value2._value; - /// - /// Returns a value indicating whether a specified instance is greater than or - /// equal to another specified instance. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// if is greater than or equal to ; - /// otherwise, . - /// - public static bool operator >=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value >= value2._value; + /// + /// Returns a value indicating whether a specified instance is greater than or + /// equal to another specified instance. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// if is greater than or equal to ; + /// otherwise, . + /// + public static bool operator >=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value >= value2._value; - /// - /// Returns the larger of two values. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// The larger of the two values. - /// - public static NpgsqlLogSequenceNumber Larger(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value > value2._value ? value1 : value2; + /// + /// Returns the larger of two values. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// The larger of the two values. + /// + public static NpgsqlLogSequenceNumber Larger(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value > value2._value ? value1 : value2; - /// - /// Returns the smaller of two values. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// The smaller of the two values. - /// - public static NpgsqlLogSequenceNumber Smaller(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value < value2._value ? value1 : value2; + /// + /// Returns the smaller of two values. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// The smaller of the two values. + /// + public static NpgsqlLogSequenceNumber Smaller(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value < value2._value ? value1 : value2; - /// - /// Returns a value indicating whether a specified instance is less than or - /// equal to another specified instance. - /// - /// The first value to compare. - /// The second value to compare. - /// - /// if is less than or equal to ; - /// otherwise, . - /// - public static bool operator <=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) - => value1._value <= value2._value; + /// + /// Returns a value indicating whether a specified instance is less than or + /// equal to another specified instance. + /// + /// The first value to compare. + /// The second value to compare. + /// + /// if is less than or equal to ; + /// otherwise, . + /// + public static bool operator <=(NpgsqlLogSequenceNumber value1, NpgsqlLogSequenceNumber value2) + => value1._value <= value2._value; - /// - /// Subtracts two specified values. - /// - /// The first value. - /// The second value. - /// The number of bytes separating those write-ahead log locations. - public static ulong operator -(NpgsqlLogSequenceNumber first, NpgsqlLogSequenceNumber second) - => first._value < second._value - ? second._value - first._value - : first._value - second._value; + /// + /// Subtracts two specified values. + /// + /// The first value. + /// The second value. + /// The number of bytes separating those write-ahead log locations. + public static ulong operator -(NpgsqlLogSequenceNumber first, NpgsqlLogSequenceNumber second) + => first._value < second._value + ? second._value - first._value + : first._value - second._value; - /// - /// Subtract the number of bytes from a instance, giving a new - /// instance. - /// Handles both positive and negative numbers of bytes. - /// - /// - /// The instance representing a write-ahead log location. - /// - /// The number of bytes to subtract. - /// A new instance. - /// - /// The resulting instance would represent a number less than - /// . - /// - public static NpgsqlLogSequenceNumber operator -(NpgsqlLogSequenceNumber lsn, double nbytes) - => double.IsNaN(nbytes) || double.IsInfinity(nbytes) - ? throw new NotFiniteNumberException($"Cannot subtract {nbytes} from {nameof(NpgsqlLogSequenceNumber)}", nbytes) - : new NpgsqlLogSequenceNumber(checked((ulong)(lsn._value - nbytes))); + /// + /// Subtract the number of bytes from a instance, giving a new + /// instance. + /// Handles both positive and negative numbers of bytes. + /// + /// + /// The instance representing a write-ahead log location. + /// + /// The number of bytes to subtract. + /// A new instance. + /// + /// The resulting instance would represent a number less than + /// . + /// + public static NpgsqlLogSequenceNumber operator -(NpgsqlLogSequenceNumber lsn, double nbytes) + => double.IsNaN(nbytes) || double.IsInfinity(nbytes) + ? throw new NotFiniteNumberException($"Cannot subtract {nbytes} from {nameof(NpgsqlLogSequenceNumber)}", nbytes) + : new NpgsqlLogSequenceNumber(checked((ulong)(lsn._value - nbytes))); - /// - /// Add the number of bytes to a instance, giving a new - /// instance. - /// Handles both positive and negative numbers of bytes. - /// - /// - /// The instance representing a write-ahead log location. - /// - /// The number of bytes to add. - /// A new instance. - /// - /// The resulting instance would represent a number greater than - /// . - /// - public static NpgsqlLogSequenceNumber operator +(NpgsqlLogSequenceNumber lsn, double nbytes) - => double.IsNaN(nbytes) || double.IsInfinity(nbytes) - ? throw new NotFiniteNumberException($"Cannot add {nbytes} to {nameof(NpgsqlLogSequenceNumber)}", nbytes) - : new NpgsqlLogSequenceNumber(checked((ulong)(lsn._value + nbytes))); - } -} + /// + /// Add the number of bytes to a instance, giving a new + /// instance. + /// Handles both positive and negative numbers of bytes. + /// + /// + /// The instance representing a write-ahead log location. + /// + /// The number of bytes to add. + /// A new instance. + /// + /// The resulting instance would represent a number greater than + /// . + /// + public static NpgsqlLogSequenceNumber operator +(NpgsqlLogSequenceNumber lsn, double nbytes) + => double.IsNaN(nbytes) || double.IsInfinity(nbytes) + ? throw new NotFiniteNumberException($"Cannot add {nbytes} to {nameof(NpgsqlLogSequenceNumber)}", nbytes) + : new NpgsqlLogSequenceNumber(checked((ulong)(lsn._value + nbytes))); +} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs b/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs index 3df7c183e9..c260202ce9 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs @@ -5,526 +5,526 @@ using System.Text; // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL range type. +/// +/// The element type of the values in the range. +/// +/// See: https://www.postgresql.org/docs/current/static/rangetypes.html +/// +public readonly struct NpgsqlRange : IEquatable> { + // ----------------------------------------------------------------------------------------------- + // Regarding bitwise flag checks via @roji: + // + // > Note that Flags.HasFlag() used to be very inefficient compared to simply doing the + // > bit operation - this is why I've always avoided it. .NET Core 2.1 adds JIT intrinstics + // > for this, making Enum.HasFlag() fast, but I honestly don't see the value over just doing + // > a bitwise and operation, which would also be fast under .NET Core 2.0 and .NET Framework. + // + // See: + // - https://github.com/npgsql/npgsql/pull/1939#pullrequestreview-121308396 + // - https://blogs.msdn.microsoft.com/dotnet/2018/04/18/performance-improvements-in-net-core-2-1 + // ----------------------------------------------------------------------------------------------- + /// - /// Represents a PostgreSQL range type. + /// Defined by PostgreSQL to represent an empty range. /// - /// The element type of the values in the range. - /// - /// See: https://www.postgresql.org/docs/current/static/rangetypes.html - /// - public readonly struct NpgsqlRange : IEquatable> - { - // ----------------------------------------------------------------------------------------------- - // Regarding bitwise flag checks via @roji: - // - // > Note that Flags.HasFlag() used to be very inefficient compared to simply doing the - // > bit operation - this is why I've always avoided it. .NET Core 2.1 adds JIT intrinstics - // > for this, making Enum.HasFlag() fast, but I honestly don't see the value over just doing - // > a bitwise and operation, which would also be fast under .NET Core 2.0 and .NET Framework. - // - // See: - // - https://github.com/npgsql/npgsql/pull/1939#pullrequestreview-121308396 - // - https://blogs.msdn.microsoft.com/dotnet/2018/04/18/performance-improvements-in-net-core-2-1 - // ----------------------------------------------------------------------------------------------- + const string EmptyLiteral = "empty"; - /// - /// Defined by PostgreSQL to represent an empty range. - /// - const string EmptyLiteral = "empty"; + /// + /// Defined by PostgreSQL to represent an infinite lower bound. + /// Some element types may have specific handling for this value distinct from a missing or null value. + /// + const string LowerInfinityLiteral = "-infinity"; - /// - /// Defined by PostgreSQL to represent an infinite lower bound. - /// Some element types may have specific handling for this value distinct from a missing or null value. - /// - const string LowerInfinityLiteral = "-infinity"; + /// + /// Defined by PostgreSQL to represent an infinite upper bound. + /// Some element types may have specific handling for this value distinct from a missing or null value. + /// + const string UpperInfinityLiteral = "infinity"; - /// - /// Defined by PostgreSQL to represent an infinite upper bound. - /// Some element types may have specific handling for this value distinct from a missing or null value. - /// - const string UpperInfinityLiteral = "infinity"; + /// + /// Defined by PostgreSQL to represent an null bound. + /// Some element types may have specific handling for this value distinct from an infinite or missing value. + /// + const string NullLiteral = "null"; - /// - /// Defined by PostgreSQL to represent an null bound. - /// Some element types may have specific handling for this value distinct from an infinite or missing value. - /// - const string NullLiteral = "null"; + /// + /// Defined by PostgreSQL to represent a lower inclusive bound. + /// + const char LowerInclusiveBound = '['; - /// - /// Defined by PostgreSQL to represent a lower inclusive bound. - /// - const char LowerInclusiveBound = '['; + /// + /// Defined by PostgreSQL to represent a lower exclusive bound. + /// + const char LowerExclusiveBound = '('; - /// - /// Defined by PostgreSQL to represent a lower exclusive bound. - /// - const char LowerExclusiveBound = '('; + /// + /// Defined by PostgreSQL to represent an upper inclusive bound. + /// + const char UpperInclusiveBound = ']'; - /// - /// Defined by PostgreSQL to represent an upper inclusive bound. - /// - const char UpperInclusiveBound = ']'; + /// + /// Defined by PostgreSQL to represent an upper exclusive bound. + /// + const char UpperExclusiveBound = ')'; - /// - /// Defined by PostgreSQL to represent an upper exclusive bound. - /// - const char UpperExclusiveBound = ')'; + /// + /// Defined by PostgreSQL to separate the values for the upper and lower bounds. + /// + const char BoundSeparator = ','; - /// - /// Defined by PostgreSQL to separate the values for the upper and lower bounds. - /// - const char BoundSeparator = ','; + /// + /// The used by to convert bounds into . + /// + static TypeConverter? BoundConverter; - /// - /// The used by to convert bounds into . - /// - static readonly TypeConverter BoundConverter = TypeDescriptor.GetConverter(typeof(T)); + /// + /// True if implements ; otherwise, false. + /// + static readonly bool HasEquatableBounds = typeof(IEquatable).IsAssignableFrom(typeof(T)); - /// - /// True if implements ; otherwise, false. - /// - static readonly bool HasEquatableBounds = typeof(IEquatable).IsAssignableFrom(typeof(T)); + /// + /// Represents the empty range. This field is read-only. + /// + public static readonly NpgsqlRange Empty = new(default, default, RangeFlags.Empty); - /// - /// Represents the empty range. This field is read-only. - /// - public static readonly NpgsqlRange Empty = new NpgsqlRange(default, default, RangeFlags.Empty); + /// + /// The lower bound of the range. Only valid when is false. + /// + [MaybeNull, AllowNull] + public T LowerBound { get; } - /// - /// The lower bound of the range. Only valid when is false. - /// - [MaybeNull, AllowNull] - public T LowerBound { get; } + /// + /// The upper bound of the range. Only valid when is false. + /// + [MaybeNull, AllowNull] + public T UpperBound { get; } - /// - /// The upper bound of the range. Only valid when is false. - /// - [MaybeNull, AllowNull] - public T UpperBound { get; } + /// + /// The characteristics of the boundaries. + /// + internal readonly RangeFlags Flags; - /// - /// The characteristics of the boundaries. - /// - internal readonly RangeFlags Flags; + /// + /// True if the lower bound is part of the range (i.e. inclusive); otherwise, false. + /// + public bool LowerBoundIsInclusive => (Flags & RangeFlags.LowerBoundInclusive) != 0; - /// - /// True if the lower bound is part of the range (i.e. inclusive); otherwise, false. - /// - public bool LowerBoundIsInclusive => (Flags & RangeFlags.LowerBoundInclusive) != 0; + /// + /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. + /// + public bool UpperBoundIsInclusive => (Flags & RangeFlags.UpperBoundInclusive) != 0; - /// - /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. - /// - public bool UpperBoundIsInclusive => (Flags & RangeFlags.UpperBoundInclusive) != 0; + /// + /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. + /// + public bool LowerBoundInfinite => (Flags & RangeFlags.LowerBoundInfinite) != 0; - /// - /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. - /// - public bool LowerBoundInfinite => (Flags & RangeFlags.LowerBoundInfinite) != 0; + /// + /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. + /// + public bool UpperBoundInfinite => (Flags & RangeFlags.UpperBoundInfinite) != 0; - /// - /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. - /// - public bool UpperBoundInfinite => (Flags & RangeFlags.UpperBoundInfinite) != 0; + /// + /// True if the range is empty; otherwise, false. + /// + public bool IsEmpty => (Flags & RangeFlags.Empty) != 0; - /// - /// True if the range is empty; otherwise, false. - /// - public bool IsEmpty => (Flags & RangeFlags.Empty) != 0; + /// + /// Constructs an with inclusive and definite bounds. + /// + /// The lower bound of the range. + /// The upper bound of the range. + public NpgsqlRange([AllowNull] T lowerBound, [AllowNull] T upperBound) + : this(lowerBound, true, false, upperBound, true, false) { } - /// - /// Constructs an with inclusive and definite bounds. - /// - /// The lower bound of the range. - /// The upper bound of the range. - public NpgsqlRange([AllowNull] T lowerBound, [AllowNull] T upperBound) - : this(lowerBound, true, false, upperBound, true, false) { } + /// + /// Constructs an with definite bounds. + /// + /// The lower bound of the range. + /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. + /// The upper bound of the range. + /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. + public NpgsqlRange( + [AllowNull] T lowerBound, bool lowerBoundIsInclusive, + [AllowNull] T upperBound, bool upperBoundIsInclusive) + : this(lowerBound, lowerBoundIsInclusive, false, upperBound, upperBoundIsInclusive, false) { } - /// - /// Constructs an with definite bounds. - /// - /// The lower bound of the range. - /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. - /// The upper bound of the range. - /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. - public NpgsqlRange( - [AllowNull] T lowerBound, bool lowerBoundIsInclusive, - [AllowNull] T upperBound, bool upperBoundIsInclusive) - : this(lowerBound, lowerBoundIsInclusive, false, upperBound, upperBoundIsInclusive, false) { } + /// + /// Constructs an . + /// + /// The lower bound of the range. + /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. + /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. + /// The upper bound of the range. + /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. + /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. + public NpgsqlRange( + [AllowNull] T lowerBound, bool lowerBoundIsInclusive, bool lowerBoundInfinite, + [AllowNull] T upperBound, bool upperBoundIsInclusive, bool upperBoundInfinite) + : this( + lowerBound, + upperBound, + EvaluateBoundaryFlags( + lowerBoundIsInclusive, + upperBoundIsInclusive, + lowerBoundInfinite, + upperBoundInfinite)) { } - /// - /// Constructs an . - /// - /// The lower bound of the range. - /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. - /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. - /// The upper bound of the range. - /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. - /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. - public NpgsqlRange( - [AllowNull] T lowerBound, bool lowerBoundIsInclusive, bool lowerBoundInfinite, - [AllowNull] T upperBound, bool upperBoundIsInclusive, bool upperBoundInfinite) - : this( - lowerBound, - upperBound, - EvaluateBoundaryFlags( - lowerBoundIsInclusive, - upperBoundIsInclusive, - lowerBoundInfinite, - upperBoundInfinite)) { } + /// + /// Constructs an . + /// + /// The lower bound of the range. + /// The upper bound of the range. + /// The characteristics of the range boundaries. + internal NpgsqlRange([AllowNull] T lowerBound, [AllowNull] T upperBound, RangeFlags flags) : this() + { + // TODO: We need to check if the bounds are implicitly empty. E.g. '(1,1)' or '(0,0]'. + // See: https://github.com/npgsql/npgsql/issues/1943. - /// - /// Constructs an . - /// - /// The lower bound of the range. - /// The upper bound of the range. - /// The characteristics of the range boundaries. - internal NpgsqlRange([AllowNull] T lowerBound, [AllowNull] T upperBound, RangeFlags flags) : this() - { - // TODO: We need to check if the bounds are implicitly empty. E.g. '(1,1)' or '(0,0]'. - // See: https://github.com/npgsql/npgsql/issues/1943. - - LowerBound = (flags & RangeFlags.LowerBoundInfinite) != 0 ? default : lowerBound; - UpperBound = (flags & RangeFlags.UpperBoundInfinite) != 0 ? default : upperBound; - Flags = flags; - - if (IsEmptyRange(LowerBound, UpperBound, Flags)) - { - LowerBound = default!; - UpperBound = default!; - Flags = RangeFlags.Empty; - } - } + LowerBound = (flags & RangeFlags.LowerBoundInfinite) != 0 ? default : lowerBound; + UpperBound = (flags & RangeFlags.UpperBoundInfinite) != 0 ? default : upperBound; + Flags = flags; - /// - /// Attempts to determine if the range is malformed or implicitly empty. - /// - /// The lower bound of the range. - /// The upper bound of the range. - /// The characteristics of the range boundaries. - /// - /// True if the range is implicitly empty; otherwise, false. - /// - static bool IsEmptyRange([AllowNull] T lowerBound, [AllowNull] T upperBound, RangeFlags flags) + if (IsEmptyRange(LowerBound, UpperBound, Flags)) { - // --------------------------------------------------------------------------------- - // We only want to check for those conditions that are unambiguously erroneous: - // 1. The bounds must not be default values (including null). - // 2. The bounds must be definite (non-infinite). - // 3. The bounds must be inclusive. - // 4. The bounds must be considered equal. - // - // See: - // - https://github.com/npgsql/npgsql/pull/1939 - // - https://github.com/npgsql/npgsql/issues/1943 - // --------------------------------------------------------------------------------- - - if ((flags & RangeFlags.Empty) == RangeFlags.Empty) - return true; - - if ((flags & RangeFlags.Infinite) == RangeFlags.Infinite) - return false; + LowerBound = default!; + UpperBound = default!; + Flags = RangeFlags.Empty; + } + } - if ((flags & RangeFlags.Inclusive) == RangeFlags.Inclusive) - return false; + /// + /// Attempts to determine if the range is malformed or implicitly empty. + /// + /// The lower bound of the range. + /// The upper bound of the range. + /// The characteristics of the range boundaries. + /// + /// True if the range is implicitly empty; otherwise, false. + /// + static bool IsEmptyRange([AllowNull] T lowerBound, [AllowNull] T upperBound, RangeFlags flags) + { + // --------------------------------------------------------------------------------- + // We only want to check for those conditions that are unambiguously erroneous: + // 1. The bounds must not be default values (including null). + // 2. The bounds must be definite (non-infinite). + // 3. The bounds must be inclusive. + // 4. The bounds must be considered equal. + // + // See: + // - https://github.com/npgsql/npgsql/pull/1939 + // - https://github.com/npgsql/npgsql/issues/1943 + // --------------------------------------------------------------------------------- - if (lowerBound is null || upperBound is null) - return false; + if ((flags & RangeFlags.Empty) == RangeFlags.Empty) + return true; - if (!HasEquatableBounds) - return lowerBound?.Equals(upperBound) ?? false; + if ((flags & RangeFlags.Infinite) == RangeFlags.Infinite) + return false; - var lower = (IEquatable)lowerBound; - var upper = (IEquatable)upperBound; + if ((flags & RangeFlags.Inclusive) == RangeFlags.Inclusive) + return false; - return !lower.Equals(default!) && !upper.Equals(default!) && lower.Equals(upperBound); - } + if (lowerBound is null || upperBound is null) + return false; - /// - /// Evaluates the boundary flags. - /// - /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. - /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. - /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. - /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. - /// - /// The boundary characteristics. - /// - static RangeFlags EvaluateBoundaryFlags(bool lowerBoundIsInclusive, bool upperBoundIsInclusive, bool lowerBoundInfinite, bool upperBoundInfinite) - { - var result = RangeFlags.None; - - // This is the only place flags are calculated. - if (lowerBoundIsInclusive) - result |= RangeFlags.LowerBoundInclusive; - if (upperBoundIsInclusive) - result |= RangeFlags.UpperBoundInclusive; - if (lowerBoundInfinite) - result |= RangeFlags.LowerBoundInfinite; - if (upperBoundInfinite) - result |= RangeFlags.UpperBoundInfinite; - - // PostgreSQL automatically converts inclusive-infinities. - // See: https://www.postgresql.org/docs/current/static/rangetypes.html#RANGETYPES-INFINITE - if ((result & RangeFlags.LowerInclusiveInfinite) == RangeFlags.LowerInclusiveInfinite) - result &= ~RangeFlags.LowerBoundInclusive; - - if ((result & RangeFlags.UpperInclusiveInfinite) == RangeFlags.UpperInclusiveInfinite) - result &= ~RangeFlags.UpperBoundInclusive; - - return result; - } + if (!HasEquatableBounds) + return lowerBound?.Equals(upperBound) ?? false; - /// - /// Indicates whether the on the left is equal to the on the right. - /// - /// The on the left. - /// The on the right. - /// - /// True if the on the left is equal to the on the right; otherwise, false. - /// - public static bool operator ==(NpgsqlRange x, NpgsqlRange y) => x.Equals(y); + var lower = (IEquatable)lowerBound; + var upper = (IEquatable)upperBound; - /// - /// Indicates whether the on the left is not equal to the on the right. - /// - /// The on the left. - /// The on the right. - /// - /// True if the on the left is not equal to the on the right; otherwise, false. - /// - public static bool operator !=(NpgsqlRange x, NpgsqlRange y) => !x.Equals(y); + return !lower.Equals(default!) && !upper.Equals(default!) && lower.Equals(upperBound); + } - /// - public override bool Equals(object? o) => o is NpgsqlRange range && Equals(range); + /// + /// Evaluates the boundary flags. + /// + /// True if the lower bound is is part of the range (i.e. inclusive); otherwise, false. + /// True if the lower bound is indefinite (i.e. infinite or unbounded); otherwise, false. + /// True if the upper bound is part of the range (i.e. inclusive); otherwise, false. + /// True if the upper bound is indefinite (i.e. infinite or unbounded); otherwise, false. + /// + /// The boundary characteristics. + /// + static RangeFlags EvaluateBoundaryFlags(bool lowerBoundIsInclusive, bool upperBoundIsInclusive, bool lowerBoundInfinite, bool upperBoundInfinite) + { + var result = RangeFlags.None; + + // This is the only place flags are calculated. + if (lowerBoundIsInclusive) + result |= RangeFlags.LowerBoundInclusive; + if (upperBoundIsInclusive) + result |= RangeFlags.UpperBoundInclusive; + if (lowerBoundInfinite) + result |= RangeFlags.LowerBoundInfinite; + if (upperBoundInfinite) + result |= RangeFlags.UpperBoundInfinite; + + // PostgreSQL automatically converts inclusive-infinities. + // See: https://www.postgresql.org/docs/current/static/rangetypes.html#RANGETYPES-INFINITE + if ((result & RangeFlags.LowerInclusiveInfinite) == RangeFlags.LowerInclusiveInfinite) + result &= ~RangeFlags.LowerBoundInclusive; + + if ((result & RangeFlags.UpperInclusiveInfinite) == RangeFlags.UpperInclusiveInfinite) + result &= ~RangeFlags.UpperBoundInclusive; + + return result; + } - /// - public bool Equals(NpgsqlRange other) - { - if (Flags != other.Flags) - return false; + /// + /// Indicates whether the on the left is equal to the on the right. + /// + /// The on the left. + /// The on the right. + /// + /// True if the on the left is equal to the on the right; otherwise, false. + /// + public static bool operator ==(NpgsqlRange x, NpgsqlRange y) => x.Equals(y); - if (HasEquatableBounds) - { - var lowerEqual = LowerBound is null - ? other.LowerBound is null - : !(other.LowerBound is null) && ((IEquatable)LowerBound).Equals(other.LowerBound); + /// + /// Indicates whether the on the left is not equal to the on the right. + /// + /// The on the left. + /// The on the right. + /// + /// True if the on the left is not equal to the on the right; otherwise, false. + /// + public static bool operator !=(NpgsqlRange x, NpgsqlRange y) => !x.Equals(y); + + /// + public override bool Equals(object? o) => o is NpgsqlRange range && Equals(range); + + /// + public bool Equals(NpgsqlRange other) + { + if (Flags != other.Flags) + return false; - if (!lowerEqual) - return false; + if (HasEquatableBounds) + { + var lowerEqual = LowerBound is null + ? other.LowerBound is null + : !(other.LowerBound is null) && ((IEquatable)LowerBound).Equals(other.LowerBound); - return UpperBound is null - ? other.UpperBound is null - : !(other.UpperBound is null) && ((IEquatable)UpperBound).Equals(other.UpperBound); - } + if (!lowerEqual) + return false; - return - (LowerBound?.Equals(other.LowerBound) ?? other.LowerBound is null) && - (UpperBound?.Equals(other.UpperBound) ?? other.UpperBound is null); + return UpperBound is null + ? other.UpperBound is null + : !(other.UpperBound is null) && ((IEquatable)UpperBound).Equals(other.UpperBound); } - /// - public override int GetHashCode() - => unchecked((397 * (int)Flags) ^ (397 * (LowerBound?.GetHashCode() ?? 0)) ^ (397 * (UpperBound?.GetHashCode() ?? 0))); + return + (LowerBound?.Equals(other.LowerBound) ?? other.LowerBound is null) && + (UpperBound?.Equals(other.UpperBound) ?? other.UpperBound is null); + } - /// - public override string ToString() - { - if (IsEmpty) - return EmptyLiteral; + /// + public override int GetHashCode() + => unchecked((397 * (int)Flags) ^ (397 * (LowerBound?.GetHashCode() ?? 0)) ^ (397 * (UpperBound?.GetHashCode() ?? 0))); - var sb = new StringBuilder(); + /// + public override string ToString() + { + if (IsEmpty) + return EmptyLiteral; - sb.Append(LowerBoundIsInclusive ? LowerInclusiveBound : LowerExclusiveBound); + var sb = new StringBuilder(); - if (!LowerBoundInfinite) - sb.Append(LowerBound); + sb.Append(LowerBoundIsInclusive ? LowerInclusiveBound : LowerExclusiveBound); - sb.Append(BoundSeparator); + if (!LowerBoundInfinite) + sb.Append(LowerBound); - if (!UpperBoundInfinite) - sb.Append(UpperBound); + sb.Append(BoundSeparator); - sb.Append(UpperBoundIsInclusive ? UpperInclusiveBound : UpperExclusiveBound); + if (!UpperBoundInfinite) + sb.Append(UpperBound); - return sb.ToString(); - } + sb.Append(UpperBoundIsInclusive ? UpperInclusiveBound : UpperExclusiveBound); - // TODO: rewrite this to use ReadOnlySpan for the 4.1 release - /// - /// Parses the well-known text representation of a PostgreSQL range type into a . - /// - /// A PosgreSQL range type in a well-known text format. - /// - /// The represented by the . - /// - /// - /// Malformed range literal. - /// - /// - /// Malformed range literal. Missing left parenthesis or bracket. - /// - /// - /// Malformed range literal. Missing right parenthesis or bracket. - /// - /// - /// Malformed range literal. Missing comma after lower bound. - /// - /// - /// See: https://www.postgresql.org/docs/current/static/rangetypes.html - /// - public static NpgsqlRange Parse(string value) - { - if (value is null) - throw new ArgumentNullException(nameof(value)); + return sb.ToString(); + } - value = value.Trim(); + // TODO: rewrite this to use ReadOnlySpan for the 4.1 release + /// + /// Parses the well-known text representation of a PostgreSQL range type into a . + /// + /// A PosgreSQL range type in a well-known text format. + /// + /// The represented by the . + /// + /// + /// Malformed range literal. + /// + /// + /// Malformed range literal. Missing left parenthesis or bracket. + /// + /// + /// Malformed range literal. Missing right parenthesis or bracket. + /// + /// + /// Malformed range literal. Missing comma after lower bound. + /// + /// + /// See: https://www.postgresql.org/docs/current/static/rangetypes.html + /// + [RequiresUnreferencedCode("Parse implementations for certain types of T may require members that have been trimmed.")] + public static NpgsqlRange Parse(string value) + { + if (value is null) + throw new ArgumentNullException(nameof(value)); - if (value.Length < 3) - throw new FormatException("Malformed range literal."); + value = value.Trim(); - if (string.Equals(value, EmptyLiteral, StringComparison.OrdinalIgnoreCase)) - return Empty; + if (value.Length < 3) + throw new FormatException("Malformed range literal."); - var lowerInclusive = value[0] == LowerInclusiveBound; - var lowerExclusive = value[0] == LowerExclusiveBound; + if (string.Equals(value, EmptyLiteral, StringComparison.OrdinalIgnoreCase)) + return Empty; - if (!lowerInclusive && !lowerExclusive) - throw new FormatException("Malformed range literal. Missing left parenthesis or bracket."); + var lowerInclusive = value[0] == LowerInclusiveBound; + var lowerExclusive = value[0] == LowerExclusiveBound; - var upperInclusive = value[value.Length - 1] == UpperInclusiveBound; - var upperExclusive = value[value.Length - 1] == UpperExclusiveBound; + if (!lowerInclusive && !lowerExclusive) + throw new FormatException("Malformed range literal. Missing left parenthesis or bracket."); - if (!upperInclusive && !upperExclusive) - throw new FormatException("Malformed range literal. Missing right parenthesis or bracket."); + var upperInclusive = value[value.Length - 1] == UpperInclusiveBound; + var upperExclusive = value[value.Length - 1] == UpperExclusiveBound; - var separator = value.IndexOf(BoundSeparator); + if (!upperInclusive && !upperExclusive) + throw new FormatException("Malformed range literal. Missing right parenthesis or bracket."); - if (separator == -1) - throw new FormatException("Malformed range literal. Missing comma after lower bound."); + var separator = value.IndexOf(BoundSeparator); - if (separator != value.LastIndexOf(BoundSeparator)) - // TODO: this should be replaced to handle quoted commas. - throw new NotSupportedException("Ranges with embedded commas are not currently supported."); + if (separator == -1) + throw new FormatException("Malformed range literal. Missing comma after lower bound."); - // Skip the opening bracket and stop short of the separator. - var lowerSegment = value.Substring(1, separator - 1).Trim(); + if (separator != value.LastIndexOf(BoundSeparator)) + // TODO: this should be replaced to handle quoted commas. + throw new NotSupportedException("Ranges with embedded commas are not currently supported."); - // Skip past the separator and stop short of the closing bracket. - var upperSegment = value.Substring(separator + 1, value.Length - separator - 2).Trim(); + // Skip the opening bracket and stop short of the separator. + var lowerSegment = value.Substring(1, separator - 1).Trim(); - // TODO: infinity literals have special meaning to some types (e.g. daterange), we should consider a flag to track them. + // Skip past the separator and stop short of the closing bracket. + var upperSegment = value.Substring(separator + 1, value.Length - separator - 2).Trim(); - var lowerInfinite = - lowerSegment.Length == 0 || - string.Equals(lowerSegment, string.Empty, StringComparison.OrdinalIgnoreCase) || - string.Equals(lowerSegment, NullLiteral, StringComparison.OrdinalIgnoreCase) || - string.Equals(lowerSegment, LowerInfinityLiteral, StringComparison.OrdinalIgnoreCase); + // TODO: infinity literals have special meaning to some types (e.g. daterange), we should consider a flag to track them. - var upperInfinite = - upperSegment.Length == 0 || - string.Equals(upperSegment, string.Empty, StringComparison.OrdinalIgnoreCase) || - string.Equals(upperSegment, NullLiteral, StringComparison.OrdinalIgnoreCase) || - string.Equals(upperSegment, UpperInfinityLiteral, StringComparison.OrdinalIgnoreCase); + var lowerInfinite = + lowerSegment.Length == 0 || + string.Equals(lowerSegment, string.Empty, StringComparison.OrdinalIgnoreCase) || + string.Equals(lowerSegment, NullLiteral, StringComparison.OrdinalIgnoreCase) || + string.Equals(lowerSegment, LowerInfinityLiteral, StringComparison.OrdinalIgnoreCase); - T lower = lowerInfinite ? default : (T)BoundConverter.ConvertFromString(lowerSegment); - T upper = upperInfinite ? default : (T)BoundConverter.ConvertFromString(upperSegment); + var upperInfinite = + upperSegment.Length == 0 || + string.Equals(upperSegment, string.Empty, StringComparison.OrdinalIgnoreCase) || + string.Equals(upperSegment, NullLiteral, StringComparison.OrdinalIgnoreCase) || + string.Equals(upperSegment, UpperInfinityLiteral, StringComparison.OrdinalIgnoreCase); - return new NpgsqlRange(lower, lowerInclusive, lowerInfinite, upper, upperInclusive, upperInfinite); - } + BoundConverter ??= TypeDescriptor.GetConverter(typeof(T)); + var lower = lowerInfinite ? default : (T?)BoundConverter.ConvertFromString(lowerSegment); + var upper = upperInfinite ? default : (T?)BoundConverter.ConvertFromString(upperSegment); -#nullable disable - /// - /// Represents a type converter for . - /// - public class RangeTypeConverter : TypeConverter - { - /// - /// Adds a to the closed form . - /// - public static void Register() => - TypeDescriptor.AddAttributes( - typeof(NpgsqlRange), - new TypeConverterAttribute(typeof(RangeTypeConverter))); - - /// - public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) - => sourceType == typeof(string); - - /// - public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) - => destinationType == typeof(string); - - /// - public override object ConvertFrom(ITypeDescriptorContext context, CultureInfo culture, object value) - => value is string s ? Parse(s) : base.ConvertFrom(context, culture, value); - - /// - public override object ConvertTo(ITypeDescriptorContext context, CultureInfo culture, object value, Type destinationType) - => value.ToString(); - } + return new NpgsqlRange(lower, lowerInclusive, lowerInfinite, upper, upperInclusive, upperInfinite); } -#nullable restore /// - /// Represents characteristics of range type boundaries. + /// Represents a type converter for . /// - /// - /// See: https://www.postgresql.org/docs/current/static/rangetypes.html - /// - [Flags] - enum RangeFlags : byte + [RequiresUnreferencedCode("ConvertFrom implementations for certain types of T may require members that have been trimmed.")] + public class RangeTypeConverter : TypeConverter { /// - /// The default flag. The range is not empty and has boundaries that are definite and exclusive. + /// Adds a to the closed form . /// - None = 0, + public static void Register() => + TypeDescriptor.AddAttributes( + typeof(NpgsqlRange), + new TypeConverterAttribute(typeof(RangeTypeConverter))); - /// - /// The range is empty. E.g. '(0,0)', 'empty'. - /// - Empty = 1, + /// + public override bool CanConvertFrom(ITypeDescriptorContext? context, Type sourceType) + => sourceType == typeof(string); - /// - /// The lower bound is inclusive. E.g. '[0,5]', '[0,5)', '[0,)'. - /// - LowerBoundInclusive = 2, + /// + public override bool CanConvertTo(ITypeDescriptorContext? context, Type? destinationType) + => destinationType == typeof(string); - /// - /// The upper bound is inclusive. E.g. '[0,5]', '(0,5]', '(,5]'. - /// - UpperBoundInclusive = 4, + /// + public override object? ConvertFrom(ITypeDescriptorContext? context, CultureInfo? culture, object value) + => value is string s ? Parse(s) : base.ConvertFrom(context, culture, value); - /// - /// The lower bound is infinite or indefinite. E.g. '(null,5]', '(-infinity,5]', '(,5]'. - /// - LowerBoundInfinite = 8, + /// + public override object? ConvertTo(ITypeDescriptorContext? context, CultureInfo? culture, object? value, Type destinationType) + => value is null ? string.Empty : value.ToString(); + } +} - /// - /// The upper bound is infinite or indefinite. E.g. '[0,null)', '[0,infinity)', '[0,)'. - /// - UpperBoundInfinite = 16, +/// +/// Represents characteristics of range type boundaries. +/// +/// +/// See: https://www.postgresql.org/docs/current/static/rangetypes.html +/// +[Flags] +enum RangeFlags : byte +{ + /// + /// The default flag. The range is not empty and has boundaries that are definite and exclusive. + /// + None = 0, - /// - /// Both the lower and upper bounds are inclusive. - /// - Inclusive = LowerBoundInclusive | UpperBoundInclusive, + /// + /// The range is empty. E.g. '(0,0)', 'empty'. + /// + Empty = 1, - /// - /// Both the lower and upper bounds are indefinite. - /// - Infinite = LowerBoundInfinite | UpperBoundInfinite, + /// + /// The lower bound is inclusive. E.g. '[0,5]', '[0,5)', '[0,)'. + /// + LowerBoundInclusive = 2, - /// - /// The lower bound is both inclusive and indefinite. This represents an error condition. - /// - LowerInclusiveInfinite = LowerBoundInclusive | LowerBoundInfinite, + /// + /// The upper bound is inclusive. E.g. '[0,5]', '(0,5]', '(,5]'. + /// + UpperBoundInclusive = 4, - /// - /// The upper bound is both inclusive and indefinite. This represents an error condition. - /// - UpperInclusiveInfinite = UpperBoundInclusive | UpperBoundInfinite - } + /// + /// The lower bound is infinite or indefinite. E.g. '(null,5]', '(-infinity,5]', '(,5]'. + /// + LowerBoundInfinite = 8, + + /// + /// The upper bound is infinite or indefinite. E.g. '[0,null)', '[0,infinity)', '[0,)'. + /// + UpperBoundInfinite = 16, + + /// + /// Both the lower and upper bounds are inclusive. + /// + Inclusive = LowerBoundInclusive | UpperBoundInclusive, + + /// + /// Both the lower and upper bounds are indefinite. + /// + Infinite = LowerBoundInfinite | UpperBoundInfinite, + + /// + /// The lower bound is both inclusive and indefinite. This represents an error condition. + /// + LowerInclusiveInfinite = LowerBoundInclusive | LowerBoundInfinite, + + /// + /// The upper bound is both inclusive and indefinite. This represents an error condition. + /// + UpperInclusiveInfinite = UpperBoundInclusive | UpperBoundInfinite } diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTimeSpan.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTimeSpan.cs deleted file mode 100644 index 4fc958e147..0000000000 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTimeSpan.cs +++ /dev/null @@ -1,905 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Text; -using Npgsql; - -// ReSharper disable once CheckNamespace -namespace NpgsqlTypes -{ - /// - /// Represents the PostgreSQL interval datatype. - /// - /// PostgreSQL differs from .NET in how it's interval type doesn't assume 24 hours in a day - /// (to deal with 23- and 25-hour days caused by daylight savings adjustments) and has a concept - /// of months that doesn't exist in .NET's class. (Neither datatype - /// has any concessions for leap-seconds). - /// For most uses just casting to and from TimeSpan will work correctly — in particular, - /// the results of subtracting one or the PostgreSQL date, time and - /// timestamp types from another should be the same whether you do so in .NET or PostgreSQL — - /// but if the handling of days and months in PostgreSQL is important to your application then you - /// should use this class instead of . - /// If you don't know whether these differences are important to your application, they - /// probably arent! Just use and do not use this class directly ☺ - /// To avoid forcing unnecessary provider-specific concerns on users who need not be concerned - /// with them a call to on a field containing an - /// value will return a rather than an - /// . If you need the extra functionality of - /// then use . - /// - /// - /// - /// - /// - [Serializable] - public readonly struct NpgsqlTimeSpan : IComparable, IComparer, IEquatable, IComparable, - IComparer - { - #region Constants - - /// - /// Represents the number of ticks (100ns periods) in one microsecond. This field is constant. - /// - public const long TicksPerMicrosecond = TimeSpan.TicksPerMillisecond / 1000; - - /// - /// Represents the number of ticks (100ns periods) in one millisecond. This field is constant. - /// - public const long TicksPerMillsecond = TimeSpan.TicksPerMillisecond; - - /// - /// Represents the number of ticks (100ns periods) in one second. This field is constant. - /// - public const long TicksPerSecond = TimeSpan.TicksPerSecond; - - /// - /// Represents the number of ticks (100ns periods) in one minute. This field is constant. - /// - public const long TicksPerMinute = TimeSpan.TicksPerMinute; - - /// - /// Represents the number of ticks (100ns periods) in one hour. This field is constant. - /// - public const long TicksPerHour = TimeSpan.TicksPerHour; - - /// - /// Represents the number of ticks (100ns periods) in one day. This field is constant. - /// - public const long TicksPerDay = TimeSpan.TicksPerDay; - - /// - /// Represents the number of hours in one day (assuming no daylight savings adjustments). This field is constant. - /// - public const int HoursPerDay = 24; - - /// - /// Represents the number of days assumed in one month if month justification or unjustifcation is performed. - /// This is set to 30 for consistency with PostgreSQL. Note that this is means that month adjustments cause - /// a year to be taken as 30 × 12 = 360 rather than 356/366 days. - /// - public const int DaysPerMonth = 30; - - /// - /// Represents the number of ticks (100ns periods) in one day, assuming 30 days per month. - /// - public const long TicksPerMonth = TicksPerDay * DaysPerMonth; - - /// - /// Represents the number of months in a year. This field is constant. - /// - public const int MonthsPerYear = 12; - - /// - /// Represents the maximum . This field is read-only. - /// - public static readonly NpgsqlTimeSpan MaxValue = new NpgsqlTimeSpan(long.MaxValue); - - /// - /// Represents the minimum . This field is read-only. - /// - public static readonly NpgsqlTimeSpan MinValue = new NpgsqlTimeSpan(long.MinValue); - - /// - /// Represents the zero . This field is read-only. - /// - public static readonly NpgsqlTimeSpan Zero = new NpgsqlTimeSpan(0); - - #endregion - - readonly int _months; - readonly int _days; - readonly long _ticks; - - #region Constructors - - /// - /// Initializes a new to the specified number of ticks. - /// - /// A time period expressed in 100ns units. - public NpgsqlTimeSpan(long ticks) - : this(new TimeSpan(ticks)) - { - } - - /// - /// Initializes a new to hold the same time as a - /// - /// A time period expressed in a - public NpgsqlTimeSpan(TimeSpan timespan) - : this(0, timespan.Days, timespan.Ticks - (TicksPerDay * timespan.Days)) - { - } - - /// - /// Initializes a new to the specified number of months, days - /// & ticks. - /// - /// Number of months. - /// Number of days. - /// Number of 100ns units. - public NpgsqlTimeSpan(int months, int days, long ticks) - { - _months = months; - _days = days; - _ticks = ticks; - } - - /// - /// Initializes a new to the specified number of - /// days, hours, minutes & seconds. - /// - /// Number of days. - /// Number of hours. - /// Number of minutes. - /// Number of seconds. - public NpgsqlTimeSpan(int days, int hours, int minutes, int seconds) - : this(0, days, new TimeSpan(hours, minutes, seconds).Ticks) - { - } - - /// - /// Initializes a new to the specified number of - /// days, hours, minutes, seconds & milliseconds. - /// - /// Number of days. - /// Number of hours. - /// Number of minutes. - /// Number of seconds. - /// Number of milliseconds. - public NpgsqlTimeSpan(int days, int hours, int minutes, int seconds, int milliseconds) - : this(0, days, new TimeSpan(0, hours, minutes, seconds, milliseconds).Ticks) - { - } - - /// - /// Initializes a new to the specified number of - /// months, days, hours, minutes, seconds & milliseconds. - /// - /// Number of months. - /// Number of days. - /// Number of hours. - /// Number of minutes. - /// Number of seconds. - /// Number of milliseconds. - public NpgsqlTimeSpan(int months, int days, int hours, int minutes, int seconds, int milliseconds) - : this(months, days, new TimeSpan(0, hours, minutes, seconds, milliseconds).Ticks) - { - } - - /// - /// Initializes a new to the specified number of - /// years, months, days, hours, minutes, seconds & milliseconds. - /// Years are calculated exactly equivalent to 12 months. - /// - /// Number of years. - /// Number of months. - /// Number of days. - /// Number of hours. - /// Number of minutes. - /// Number of seconds. - /// Number of milliseconds. - public NpgsqlTimeSpan(int years, int months, int days, int hours, int minutes, int seconds, int milliseconds) - : this(years * 12 + months, days, new TimeSpan(0, hours, minutes, seconds, milliseconds).Ticks) - { - } - - #endregion - - #region Whole Parts - - /// - /// The total number of ticks(100ns units) contained. This is the resolution of the - /// type. This ignores the number of days and - /// months held. If you want them included use first. - /// The resolution of the PostgreSQL - /// interval type is by default 1µs = 1,000 ns. It may be smaller as follows: - /// - /// - /// interval(0) - /// resolution of 1s (1 second) - /// - /// - /// interval(1) - /// resolution of 100ms = 0.1s (100 milliseconds) - /// - /// - /// interval(2) - /// resolution of 10ms = 0.01s (10 milliseconds) - /// - /// - /// interval(3) - /// resolution of 1ms = 0.001s (1 millisecond) - /// - /// - /// interval(4) - /// resolution of 100µs = 0.0001s (100 microseconds) - /// - /// - /// interval(5) - /// resolution of 10µs = 0.00001s (10 microseconds) - /// - /// - /// interval(6) or interval - /// resolution of 1µs = 0.000001s (1 microsecond) - /// - /// - /// As such, if the 100-nanosecond resolution is significant to an application, a PostgreSQL interval will - /// not suffice for those purposes. - /// In more frequent cases though, the resolution of the interval suffices. - /// will always suffice to handle the resolution of any interval value, and upon - /// writing to the database, will be rounded to the resolution used. - /// - /// The number of ticks in the instance. - /// - public long Ticks => _ticks; - - /// - /// Gets the number of whole microseconds held in the instance. - /// An in the range [-999999, 999999]. - /// - public int Microseconds => (int)((_ticks / 10) % 1000000); - - /// - /// Gets the number of whole milliseconds held in the instance. - /// An in the range [-999, 999]. - /// - public int Milliseconds => (int)((_ticks / TicksPerMillsecond) % 1000); - - /// - /// Gets the number of whole seconds held in the instance. - /// An in the range [-59, 59]. - /// - public int Seconds => (int)((_ticks / TicksPerSecond) % 60); - - /// - /// Gets the number of whole minutes held in the instance. - /// An in the range [-59, 59]. - /// - public int Minutes => (int)((_ticks / TicksPerMinute) % 60); - - /// - /// Gets the number of whole hours held in the instance. - /// Note that this can be less than -23 or greater than 23 unless - /// has been used to produce this instance. - /// - public int Hours => (int)(_ticks / TicksPerHour); - - /// - /// Gets the number of days held in the instance. - /// Note that this does not pay attention to a time component with -24 or less hours or - /// 24 or more hours, unless has been called to produce this instance. - /// - public int Days => _days; - - /// - /// Gets the number of months held in the instance. - /// Note that this does not pay attention to a day component with -30 or less days or - /// 30 or more days, unless has been called to produce this instance. - /// - public int Months => _months; - - /// - /// Returns a representing the time component of the instance. - /// Note that this may have a value beyond the range ±23:59:59.9999999 unless - /// has been called to produce this instance. - /// - public TimeSpan Time => new TimeSpan(_ticks); - - #endregion - - #region Total Parts - - /// - /// The total number of ticks (100ns units) in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public long TotalTicks => Ticks + Days * TicksPerDay + Months * TicksPerMonth; - - /// - /// The total number of microseconds in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalMicroseconds => TotalTicks / 10d; - - /// - /// The total number of milliseconds in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalMilliseconds => TotalTicks / (double)TicksPerMillsecond; - - /// - /// The total number of seconds in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalSeconds => TotalTicks / (double)TicksPerSecond; - - /// - /// The total number of minutes in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalMinutes => TotalTicks / (double)TicksPerMinute; - - /// - /// The total number of hours in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalHours => TotalTicks / (double)TicksPerHour; - - /// - /// The total number of days in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalDays => TotalTicks / (double)TicksPerDay; - - /// - /// The total number of months in the instance, assuming 24 hours in each day and - /// 30 days in a month. - /// - public double TotalMonths => TotalTicks / (double)TicksPerMonth; - - #endregion - - #region Create From Part - - /// - /// Creates an from a number of ticks. - /// - /// The number of ticks (100ns units) in the interval. - /// A d with the given number of ticks. - public static NpgsqlTimeSpan FromTicks(long ticks) => new NpgsqlTimeSpan(ticks).Canonicalize(); - - /// - /// Creates an from a number of microseconds. - /// - /// The number of microseconds in the interval. - /// A d with the given number of microseconds. - public static NpgsqlTimeSpan FromMicroseconds(double micro) => FromTicks((long)(micro * TicksPerMicrosecond)); - - /// - /// Creates an from a number of milliseconds. - /// - /// The number of milliseconds in the interval. - /// A d with the given number of milliseconds. - public static NpgsqlTimeSpan FromMilliseconds(double milli) => FromTicks((long)(milli * TicksPerMillsecond)); - - /// - /// Creates an from a number of seconds. - /// - /// The number of seconds in the interval. - /// A d with the given number of seconds. - public static NpgsqlTimeSpan FromSeconds(double seconds) => FromTicks((long)(seconds * TicksPerSecond)); - - /// - /// Creates an from a number of minutes. - /// - /// The number of minutes in the interval. - /// A d with the given number of minutes. - public static NpgsqlTimeSpan FromMinutes(double minutes) => FromTicks((long)(minutes * TicksPerMinute)); - - /// - /// Creates an from a number of hours. - /// - /// The number of hours in the interval. - /// A d with the given number of hours. - public static NpgsqlTimeSpan FromHours(double hours) => FromTicks((long)(hours * TicksPerHour)); - - /// - /// Creates an from a number of days. - /// - /// The number of days in the interval. - /// A d with the given number of days. - public static NpgsqlTimeSpan FromDays(double days) => FromTicks((long)(days * TicksPerDay)); - - /// - /// Creates an from a number of months. - /// - /// The number of months in the interval. - /// A d with the given number of months. - public static NpgsqlTimeSpan FromMonths(double months) => FromTicks((long)(months * TicksPerMonth)); - - #endregion - - #region Arithmetic - - /// - /// Adds another interval to this instance and returns the result. - /// - /// An to add to this instance. - /// An whose values are the sums of the two instances. - public NpgsqlTimeSpan Add(in NpgsqlTimeSpan interval) - => new NpgsqlTimeSpan(Months + interval.Months, Days + interval.Days, Ticks + interval.Ticks); - - /// - /// Subtracts another interval from this instance and returns the result. - /// - /// An to subtract from this instance. - /// An whose values are the differences of the two instances. - public NpgsqlTimeSpan Subtract(in NpgsqlTimeSpan interval) - => new NpgsqlTimeSpan(Months - interval.Months, Days - interval.Days, Ticks - interval.Ticks); - - /// - /// Returns an whose value is the negated value of this instance. - /// - /// An whose value is the negated value of this instance. - public NpgsqlTimeSpan Negate() => new NpgsqlTimeSpan(-Months, -Days, -Ticks); - - /// - /// This absolute value of this instance. In the case of some, but not all, components being negative, - /// the rules used for justification are used to determine if the instance is positive or negative. - /// - /// An whose value is the absolute value of this instance. - public NpgsqlTimeSpan Duration() - => UnjustifyInterval().Ticks < 0 ? Negate() : this; - - #endregion - - #region Justification - - /// - /// Equivalent to PostgreSQL's justify_days function. - /// - /// An based on this one, but with any hours outside of the range [-23, 23] - /// converted into days. - public NpgsqlTimeSpan JustifyDays() - { - return new NpgsqlTimeSpan(Months, Days + (int)(Ticks / TicksPerDay), Ticks % TicksPerDay); - } - - /// - /// Opposite to PostgreSQL's justify_days function. - /// - /// An based on this one, but with any days converted to multiples of ±24hours. - public NpgsqlTimeSpan UnjustifyDays() - { - return new NpgsqlTimeSpan(Months, 0, Ticks + Days * TicksPerDay); - } - - /// - /// Equivalent to PostgreSQL's justify_months function. - /// - /// An based on this one, but with any days outside of the range [-30, 30] - /// converted into months. - public NpgsqlTimeSpan JustifyMonths() - { - return new NpgsqlTimeSpan(Months + Days / DaysPerMonth, Days % DaysPerMonth, Ticks); - } - - /// - /// Opposite to PostgreSQL's justify_months function. - /// - /// An based on this one, but with any months converted to multiples of ±30days. - public NpgsqlTimeSpan UnjustifyMonths() - { - return new NpgsqlTimeSpan(0, Days + Months * DaysPerMonth, Ticks); - } - - /// - /// Equivalent to PostgreSQL's justify_interval function. - /// - /// An based on this one, - /// but with any months converted to multiples of ±30days - /// and then with any days converted to multiples of ±24hours - public NpgsqlTimeSpan JustifyInterval() - { - return JustifyMonths().JustifyDays(); - } - - /// - /// Opposite to PostgreSQL's justify_interval function. - /// - /// An based on this one, but with any months converted to multiples of ±30days and then any days converted to multiples of ±24hours; - public NpgsqlTimeSpan UnjustifyInterval() - { - return new NpgsqlTimeSpan(Ticks + Days * TicksPerDay + Months * DaysPerMonth * TicksPerDay); - } - - /// - /// Produces a canonical NpgslInterval with 0 months and hours in the range of [-23, 23]. - /// - /// - /// While the fact that for many purposes, two different instances could be considered - /// equivalent (e.g. one with 2days, 3hours and one with 1day 27hours) there are different possible canonical forms. - /// - /// E.g. we could move all excess hours into days and all excess days into months and have the most readable form, - /// or we could move everything into the ticks and have the form that allows for the easiest arithmetic) the form - /// chosen has two important properties that make it the best choice. - /// First, it is closest two how - /// objects are most often represented. Second, it is compatible with results of many - /// PostgreSQL functions, particularly with age() and the results of subtracting one date, time or timestamp from - /// another. - /// - /// Note that the results of casting a to is - /// canonicalised. - /// - /// - /// An based on this one, but with months converted to multiples of ±30days and with any hours outside of the range [-23, 23] - /// converted into days. - public NpgsqlTimeSpan Canonicalize() - { - return new NpgsqlTimeSpan(0, Days + Months * DaysPerMonth + (int)(Ticks / TicksPerDay), Ticks % TicksPerDay); - } - - #endregion - - #region Casts - - /// - /// Implicit cast of a to an - /// - /// A - /// An eqivalent, canonical, . - public static implicit operator NpgsqlTimeSpan(TimeSpan timespan) => ToNpgsqlTimeSpan(timespan); - - /// - /// Casts a to an . - /// - public static NpgsqlTimeSpan ToNpgsqlTimeSpan(TimeSpan timespan) => new NpgsqlTimeSpan(timespan).Canonicalize(); - - /// - /// Explicit cast of an to a . - /// - /// A . - /// An equivalent . - public static explicit operator TimeSpan(NpgsqlTimeSpan interval) - => ToTimeSpan(interval); - - /// - /// Casts an to a . - /// - public static TimeSpan ToTimeSpan(in NpgsqlTimeSpan interval) - => new TimeSpan(interval.Ticks + interval.Days * TicksPerDay + interval.Months * DaysPerMonth * TicksPerDay); - - #endregion - - #region Comparison - - /// - /// Returns true if another is exactly the same as this instance. - /// - /// An for comparison. - /// true if the two instances are exactly the same, - /// false otherwise. - public bool Equals(NpgsqlTimeSpan other) - => Ticks == other.Ticks && Days == other.Days && Months == other.Months; - - /// - /// Returns true if another object is an , that is exactly the same as - /// this instance - /// - /// An for comparison. - /// true if the argument is an and is exactly the same - /// as this one, false otherwise. - public override bool Equals(object? obj) => obj is NpgsqlTimeSpan span && Equals(span); - - /// - /// Compares two instances. - /// - /// The first . - /// The second . - /// 0 if the two are equal or equivalent. A value greater than zero if x is greater than y, - /// a value less than zero if x is less than y. - public static int Compare(NpgsqlTimeSpan x, NpgsqlTimeSpan y) => x.CompareTo(y); - - int IComparer.Compare(NpgsqlTimeSpan x, NpgsqlTimeSpan y) => x.CompareTo(y); - - int IComparer.Compare(object? x, object? y) - { - if (x == null) - return y == null ? 0 : 1; - if (y == null) - return -1; - try { - return ((IComparable)x).CompareTo(y); - } catch (Exception) { - throw new ArgumentException(); - } - } - - /// - /// A hash code suitable for uses with hashing algorithms. - /// - /// An signed integer. - public override int GetHashCode() => UnjustifyInterval().Ticks.GetHashCode(); - - /// - /// Compares this instance with another/ - /// - /// An to compare this with. - /// 0 if the instances are equal or equivalent. A value less than zero if - /// this instance is less than the argument. A value greater than zero if this instance - /// is greater than the instance. - public int CompareTo(NpgsqlTimeSpan other) - => UnjustifyInterval().Ticks.CompareTo(other.UnjustifyInterval().Ticks); - - /// - /// Compares this instance with another/ - /// - /// An object to compare this with. - /// 0 if the argument is an and the instances are equal or equivalent. - /// A value less than zero if the argument is an and - /// this instance is less than the argument. - /// A value greater than zero if the argument is an and this instance - /// is greater than the instance. - /// A value greater than zero if the argument is null. - /// The argument is not an . - public int CompareTo(object? other) - { - if (other == null) - return 1; - if (other is NpgsqlTimeSpan) - return CompareTo((NpgsqlTimeSpan)other); - throw new ArgumentException(nameof(other)); - } - - #endregion - - #region String Conversions - - /// - /// Parses a and returns a instance. - /// Designed to use the formats generally returned by PostgreSQL. - /// - /// The to parse. - /// An represented by the argument. - /// The string was null. - /// A value obtained from parsing the string exceeded the values allowed for the relevant component. - /// The string was not in a format that could be parsed to produce an . - public static NpgsqlTimeSpan Parse(string str) - { - if (str == null) { - throw new ArgumentNullException(nameof(str)); - } - str = str.Replace('s', ' '); //Quick and easy way to catch plurals. - try { - var years = 0; - var months = 0; - var days = 0; - var hours = 0; - var minutes = 0; - var seconds = 0m; - var idx = str.IndexOf("year", StringComparison.Ordinal); - if (idx > 0) { - years = int.Parse(str.Substring(0, idx)); - str = SafeSubstring(str, idx + 5); - } - idx = str.IndexOf("mon", StringComparison.Ordinal); - if (idx > 0) { - months = int.Parse(str.Substring(0, idx)); - str = SafeSubstring(str, idx + 4); - } - idx = str.IndexOf("day", StringComparison.Ordinal); - if (idx > 0) { - days = int.Parse(str.Substring(0, idx)); - str = SafeSubstring(str, idx + 4).Trim(); - } - if (str.Length > 0) { - var isNegative = str[0] == '-'; - var parts = str.Split(':'); - switch (parts.Length) //One of those times that fall-through would actually be good. - { - case 1: - hours = int.Parse(parts[0]); - break; - case 2: - hours = int.Parse(parts[0]); - minutes = int.Parse(parts[1]); - break; - default: - hours = int.Parse(parts[0]); - minutes = int.Parse(parts[1]); - seconds = decimal.Parse(parts[2], System.Globalization.CultureInfo.InvariantCulture.NumberFormat); - break; - } - if (isNegative) { - minutes *= -1; - seconds *= -1; - } - } - var ticks = hours * TicksPerHour + minutes * TicksPerMinute + (long)(seconds * TicksPerSecond); - return new NpgsqlTimeSpan(years * MonthsPerYear + months, days, ticks); - } catch (OverflowException) { - throw; - } catch (Exception) { - throw new FormatException(); - } - } - - private static string SafeSubstring(string s, int startIndex) - { - if (startIndex >= s.Length) - return string.Empty; - else - return s.Substring(startIndex); - } - - /// - /// Attempt to parse a to produce an . - /// - /// The to parse. - /// (out) The produced, or if the parsing failed. - /// true if the parsing succeeded, false otherwise. - public static bool TryParse(string str, out NpgsqlTimeSpan result) - { - try { - result = Parse(str); - return true; - } catch (Exception) { - result = Zero; - return false; - } - } - - /// - /// Create a representation of the instance. - /// The format returned is of the form: - /// [M mon[s]] [d day[s]] [HH:mm:ss[.f[f[f[f[f[f[f[f[f]]]]]]]]]] - /// A zero is represented as 00:00:00 - /// - /// Ticks are 100ns, Postgress resolution is only to 1µs at most. Hence we lose 1 or more decimal - /// precision in storing values in the database. Despite this, this method will output that extra - /// digit of precision. It's forward-compatible with any future increases in resolution up to 100ns, - /// and also makes this ToString() more applicable to any other use-case. - /// - /// - /// The representation. - public override string ToString() - { - var sb = new StringBuilder(); - if (Months != 0) { - sb.Append(Months).Append(Math.Abs(Months) == 1 ? " mon " : " mons "); - } - if (Days != 0) { - if (Months < 0 && Days > 0) { - sb.Append('+'); - } - sb.Append(Days).Append(Math.Abs(Days) == 1 ? " day " : " days "); - } - if (Ticks != 0 || sb.Length == 0) { - if (Ticks < 0) { - sb.Append('-'); - } else if (Days < 0 || (Days == 0 && Months < 0)) { - sb.Append('+'); - } - // calculate total seconds and then subtract total whole minutes in seconds to get just the seconds and fractional part - var seconds = _ticks / (decimal)TicksPerSecond - (_ticks / TicksPerMinute) * 60; - sb.Append(Math.Abs(Hours).ToString("D2")).Append(':').Append(Math.Abs(Minutes).ToString("D2")).Append(':').Append(Math.Abs(seconds).ToString("0#.######", System.Globalization.CultureInfo.InvariantCulture.NumberFormat)); - - } - if (sb[sb.Length - 1] == ' ') { - sb.Remove(sb.Length - 1, 1); - } - return sb.ToString(); - } - - #endregion - - #region Common Operators - - /// - /// Adds two together. - /// - /// The first to add. - /// The second to add. - /// An whose values are the sum of the arguments. - public static NpgsqlTimeSpan operator +(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return x.Add(y); - } - - /// - /// Subtracts one from another. - /// - /// The to subtract the other from. - /// The to subtract from the other. - /// An whose values are the difference of the arguments - public static NpgsqlTimeSpan operator -(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return x.Subtract(y); - } - - /// - /// Returns true if two are exactly the same. - /// - /// The first to compare. - /// The second to compare. - /// true if the two arguments are exactly the same, false otherwise. - public static bool operator ==(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return x.Equals(y); - } - - /// - /// Returns false if two are exactly the same. - /// - /// The first to compare. - /// The second to compare. - /// false if the two arguments are exactly the same, true otherwise. - public static bool operator !=(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return !(x == y); - } - - /// - /// Compares two instances to see if the first is less than the second - /// - /// The first to compare. - /// The second to compare. - /// true if the first is less than second, false otherwise. - public static bool operator <(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return x.UnjustifyInterval().Ticks < y.UnjustifyInterval().Ticks; - } - - /// - /// Compares two instances to see if the first is less than or equivalent to the second - /// - /// The first to compare. - /// The second to compare. - /// true if the first is less than or equivalent to second, false otherwise. - public static bool operator <=(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return x.UnjustifyInterval().Ticks <= y.UnjustifyInterval().Ticks; - } - - /// - /// Compares two instances to see if the first is greater than the second - /// - /// The first to compare. - /// The second to compare. - /// true if the first is greater than second, false otherwise. - public static bool operator >(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return !(x <= y); - } - - /// - /// Compares two instances to see if the first is greater than or equivalent the second - /// - /// The first to compare. - /// The second to compare. - /// true if the first is greater than or equivalent to the second, false otherwise. - public static bool operator >=(NpgsqlTimeSpan x, NpgsqlTimeSpan y) - { - return !(x < y); - } - - /// - /// Returns the instance. - /// - public static NpgsqlTimeSpan operator +(NpgsqlTimeSpan x) => Plus(x); - - /// - /// Returns the instance. - /// - public static NpgsqlTimeSpan Plus(in NpgsqlTimeSpan x) => x; - - /// - /// Negates an instance. - /// - /// An . - /// The negation of the argument. - public static NpgsqlTimeSpan operator -(NpgsqlTimeSpan x) => x.Negate(); - - #endregion - } -} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs index f731fa4ac6..9d49dc7c3a 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs @@ -5,641 +5,741 @@ #pragma warning disable CA1034 // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL tsquery. This is the base class for the +/// lexeme, not, or, and, and "followed by" nodes. +/// +public abstract class NpgsqlTsQuery : IEquatable { /// - /// Represents a PostgreSQL tsquery. This is the base class for the - /// lexeme, not, or, and, and "followed by" nodes. + /// Node kind /// - public abstract class NpgsqlTsQuery + public NodeKind Kind { get; } + + /// + /// NodeKind + /// + public enum NodeKind { /// - /// Node kind + /// Represents the empty tsquery. Should only be used at top level. /// - public NodeKind Kind { get; } - + Empty = -1, /// - /// NodeKind + /// Lexeme /// - public enum NodeKind - { - /// - /// Represents the empty tsquery. Should only be used at top level. - /// - Empty = -1, - /// - /// Lexeme - /// - Lexeme = 0, - /// - /// Not operator - /// - Not = 1, - /// - /// And operator - /// - And = 2, - /// - /// Or operator - /// - Or = 3, - /// - /// "Followed by" operator - /// - Phrase = 4 - } - + Lexeme = 0, /// - /// Constructs an . + /// Not operator /// - /// - protected NpgsqlTsQuery(NodeKind kind) => Kind = kind; - - internal abstract void Write(StringBuilder sb, bool first = false); - + Not = 1, /// - /// Writes the tsquery in PostgreSQL's text format. + /// And operator /// - /// - public override string ToString() - { - var sb = new StringBuilder(); - Write(sb, true); - return sb.ToString(); - } - + And = 2, /// - /// Parses a tsquery in PostgreSQL's text format. + /// Or operator /// - /// - /// - public static NpgsqlTsQuery Parse(string value) - { - if (value == null) - throw new ArgumentNullException(nameof(value)); + Or = 3, + /// + /// "Followed by" operator + /// + Phrase = 4 + } + + /// + /// Constructs an . + /// + /// + protected NpgsqlTsQuery(NodeKind kind) => Kind = kind; - var valStack = new Stack(); - var opStack = new Stack(); + /// + /// Writes the tsquery in PostgreSQL's text format. + /// + public void Write(StringBuilder stringBuilder) => WriteCore(stringBuilder, true); - var sb = new StringBuilder(); - var pos = 0; - var expectingBinOp = false; + internal abstract void WriteCore(StringBuilder sb, bool first = false); - var lastFollowedByOpDistance = -1; + /// + /// Writes the tsquery in PostgreSQL's text format. + /// + public override string ToString() + { + var sb = new StringBuilder(); + Write(sb); + return sb.ToString(); + } - NextToken: - if (pos >= value.Length) - goto Finish; - var ch = value[pos++]; - if (ch == '\'') - goto WaitEndComplex; - if ((ch == ')' || ch == '|' || ch == '&') && !expectingBinOp || (ch == '(' || ch == '!') && expectingBinOp) - throw new FormatException("Syntax error in tsquery. Unexpected token."); - - if (ch == '<') - { - var endOfOperatorConsumed = false; - var sbCurrentLength = sb.Length; + /// + /// Parses a tsquery in PostgreSQL's text format. + /// + /// + /// + [Obsolete("Client-side parsing of NpgsqlTsQuery is unreliable and cannot fully duplicate the PostgreSQL logic. Use PG functions instead (e.g. to_tsquery)")] + public static NpgsqlTsQuery Parse(string value) + { + if (value == null) + throw new ArgumentNullException(nameof(value)); - while (pos < value.Length) - { - var c = value[pos++]; - if (c == '>') - { - endOfOperatorConsumed = true; - break; - } - - sb.Append(c); - } + var valStack = new Stack(); + var opStack = new Stack(); - if (sb.Length == sbCurrentLength || !endOfOperatorConsumed) - throw new FormatException("Syntax error in tsquery. Malformed 'followed by' operator."); + var sb = new StringBuilder(); + var pos = 0; + var expectingBinOp = false; - var followedByOpDistanceString = sb.ToString(sbCurrentLength, sb.Length - sbCurrentLength); - if (followedByOpDistanceString == "-") - { - lastFollowedByOpDistance = 1; - } - else if (!int.TryParse(followedByOpDistanceString, out lastFollowedByOpDistance) - || lastFollowedByOpDistance < 0) + short lastFollowedByOpDistance = -1; + + NextToken: + if (pos >= value.Length) + goto Finish; + var ch = value[pos++]; + if (ch == '\'') + goto WaitEndComplex; + if ((ch == ')' || ch == '|' || ch == '&') && !expectingBinOp || (ch == '(' || ch == '!') && expectingBinOp) + throw new FormatException("Syntax error in tsquery. Unexpected token."); + + if (ch == '<') + { + var endOfOperatorConsumed = false; + var sbCurrentLength = sb.Length; + + while (pos < value.Length) + { + var c = value[pos++]; + if (c == '>') { - throw new FormatException("Syntax error in tsquery. Malformed distance in 'followed by' operator."); + endOfOperatorConsumed = true; + break; } - sb.Length -= followedByOpDistanceString.Length; + sb.Append(c); } - if (ch == '(' || ch == '!' || ch == '&' || ch == '<') + if (sb.Length == sbCurrentLength || !endOfOperatorConsumed) + throw new FormatException("Syntax error in tsquery. Malformed 'followed by' operator."); + + var followedByOpDistanceString = sb.ToString(sbCurrentLength, sb.Length - sbCurrentLength); + if (followedByOpDistanceString == "-") { - opStack.Push(new NpgsqlTsQueryOperator(ch, lastFollowedByOpDistance)); - expectingBinOp = false; - lastFollowedByOpDistance = 0; - goto NextToken; + lastFollowedByOpDistance = 1; + } + else if (!short.TryParse(followedByOpDistanceString, out lastFollowedByOpDistance) + || lastFollowedByOpDistance < 0) + { + throw new FormatException("Syntax error in tsquery. Malformed distance in 'followed by' operator."); } - if (ch == '|') + sb.Length -= followedByOpDistanceString.Length; + } + + if (ch == '(' || ch == '!' || ch == '&' || ch == '<') + { + opStack.Push(new NpgsqlTsQueryOperator(ch, lastFollowedByOpDistance)); + expectingBinOp = false; + lastFollowedByOpDistance = 0; + goto NextToken; + } + + if (ch == '|') + { + if (opStack.Count > 0 && opStack.Peek() == '|') { - if (opStack.Count > 0 && opStack.Peek() == '|') - { - if (valStack.Count < 2) - throw new FormatException("Syntax error in tsquery"); - var right = valStack.Pop(); - var left = valStack.Pop(); - valStack.Push(new NpgsqlTsQueryOr(left, right)); - // Implicit pop and repush | - } - else - opStack.Push('|'); - expectingBinOp = false; - goto NextToken; + if (valStack.Count < 2) + throw new FormatException("Syntax error in tsquery"); + var right = valStack.Pop(); + var left = valStack.Pop(); + valStack.Push(new NpgsqlTsQueryOr(left, right)); + // Implicit pop and repush | } + else + opStack.Push('|'); + expectingBinOp = false; + goto NextToken; + } - if (ch == ')') + if (ch == ')') + { + while (opStack.Count > 0 && opStack.Peek() != '(') { - while (opStack.Count > 0 && opStack.Peek() != '(') + if (valStack.Count < 2 || opStack.Peek() == '!') + throw new FormatException("Syntax error in tsquery"); + + var right = valStack.Pop(); + var left = valStack.Pop(); + + var tsOp = opStack.Pop(); + valStack.Push((char)tsOp switch { - if (valStack.Count < 2 || opStack.Peek() == '!') - throw new FormatException("Syntax error in tsquery"); - - var right = valStack.Pop(); - var left = valStack.Pop(); - - var tsOp = opStack.Pop(); - valStack.Push((char)tsOp switch - { - '&' => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(left, right), - '|' => new NpgsqlTsQueryOr(left, right), - '<' => new NpgsqlTsQueryFollowedBy(left, tsOp.FollowedByDistance, right), - _ => throw new FormatException("Syntax error in tsquery") - }); - } - if (opStack.Count == 0) - throw new FormatException("Syntax error in tsquery: closing parenthesis without an opening parenthesis"); - opStack.Pop(); - goto PushedVal; + '&' => new NpgsqlTsQueryAnd(left, right), + '|' => new NpgsqlTsQueryOr(left, right), + '<' => new NpgsqlTsQueryFollowedBy(left, tsOp.FollowedByDistance, right), + _ => throw new FormatException("Syntax error in tsquery") + }); } + if (opStack.Count == 0) + throw new FormatException("Syntax error in tsquery: closing parenthesis without an opening parenthesis"); + opStack.Pop(); + goto PushedVal; + } - if (ch == ':') - throw new FormatException("Unexpected : while parsing tsquery"); + if (ch == ':') + throw new FormatException("Unexpected : while parsing tsquery"); - if (char.IsWhiteSpace(ch)) - goto NextToken; + if (char.IsWhiteSpace(ch)) + goto NextToken; - pos--; - if (expectingBinOp) - throw new FormatException("Unexpected lexeme while parsing tsquery"); - // Proceed to WaitEnd + pos--; + if (expectingBinOp) + throw new FormatException("Unexpected lexeme while parsing tsquery"); + // Proceed to WaitEnd - WaitEnd: - if (pos >= value.Length || char.IsWhiteSpace(ch = value[pos]) || ch == '!' || ch == '&' || ch == '|' || ch == '(' || ch == ')') + WaitEnd: + if (pos >= value.Length || char.IsWhiteSpace(ch = value[pos]) || ch == '!' || ch == '&' || ch == '|' || ch == '(' || ch == ')') + { + valStack.Push(new NpgsqlTsQueryLexeme(sb.ToString())); + goto PushedVal; + } + pos++; + if (ch == ':') + { + valStack.Push(new NpgsqlTsQueryLexeme(sb.ToString())); + sb.Clear(); + goto InWeightInfo; + } + if (ch == '\\') + { + if (pos >= value.Length) + throw new FormatException(@"Unexpected \ in end of value"); + ch = value[pos++]; + } + sb.Append(ch); + goto WaitEnd; + + WaitEndComplex: + if (pos >= value.Length) + throw new FormatException("Missing terminating ' in string literal"); + ch = value[pos++]; + if (ch == '\'') + { + if (pos < value.Length && value[pos] == '\'') { - valStack.Push(new NpgsqlTsQueryLexeme(sb.ToString())); - goto PushedVal; + ch = '\''; + pos++; } - pos++; - if (ch == ':') + else { valStack.Push(new NpgsqlTsQueryLexeme(sb.ToString())); - sb.Clear(); - goto InWeightInfo; - } - if (ch == '\\') - { - if (pos >= value.Length) - throw new FormatException(@"Unexpected \ in end of value"); - ch = value[pos++]; - } - sb.Append(ch); - goto WaitEnd; - - WaitEndComplex: - if (pos >= value.Length) - throw new FormatException("Missing terminating ' in string literal"); - ch = value[pos++]; - if (ch == '\'') - { - if (pos < value.Length && value[pos] == '\'') + if (pos < value.Length && value[pos] == ':') { - ch = '\''; pos++; + goto InWeightInfo; } - else - { - valStack.Push(new NpgsqlTsQueryLexeme(sb.ToString())); - if (pos < value.Length && value[pos] == ':') - { - pos++; - goto InWeightInfo; - } - goto PushedVal; - } - } - if (ch == '\\') - { - if (pos >= value.Length) - throw new FormatException(@"Unexpected \ in end of value"); - ch = value[pos++]; + goto PushedVal; } - sb.Append(ch); - goto WaitEndComplex; - - - InWeightInfo: + } + if (ch == '\\') + { if (pos >= value.Length) - goto Finish; - ch = value[pos]; - if (ch == '*') - ((NpgsqlTsQueryLexeme)valStack.Peek()).IsPrefixSearch = true; - else if (ch == 'a' || ch == 'A') - ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.A; - else if (ch == 'b' || ch == 'B') - ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.B; - else if (ch == 'c' || ch == 'C') - ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.C; - else if (ch == 'd' || ch == 'D') - ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.D; - else - goto PushedVal; - pos++; - goto InWeightInfo; - - PushedVal: - sb.Clear(); - var processTightBindingOperator = true; - while (opStack.Count > 0 && processTightBindingOperator) - { - var tsOp = opStack.Peek(); - switch (tsOp) - { - case '&': - if (valStack.Count < 2) - throw new FormatException("Syntax error in tsquery"); - var andRight = valStack.Pop(); - var andLeft = valStack.Pop(); - valStack.Push(new NpgsqlTsQueryAnd(andLeft, andRight)); - opStack.Pop(); - break; + throw new FormatException(@"Unexpected \ in end of value"); + ch = value[pos++]; + } + sb.Append(ch); + goto WaitEndComplex; - case '!': - if (valStack.Count == 0) - throw new FormatException("Syntax error in tsquery"); - valStack.Push(new NpgsqlTsQueryNot(valStack.Pop())); - opStack.Pop(); - break; - case '<': - if (valStack.Count < 2) - throw new FormatException("Syntax error in tsquery"); - var followedByRight = valStack.Pop(); - var followedByLeft = valStack.Pop(); - valStack.Push( - new NpgsqlTsQueryFollowedBy( - followedByLeft, - tsOp.FollowedByDistance, - followedByRight)); - opStack.Pop(); - break; + InWeightInfo: + if (pos >= value.Length) + goto Finish; + ch = value[pos]; + switch (ch) + { + case '*': + ((NpgsqlTsQueryLexeme)valStack.Peek()).IsPrefixSearch = true; + break; + case 'a' or 'A': + ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.A; + break; + case 'b' or 'B': + ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.B; + break; + case 'c' or 'C': + ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.C; + break; + case 'd' or 'D': + ((NpgsqlTsQueryLexeme)valStack.Peek()).Weights |= NpgsqlTsQueryLexeme.Weight.D; + break; + default: + goto PushedVal; + } - default: - processTightBindingOperator = false; - break; - } - } - expectingBinOp = true; - goto NextToken; + pos++; + goto InWeightInfo; - Finish: - while (opStack.Count > 0) + PushedVal: + sb.Clear(); + var processTightBindingOperator = true; + while (opStack.Count > 0 && processTightBindingOperator) + { + var tsOp = opStack.Peek(); + switch (tsOp) { + case '&': if (valStack.Count < 2) throw new FormatException("Syntax error in tsquery"); + var andRight = valStack.Pop(); + var andLeft = valStack.Pop(); + valStack.Push(new NpgsqlTsQueryAnd(andLeft, andRight)); + opStack.Pop(); + break; - var right = valStack.Pop(); - var left = valStack.Pop(); + case '!': + if (valStack.Count == 0) + throw new FormatException("Syntax error in tsquery"); + valStack.Push(new NpgsqlTsQueryNot(valStack.Pop())); + opStack.Pop(); + break; - var tsOp = opStack.Pop(); - var query = (char)tsOp switch - { - '&' => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(left, right), - '|' => new NpgsqlTsQueryOr(left, right), - '<' => new NpgsqlTsQueryFollowedBy(left, tsOp.FollowedByDistance, right), - _ => throw new FormatException("Syntax error in tsquery") - }; - valStack.Push(query); + case '<': + if (valStack.Count < 2) + throw new FormatException("Syntax error in tsquery"); + var followedByRight = valStack.Pop(); + var followedByLeft = valStack.Pop(); + valStack.Push( + new NpgsqlTsQueryFollowedBy( + followedByLeft, + tsOp.FollowedByDistance, + followedByRight)); + opStack.Pop(); + break; + + default: + processTightBindingOperator = false; + break; } - if (valStack.Count != 1) + } + expectingBinOp = true; + goto NextToken; + + Finish: + while (opStack.Count > 0) + { + if (valStack.Count < 2) throw new FormatException("Syntax error in tsquery"); - return valStack.Pop(); + + var right = valStack.Pop(); + var left = valStack.Pop(); + + var tsOp = opStack.Pop(); + var query = (char)tsOp switch + { + '&' => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(left, right), + '|' => new NpgsqlTsQueryOr(left, right), + '<' => new NpgsqlTsQueryFollowedBy(left, tsOp.FollowedByDistance, right), + _ => throw new FormatException("Syntax error in tsquery") + }; + valStack.Push(query); } + if (valStack.Count != 1) + throw new FormatException("Syntax error in tsquery"); + return valStack.Pop(); } - readonly struct NpgsqlTsQueryOperator - { - public readonly char Char; - public readonly int FollowedByDistance; + /// + public override int GetHashCode() + => throw new NotSupportedException("Must be overridden"); - public NpgsqlTsQueryOperator(char character, int followedByDistance) - { - Char = character; - FollowedByDistance = followedByDistance; - } + /// + public override bool Equals(object? obj) + => obj is NpgsqlTsQuery query && query.Equals(this); - public static implicit operator NpgsqlTsQueryOperator(char c) => new NpgsqlTsQueryOperator(c, 0); - public static implicit operator char(NpgsqlTsQueryOperator o) => o.Char; - } + /// + /// Returns a value indicating whether this instance and a specified object represent the same value. + /// + /// An object to compare to this instance. + /// if g is equal to this instance; otherwise, . + public abstract bool Equals(NpgsqlTsQuery? other); /// - /// TsQuery Lexeme node. + /// Indicates whether the values of two specified objects are equal. /// - public sealed class NpgsqlTsQueryLexeme : NpgsqlTsQuery + /// The first object to compare. + /// The second object to compare. + /// if and are equal; otherwise, . + public static bool operator ==(NpgsqlTsQuery? left, NpgsqlTsQuery? right) + => left is null ? right is null : left.Equals(right); + + /// + /// Indicates whether the values of two specified objects are not equal. + /// + /// The first object to compare. + /// The second object to compare. + /// if and are not equal; otherwise, . + public static bool operator !=(NpgsqlTsQuery? left, NpgsqlTsQuery? right) + => left is null ? right is not null : !left.Equals(right); +} + +readonly struct NpgsqlTsQueryOperator +{ + public readonly char Char; + public readonly short FollowedByDistance; + + public NpgsqlTsQueryOperator(char character, short followedByDistance) { - string _text; + Char = character; + FollowedByDistance = followedByDistance; + } - /// - /// Lexeme text. - /// - public string Text + public static implicit operator NpgsqlTsQueryOperator(char c) => new(c, 0); + public static implicit operator char(NpgsqlTsQueryOperator o) => o.Char; +} + +/// +/// TsQuery Lexeme node. +/// +public sealed class NpgsqlTsQueryLexeme : NpgsqlTsQuery +{ + string _text; + + /// + /// Lexeme text. + /// + public string Text + { + get => _text; + set { - get => _text; - set - { - if (string.IsNullOrEmpty(value)) - throw new ArgumentException("Text is null or empty string", nameof(value)); + if (string.IsNullOrEmpty(value)) + throw new ArgumentException("Text is null or empty string", nameof(value)); - _text = value; - } + _text = value; } + } - Weight _weights; + Weight _weights; - /// - /// Weights is a bitmask of the Weight enum. - /// - public Weight Weights + /// + /// Weights is a bitmask of the Weight enum. + /// + public Weight Weights + { + get => _weights; + set { - get => _weights; - set - { - if (((byte)value >> 4) != 0) - throw new ArgumentOutOfRangeException(nameof(value), "Illegal weights"); + if (((byte)value >> 4) != 0) + throw new ArgumentOutOfRangeException(nameof(value), "Illegal weights"); - _weights = value; - } + _weights = value; } + } + + /// + /// Prefix search. + /// + public bool IsPrefixSearch { get; set; } + + /// + /// Creates a tsquery lexeme with only lexeme text. + /// + /// Lexeme text. + public NpgsqlTsQueryLexeme(string text) : this(text, Weight.None, false) { } + + /// + /// Creates a tsquery lexeme with lexeme text and weights. + /// + /// Lexeme text. + /// Bitmask of enum Weight. + public NpgsqlTsQueryLexeme(string text, Weight weights) : this(text, weights, false) { } + /// + /// Creates a tsquery lexeme with lexeme text, weights and prefix search flag. + /// + /// Lexeme text. + /// Bitmask of enum Weight. + /// Is prefix search? + public NpgsqlTsQueryLexeme(string text, Weight weights, bool isPrefixSearch) + : base(NodeKind.Lexeme) + { + _text = text; + Weights = weights; + IsPrefixSearch = isPrefixSearch; + } + + /// + /// Weight enum, can be OR'ed together. + /// +#pragma warning disable CA1714 + [Flags] + public enum Weight +#pragma warning restore CA1714 + { /// - /// Prefix search. + /// None /// - public bool IsPrefixSearch { get; set; } - + None = 0, /// - /// Creates a tsquery lexeme with only lexeme text. + /// D /// - /// Lexeme text. - public NpgsqlTsQueryLexeme(string text) : this(text, Weight.None, false) { } - + D = 1, /// - /// Creates a tsquery lexeme with lexeme text and weights. + /// C /// - /// Lexeme text. - /// Bitmask of enum Weight. - public NpgsqlTsQueryLexeme(string text, Weight weights) : this(text, weights, false) { } - + C = 2, /// - /// Creates a tsquery lexeme with lexeme text, weights and prefix search flag. + /// B /// - /// Lexeme text. - /// Bitmask of enum Weight. - /// Is prefix search? - public NpgsqlTsQueryLexeme(string text, Weight weights, bool isPrefixSearch) - : base(NodeKind.Lexeme) - { - _text = text; - Weights = weights; - IsPrefixSearch = isPrefixSearch; - } - + B = 4, /// - /// Weight enum, can be OR'ed together. + /// A /// -#pragma warning disable CA1714 - [Flags] - public enum Weight -#pragma warning restore CA1714 - { - /// - /// None - /// - None = 0, - /// - /// D - /// - D = 1, - /// - /// C - /// - C = 2, - /// - /// B - /// - B = 4, - /// - /// A - /// - A = 8 - } + A = 8 + } - internal override void Write(StringBuilder sb, bool first = false) - { - sb.Append('\'').Append(Text.Replace(@"\", @"\\").Replace("'", "''")).Append('\''); - if (IsPrefixSearch || Weights != Weight.None) - sb.Append(':'); - if (IsPrefixSearch) - sb.Append('*'); - if ((Weights & Weight.A) != Weight.None) - sb.Append('A'); - if ((Weights & Weight.B) != Weight.None) - sb.Append('B'); - if ((Weights & Weight.C) != Weight.None) - sb.Append('C'); - if ((Weights & Weight.D) != Weight.None) - sb.Append('D'); - } + internal override void WriteCore(StringBuilder sb, bool first = false) + { + sb.Append('\'').Append(Text.Replace(@"\", @"\\").Replace("'", "''")).Append('\''); + if (IsPrefixSearch || Weights != Weight.None) + sb.Append(':'); + if (IsPrefixSearch) + sb.Append('*'); + if ((Weights & Weight.A) != Weight.None) + sb.Append('A'); + if ((Weights & Weight.B) != Weight.None) + sb.Append('B'); + if ((Weights & Weight.C) != Weight.None) + sb.Append('C'); + if ((Weights & Weight.D) != Weight.None) + sb.Append('D'); } + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryLexeme lexeme && + lexeme.Text == Text && + lexeme.Weights == Weights && + lexeme.IsPrefixSearch == IsPrefixSearch; + + /// + public override int GetHashCode() + => HashCode.Combine(Text, Weights, IsPrefixSearch); +} + +/// +/// TsQuery Not node. +/// +public sealed class NpgsqlTsQueryNot : NpgsqlTsQuery +{ /// - /// TsQuery Not node. + /// Child node /// - public sealed class NpgsqlTsQueryNot : NpgsqlTsQuery + public NpgsqlTsQuery Child { get; set; } + + /// + /// Creates a not operator, with a given child node. + /// + /// + public NpgsqlTsQueryNot(NpgsqlTsQuery child) + : base(NodeKind.Not) { - /// - /// Child node - /// - public NpgsqlTsQuery? Child { get; set; } + Child = child; + } - /// - /// Creates a not operator, with a given child node. - /// - /// - public NpgsqlTsQueryNot(NpgsqlTsQuery? child) - : base(NodeKind.Not) + internal override void WriteCore(StringBuilder sb, bool first = false) + { + sb.Append('!'); + if (Child == null) { - Child = child; + sb.Append("''"); } - - internal override void Write(StringBuilder sb, bool first = false) + else { - sb.Append('!'); - if (Child == null) - { - sb.Append("''"); - } - else - { - if (Child.Kind != NodeKind.Lexeme) - sb.Append("( "); - Child.Write(sb, true); - if (Child.Kind != NodeKind.Lexeme) - sb.Append(" )"); - } + if (Child.Kind != NodeKind.Lexeme) + sb.Append("( "); + Child.WriteCore(sb, true); + if (Child.Kind != NodeKind.Lexeme) + sb.Append(" )"); } } + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryNot not && not.Child == Child; + + /// + public override int GetHashCode() + => Child?.GetHashCode() ?? 0; +} + +/// +/// Base class for TsQuery binary operators (& and |). +/// +public abstract class NpgsqlTsQueryBinOp : NpgsqlTsQuery +{ /// - /// Base class for TsQuery binary operators (& and |). + /// Left child /// - public abstract class NpgsqlTsQueryBinOp : NpgsqlTsQuery - { - /// - /// Left child - /// - public NpgsqlTsQuery Left { get; set; } + public NpgsqlTsQuery Left { get; set; } - /// - /// Right child - /// - public NpgsqlTsQuery Right { get; set; } + /// + /// Right child + /// + public NpgsqlTsQuery Right { get; set; } - /// - /// Constructs a . - /// - protected NpgsqlTsQueryBinOp(NodeKind kind, NpgsqlTsQuery left, NpgsqlTsQuery right) - : base(kind) - { - Left = left; - Right = right; - } + /// + /// Constructs a . + /// + protected NpgsqlTsQueryBinOp(NodeKind kind, NpgsqlTsQuery left, NpgsqlTsQuery right) + : base(kind) + { + Left = left; + Right = right; } +} +/// +/// TsQuery And node. +/// +public sealed class NpgsqlTsQueryAnd : NpgsqlTsQueryBinOp +{ /// - /// TsQuery And node. + /// Creates an and operator, with two given child nodes. /// - public sealed class NpgsqlTsQueryAnd : NpgsqlTsQueryBinOp - { - /// - /// Creates an and operator, with two given child nodes. - /// - /// - /// - public NpgsqlTsQueryAnd(NpgsqlTsQuery left, NpgsqlTsQuery right) - : base(NodeKind.And, left, right) {} + /// + /// + public NpgsqlTsQueryAnd(NpgsqlTsQuery left, NpgsqlTsQuery right) + : base(NodeKind.And, left, right) {} - internal override void Write(StringBuilder sb, bool first = false) - { - Left.Write(sb); - sb.Append(" & "); - Right.Write(sb); - } + internal override void WriteCore(StringBuilder sb, bool first = false) + { + Left.WriteCore(sb); + sb.Append(" & "); + Right.WriteCore(sb); } + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryAnd and && and.Left == Left && and.Right == Right; + + /// + public override int GetHashCode() + => HashCode.Combine(Left, Right); +} + +/// +/// TsQuery Or Node. +/// +public sealed class NpgsqlTsQueryOr : NpgsqlTsQueryBinOp +{ /// - /// TsQuery Or Node. + /// Creates an or operator, with two given child nodes. /// - public sealed class NpgsqlTsQueryOr : NpgsqlTsQueryBinOp - { - /// - /// Creates an or operator, with two given child nodes. - /// - /// - /// - public NpgsqlTsQueryOr(NpgsqlTsQuery left, NpgsqlTsQuery right) - : base(NodeKind.Or, left, right) {} + /// + /// + public NpgsqlTsQueryOr(NpgsqlTsQuery left, NpgsqlTsQuery right) + : base(NodeKind.Or, left, right) {} - internal override void Write(StringBuilder sb, bool first = false) - { - // TODO: Figure out the nullability strategy here - if (!first) - sb.Append("( "); + internal override void WriteCore(StringBuilder sb, bool first = false) + { + // TODO: Figure out the nullability strategy here + if (!first) + sb.Append("( "); - Left.Write(sb); - sb.Append(" | "); - Right.Write(sb); + Left.WriteCore(sb); + sb.Append(" | "); + Right.WriteCore(sb); - if (!first) - sb.Append(" )"); - } + if (!first) + sb.Append(" )"); } + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryOr or && or.Left == Left && or.Right == Right; + + /// + public override int GetHashCode() + => HashCode.Combine(Left, Right); +} + +/// +/// TsQuery "Followed by" Node. +/// +public sealed class NpgsqlTsQueryFollowedBy : NpgsqlTsQueryBinOp +{ /// - /// TsQuery "Followed by" Node. + /// The distance between the 2 nodes, in lexemes. /// - public sealed class NpgsqlTsQueryFollowedBy : NpgsqlTsQueryBinOp - { - /// - /// The distance between the 2 nodes, in lexemes. - /// - public int Distance { get; set; } + public short Distance { get; set; } - /// - /// Creates a "followed by" operator, specifying 2 child nodes and the - /// distance between them in lexemes. - /// - /// - /// - /// - public NpgsqlTsQueryFollowedBy( - NpgsqlTsQuery left, - int distance, - NpgsqlTsQuery right) - : base(NodeKind.Phrase, left, right) - { - if (distance < 0) - throw new ArgumentOutOfRangeException(nameof(distance)); + /// + /// Creates a "followed by" operator, specifying 2 child nodes and the + /// distance between them in lexemes. + /// + /// + /// + /// + public NpgsqlTsQueryFollowedBy( + NpgsqlTsQuery left, + short distance, + NpgsqlTsQuery right) + : base(NodeKind.Phrase, left, right) + { + if (distance < 0) + throw new ArgumentOutOfRangeException(nameof(distance)); - Distance = distance; - } + Distance = distance; + } - internal override void Write(StringBuilder sb, bool first = false) - { - // TODO: Figure out the nullability strategy here - if (!first) - sb.Append("( "); + internal override void WriteCore(StringBuilder sb, bool first = false) + { + // TODO: Figure out the nullability strategy here + if (!first) + sb.Append("( "); - Left.Write(sb); + Left.WriteCore(sb); - sb.Append(" <"); - if (Distance == 1) sb.Append("-"); - else sb.Append(Distance); - sb.Append("> "); + sb.Append(" <"); + if (Distance == 1) sb.Append("-"); + else sb.Append(Distance); + sb.Append("> "); - Right.Write(sb); + Right.WriteCore(sb); - if (!first) - sb.Append(" )"); - } + if (!first) + sb.Append(" )"); } + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryFollowedBy followedBy && + followedBy.Left == Left && + followedBy.Right == Right && + followedBy.Distance == Distance; + + /// + public override int GetHashCode() + => HashCode.Combine(Left, Right, Distance); +} + +/// +/// Represents an empty tsquery. Shold only be used as top node. +/// +public sealed class NpgsqlTsQueryEmpty : NpgsqlTsQuery +{ /// - /// Represents an empty tsquery. Shold only be used as top node. + /// Creates a tsquery that represents an empty query. Should not be used as child node. /// - public sealed class NpgsqlTsQueryEmpty : NpgsqlTsQuery - { - /// - /// Creates a tsquery that represents an empty query. Should not be used as child node. - /// - public NpgsqlTsQueryEmpty() : base(NodeKind.Empty) {} + public NpgsqlTsQueryEmpty() : base(NodeKind.Empty) {} - internal override void Write(StringBuilder sb, bool first = false) {} - } + internal override void WriteCore(StringBuilder sb, bool first = false) { } + + /// + public override bool Equals(NpgsqlTsQuery? other) + => other is NpgsqlTsQueryEmpty; + + /// + public override int GetHashCode() + => Kind.GetHashCode(); } diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs index 53835691b3..76f097f0ac 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs @@ -5,507 +5,551 @@ #pragma warning disable CA1040, CA1034 // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL tsvector. +/// +public sealed class NpgsqlTsVector : IEnumerable, IEquatable { - /// - /// Represents a PostgreSQL tsvector. - /// - public sealed class NpgsqlTsVector : IEnumerable - { - readonly List _lexemes; + readonly List _lexemes; - internal NpgsqlTsVector(List lexemes, bool noCheck = false) + internal NpgsqlTsVector(List lexemes, bool noCheck = false) + { + if (noCheck) { - if (noCheck) - { - _lexemes = lexemes; - return; - } + _lexemes = lexemes; + return; + } - _lexemes = new List(lexemes); + _lexemes = new List(lexemes); - if (_lexemes.Count == 0) - return; + if (_lexemes.Count == 0) + return; - // Culture-specific comparisons doesn't really matter for the backend. It's sorting on its own if it detects an unsorted collection. - // Only when a .NET user wants to print the sort order. - _lexemes.Sort((a, b) => a.Text.CompareTo(b.Text)); + // Culture-specific comparisons doesn't really matter for the backend. It's sorting on its own if it detects an unsorted collection. + // Only when a .NET user wants to print the sort order. + _lexemes.Sort((a, b) => string.Compare(a.Text, b.Text, StringComparison.CurrentCulture)); - var res = 0; - var pos = 1; - while (pos < _lexemes.Count) + var res = 0; + var pos = 1; + while (pos < _lexemes.Count) + { + if (_lexemes[pos].Text != _lexemes[res].Text) + { + // We're done with this lexeme. First make sure the word pos list is sorted and contains unique elements. + _lexemes[res] = new Lexeme(_lexemes[res].Text, Lexeme.UniquePos(_lexemes[res].WordEntryPositions), true); + res++; + if (res != pos) + _lexemes[res] = _lexemes[pos]; + } + else { - if (_lexemes[pos].Text != _lexemes[res].Text) + // Just concatenate the word pos lists + var wordEntryPositions = _lexemes[res].WordEntryPositions; + if (wordEntryPositions != null) { - // We're done with this lexeme. First make sure the word pos list is sorted and contains unique elements. - _lexemes[res] = new Lexeme(_lexemes[res].Text, Lexeme.UniquePos(_lexemes[res].WordEntryPositions), true); - res++; - if (res != pos) - _lexemes[res] = _lexemes[pos]; + var lexeme = _lexemes[pos]; + if (lexeme.WordEntryPositions != null) + wordEntryPositions.AddRange(lexeme.WordEntryPositions); } else { - // Just concatenate the word pos lists - var wordEntryPositions = _lexemes[res].WordEntryPositions; - if (wordEntryPositions != null) - { - var lexeme = _lexemes[pos]; - if (lexeme.WordEntryPositions != null) - wordEntryPositions.AddRange(lexeme.WordEntryPositions); - } - else - { - _lexemes[res] = _lexemes[pos]; - } + _lexemes[res] = _lexemes[pos]; } - pos++; } + pos++; + } - // Last element - _lexemes[res] = new Lexeme(_lexemes[res].Text, Lexeme.UniquePos(_lexemes[res].WordEntryPositions), true); - if (res != pos - 1) - { - _lexemes.RemoveRange(res, pos - 1 - res); - } + // Last element + _lexemes[res] = new Lexeme(_lexemes[res].Text, Lexeme.UniquePos(_lexemes[res].WordEntryPositions), true); + if (res != pos - 1) + { + _lexemes.RemoveRange(res, pos - 1 - res); } + } - /// - /// Parses a tsvector in PostgreSQL's text format. - /// - /// - /// - public static NpgsqlTsVector Parse(string value) + /// + /// Parses a tsvector in PostgreSQL's text format. + /// + /// + /// + [Obsolete("Client-side parsing of NpgsqlTsVector is unreliable and cannot fully duplicate the PostgreSQL logic. Use PG functions instead (e.g. to_tsvector)")] + public static NpgsqlTsVector Parse(string value) + { + if (value == null) + throw new ArgumentNullException(nameof(value)); + + var lexemes = new List(); + var pos = 0; + var wordPos = 0; + var sb = new StringBuilder(); + List wordEntryPositions; + + WaitWord: + if (pos >= value.Length) + goto Finish; + if (char.IsWhiteSpace(value[pos])) { - if (value == null) - throw new ArgumentNullException(nameof(value)); + pos++; + goto WaitWord; + } + sb.Clear(); + if (value[pos] == '\'') + { + pos++; + goto WaitEndComplex; + } + if (value[pos] == '\\') + { + pos++; + goto WaitNextChar; + } + sb.Append(value[pos++]); + goto WaitEndWord; - var lexemes = new List(); - var pos = 0; - var wordPos = 0; - var sb = new StringBuilder(); - List wordEntryPositions; + WaitNextChar: + if (pos >= value.Length) + throw new FormatException("Missing escaped character after \\ at end of value"); + sb.Append(value[pos++]); - WaitWord: + WaitEndWord: + if (pos >= value.Length || char.IsWhiteSpace(value[pos])) + { + lexemes.Add(new Lexeme(sb.ToString())); if (pos >= value.Length) goto Finish; - if (char.IsWhiteSpace(value[pos])) - { - pos++; - goto WaitWord; - } - sb.Clear(); - if (value[pos] == '\'') - { - pos++; - goto WaitEndComplex; - } - if (value[pos] == '\\') - { - pos++; - goto WaitNextChar; - } - sb.Append(value[pos++]); - goto WaitEndWord; + pos++; + goto WaitWord; + } + if (value[pos] == '\\') + { + pos++; + goto WaitNextChar; + } + if (value[pos] == ':') + { + pos++; + goto StartPosInfo; + } + sb.Append(value[pos++]); + goto WaitEndWord; - WaitNextChar: + WaitEndComplex: + if (pos >= value.Length) + throw new FormatException("Unexpected end of value"); + if (value[pos] == '\'') + { + pos++; + goto WaitCharComplex; + } + if (value[pos] == '\\') + { + pos++; if (pos >= value.Length) throw new FormatException("Missing escaped character after \\ at end of value"); - sb.Append(value[pos++]); - - WaitEndWord: - if (pos >= value.Length || char.IsWhiteSpace(value[pos])) - { - lexemes.Add(new Lexeme(sb.ToString())); - if (pos >= value.Length) - goto Finish; - pos++; - goto WaitWord; - } - if (value[pos] == '\\') - { - pos++; - goto WaitNextChar; - } - if (value[pos] == ':') - { - pos++; - goto StartPosInfo; - } - sb.Append(value[pos++]); - goto WaitEndWord; + } + sb.Append(value[pos++]); + goto WaitEndComplex; - WaitEndComplex: - if (pos >= value.Length) - throw new FormatException("Unexpected end of value"); - if (value[pos] == '\'') + WaitCharComplex: + if (pos < value.Length && value[pos] == '\'') + { + sb.Append('\''); + pos++; + goto WaitEndComplex; + } + if (pos < value.Length && value[pos] == ':') + { + pos++; + goto StartPosInfo; + } + lexemes.Add(new Lexeme(sb.ToString())); + goto WaitWord; + + StartPosInfo: + wordEntryPositions = new List(); + + InPosInfo: + var digitPos = pos; + while (pos < value.Length && value[pos] >= '0' && value[pos] <= '9') + pos++; + if (digitPos == pos) + throw new FormatException("Missing length after :"); + wordPos = int.Parse(value.Substring(digitPos, pos - digitPos)); + + // Note: PostgreSQL backend parser matches also for example 1DD2A, which is parsed into 1A, but not 1AA2D ... + if (pos < value.Length) + { + if (value[pos] == 'A' || value[pos] == 'a' || value[pos] == '*') // Why * ? { + wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos, Lexeme.Weight.A)); pos++; - goto WaitCharComplex; + goto WaitPosDelim; } - if (value[pos] == '\\') + if (value[pos] >= 'B' && value[pos] <= 'D' || value[pos] >= 'b' && value[pos] <= 'd') { + var weight = value[pos]; + if (weight >= 'b' && weight <= 'd') + weight = (char)(weight - ('b' - 'B')); + wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos, Lexeme.Weight.D + ('D' - weight))); pos++; - if (pos >= value.Length) - throw new FormatException("Missing escaped character after \\ at end of value"); + goto WaitPosDelim; } - sb.Append(value[pos++]); - goto WaitEndComplex; + } + wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos)); - WaitCharComplex: - if (pos < value.Length && value[pos] == '\'') - { - sb.Append('\''); - pos++; - goto WaitEndComplex; - } - if (pos < value.Length && value[pos] == ':') - { + WaitPosDelim: + if (pos >= value.Length || char.IsWhiteSpace(value[pos])) + { + if (pos < value.Length) pos++; - goto StartPosInfo; - } - lexemes.Add(new Lexeme(sb.ToString())); + lexemes.Add(new Lexeme(sb.ToString(), wordEntryPositions)); goto WaitWord; + } + if (value[pos] == ',') + { + pos++; + goto InPosInfo; + } + throw new FormatException("Missing comma, whitespace or end of value after lexeme pos info"); - StartPosInfo: - wordEntryPositions = new List(); + Finish: + return new NpgsqlTsVector(lexemes); + } - InPosInfo: - var digitPos = pos; - while (pos < value.Length && value[pos] >= '0' && value[pos] <= '9') - pos++; - if (digitPos == pos) - throw new FormatException("Missing length after :"); - wordPos = int.Parse(value.Substring(digitPos, pos - digitPos)); + /// + /// Returns the lexeme at a specific index + /// + /// + /// + public Lexeme this[int index] + { + get + { + if (index < 0 || index >= _lexemes.Count) + throw new ArgumentException(nameof(index)); - // Note: PostgreSQL backend parser matches also for example 1DD2A, which is parsed into 1A, but not 1AA2D ... - if (pos < value.Length) + return _lexemes[index]; + } + } + + /// + /// Gets the number of lexemes. + /// + public int Count => _lexemes.Count; + + /// + /// Returns an enumerator. + /// + /// + public IEnumerator GetEnumerator() => _lexemes.GetEnumerator(); + + /// + /// Returns an enumerator. + /// + /// + IEnumerator IEnumerable.GetEnumerator() => _lexemes.GetEnumerator(); + + /// + /// Gets a string representation in PostgreSQL's format. + /// + /// + public override string ToString() => string.Join(" ", _lexemes); + + /// + public bool Equals(NpgsqlTsVector? other) + { + if (ReferenceEquals(this, other)) + return true; + + if (other is null || _lexemes.Count != other._lexemes.Count) + return false; + + for (var i = 0; i < _lexemes.Count; i++) + if (!_lexemes[i].Equals(other._lexemes[i])) + return false; + + return true; + } + + /// + public override bool Equals(object? obj) + => obj is NpgsqlTsVector other && Equals(other); + + /// + public override int GetHashCode() + { + var hash = new HashCode(); + + foreach (var lexeme in _lexemes) + hash.Add(lexeme); + + return hash.ToHashCode(); + } + + /// + /// Represents a lexeme. A lexeme consists of a text string and optional word entry positions. + /// + public struct Lexeme : IEquatable + { + /// + /// Gets or sets the text. + /// + public string Text { get; set; } + + internal readonly List? WordEntryPositions; + + /// + /// Creates a lexeme with no word entry positions. + /// + /// + public Lexeme(string text) + { + Text = text; + WordEntryPositions = null; + } + + /// + /// Creates a lexeme with word entry positions. + /// + /// + /// + public Lexeme(string text, List? wordEntryPositions) + : this(text, wordEntryPositions, false) {} + + internal Lexeme(string text, List? wordEntryPositions, bool noCopy) + { + Text = text; + if (wordEntryPositions != null) + WordEntryPositions = noCopy ? wordEntryPositions : new List(wordEntryPositions); + else + WordEntryPositions = null; + } + + internal static List? UniquePos(List? list) + { + if (list == null) + return null; + var needsProcessing = false; + for (var i = 1; i < list.Count; i++) { - if (value[pos] == 'A' || value[pos] == 'a' || value[pos] == '*') // Why * ? - { - wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos, Lexeme.Weight.A)); - pos++; - goto WaitPosDelim; - } - if (value[pos] >= 'B' && value[pos] <= 'D' || value[pos] >= 'b' && value[pos] <= 'd') + if (list[i - 1].Pos >= list[i].Pos) { - var weight = value[pos]; - if (weight >= 'b' && weight <= 'd') - weight = (char)(weight - ('b' - 'B')); - wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos, Lexeme.Weight.D + ('D' - weight))); - pos++; - goto WaitPosDelim; + needsProcessing = true; + break; } } - wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos)); + if (!needsProcessing) + return list; + + // Don't change the original list, as the user might inspect it later if he holds a reference to the lexeme's list + list = new List(list); - WaitPosDelim: - if (pos >= value.Length || char.IsWhiteSpace(value[pos])) + list.Sort((x, y) => x.Pos.CompareTo(y.Pos)); + + var a = 0; + for (var b = 1; b < list.Count; b++) { - if (pos < value.Length) - pos++; - lexemes.Add(new Lexeme(sb.ToString(), wordEntryPositions)); - goto WaitWord; + if (list[a].Pos != list[b].Pos) + { + a++; + if (a != b) + list[a] = list[b]; + } + else if (list[b].Weight > list[a].Weight) + list[a] = list[b]; } - if (value[pos] == ',') + if (a != list.Count - 1) { - pos++; - goto InPosInfo; + list.RemoveRange(a, list.Count - 1 - a); } - throw new FormatException("Missing comma, whitespace or end of value after lexeme pos info"); - - Finish: - return new NpgsqlTsVector(lexemes); + return list; } /// - /// Returns the lexeme at a specific index + /// Gets a word entry position. /// /// /// - public Lexeme this[int index] + public WordEntryPos this[int index] { get { - if (index < 0 || index >= _lexemes.Count) + if (index < 0 || WordEntryPositions == null || index >= WordEntryPositions.Count) throw new ArgumentException(nameof(index)); - return _lexemes[index]; + return WordEntryPositions[index]; } - } - - /// - /// Gets the number of lexemes. - /// - public int Count => _lexemes.Count; + internal set + { + if (index < 0 || WordEntryPositions == null || index >= WordEntryPositions.Count) + throw new ArgumentOutOfRangeException(nameof(index)); - /// - /// Returns an enumerator. - /// - /// - public IEnumerator GetEnumerator() => _lexemes.GetEnumerator(); + WordEntryPositions[index] = value; + } + } /// - /// Returns an enumerator. + /// Gets the number of word entry positions. /// - /// - IEnumerator IEnumerable.GetEnumerator() => _lexemes.GetEnumerator(); + public int Count => WordEntryPositions?.Count ?? 0; /// - /// Gets a string representation in PostgreSQL's format. + /// Creates a string representation in PostgreSQL's format. /// /// - public override string ToString() => string.Join(" ", _lexemes); + public override string ToString() + { + var str = '\'' + (Text ?? "").Replace(@"\", @"\\").Replace("'", "''") + '\''; + if (Count > 0) + str += ":" + string.Join(",", WordEntryPositions!); + return str; + } /// - /// Represents a lexeme. A lexeme consists of a text string and optional word entry positions. + /// Represents a word entry position and an optional weight. /// - public struct Lexeme : IEquatable + public struct WordEntryPos : IEquatable { - /// - /// Gets or sets the text. - /// - public string Text { get; set; } + internal short Value { get; } - internal readonly List? WordEntryPositions; - - /// - /// Creates a lexeme with no word entry positions. - /// - /// - public Lexeme(string text) + internal WordEntryPos(short value) { - Text = text; - WordEntryPositions = null; + Value = value; } /// - /// Creates a lexeme with word entry positions. + /// Creates a WordEntryPos with a given position and weight. /// - /// - /// - public Lexeme(string text, List? wordEntryPositions) - : this(text, wordEntryPositions, false) {} - - internal Lexeme(string text, List? wordEntryPositions, bool noCopy) - { - Text = text; - if (wordEntryPositions != null) - WordEntryPositions = noCopy ? wordEntryPositions : new List(wordEntryPositions); - else - WordEntryPositions = null; - } - - internal static List? UniquePos(List? list) + /// Position values can range from 1 to 16383; larger numbers are silently set to 16383. + /// A weight labeled between A and D. + public WordEntryPos(int pos, Weight weight = Weight.D) { - if (list == null) - return null; - var needsProcessing = false; - for (var i = 1; i < list.Count; i++) - { - if (list[i - 1].Pos >= list[i].Pos) - { - needsProcessing = true; - break; - } - } - if (!needsProcessing) - return list; - - // Don't change the original list, as the user might inspect it later if he holds a reference to the lexeme's list - list = new List(list); + if (pos == 0) + throw new ArgumentOutOfRangeException(nameof(pos), "Lexeme position is out of range. Min value is 1, max value is 2^14-1. Value was: " + pos); + if (weight < Weight.D || weight > Weight.A) + throw new ArgumentOutOfRangeException(nameof(weight)); - list.Sort((x, y) => x.Pos.CompareTo(y.Pos)); + // Per documentation: "Position values can range from 1 to 16383; larger numbers are silently set to 16383." + if (pos >> 14 != 0) + pos = (1 << 14) - 1; - var a = 0; - for (var b = 1; b < list.Count; b++) - { - if (list[a].Pos != list[b].Pos) - { - a++; - if (a != b) - list[a] = list[b]; - } - else if (list[b].Weight > list[a].Weight) - list[a] = list[b]; - } - if (a != list.Count - 1) - { - list.RemoveRange(a, list.Count - 1 - a); - } - return list; + Value = (short)(((int)weight << 14) | pos); } /// - /// Gets a word entry position. + /// The weight is labeled from A to D. D is the default, and not printed. /// - /// - /// - public WordEntryPos this[int index] - { - get - { - if (index < 0 || WordEntryPositions == null || index >= WordEntryPositions.Count) - throw new ArgumentException(nameof(index)); - - return WordEntryPositions[index]; - } - internal set - { - if (index < 0 || WordEntryPositions == null || index >= WordEntryPositions.Count) - throw new ArgumentOutOfRangeException(nameof(index)); - - WordEntryPositions[index] = value; - } - } + public Weight Weight => (Weight)((Value >> 14) & 3); /// - /// Gets the number of word entry positions. + /// The position is a 14-bit unsigned integer indicating the position in the text this lexeme occurs. Cannot be 0. /// - public int Count => WordEntryPositions?.Count ?? 0; + public int Pos => Value & ((1 << 14) - 1); /// - /// Creates a string representation in PostgreSQL's format. + /// Prints this lexeme in PostgreSQL's format, i.e. position is followed by weight (weight is only printed if A, B or C). /// /// public override string ToString() { - var str = '\'' + (Text ?? "").Replace(@"\", @"\\").Replace("'", "''") + '\''; - if (Count > 0) - str += ":" + string.Join(",", WordEntryPositions!); - return str; + if (Weight != Weight.D) + return Pos + Weight.ToString(); + return Pos.ToString(); } /// - /// Represents a word entry position and an optional weight. + /// Determines whether the specified object is equal to the current object. /// - public struct WordEntryPos : IEquatable - { - internal short Value { get; } - - internal WordEntryPos(short value) - { - Value = value; - } + public bool Equals(WordEntryPos o) => Value == o.Value; - /// - /// Creates a WordEntryPos with a given position and weight. - /// - /// Position values can range from 1 to 16383; larger numbers are silently set to 16383. - /// A weight labeled between A and D. - public WordEntryPos(int pos, Weight weight = Weight.D) - { - if (pos == 0) - throw new ArgumentOutOfRangeException(nameof(pos), "Lexeme position is out of range. Min value is 1, max value is 2^14-1. Value was: " + pos); - if (weight < Weight.D || weight > Weight.A) - throw new ArgumentOutOfRangeException(nameof(weight)); - - // Per documentation: "Position values can range from 1 to 16383; larger numbers are silently set to 16383." - if (pos >> 14 != 0) - pos = (1 << 14) - 1; - - Value = (short)(((int)weight << 14) | pos); - } - - /// - /// The weight is labeled from A to D. D is the default, and not printed. - /// - public Weight Weight => (Weight)((Value >> 14) & 3); - - /// - /// The position is a 14-bit unsigned integer indicating the position in the text this lexeme occurs. Cannot be 0. - /// - public int Pos => Value & ((1 << 14) - 1); - - /// - /// Prints this lexeme in PostgreSQL's format, i.e. position is followed by weight (weight is only printed if A, B or C). - /// - /// - public override string ToString() - { - if (Weight != Weight.D) - return Pos + Weight.ToString(); - return Pos.ToString(); - } - - /// - /// Determines whether the specified object is equal to the current object. - /// - public bool Equals(WordEntryPos o) => Value == o.Value; - - /// - /// Determines whether the specified object is equal to the current object. - /// - public override bool Equals(object? o) => o is WordEntryPos pos && Equals(pos); - - /// - /// Gets a hash code for the current object. - /// - public override int GetHashCode() => Value.GetHashCode(); - - /// - /// Determines whether the specified object is equal to the current object. - /// - public static bool operator ==(WordEntryPos left, WordEntryPos right) => left.Equals(right); - - /// - /// Determines whether the specified object is unequal to the current object. - /// - public static bool operator !=(WordEntryPos left, WordEntryPos right) => !left.Equals(right); - } + /// + /// Determines whether the specified object is equal to the current object. + /// + public override bool Equals(object? o) => o is WordEntryPos pos && Equals(pos); /// - /// The weight is labeled from A to D. D is the default, and not printed. + /// Gets a hash code for the current object. /// - public enum Weight - { - /// - /// D, the default - /// - D = 0, - - /// - /// C - /// - C = 1, - - /// - /// B - /// - B = 2, - - /// - /// A - /// - A = 3 - } + public override int GetHashCode() => Value.GetHashCode(); /// /// Determines whether the specified object is equal to the current object. /// - public bool Equals(Lexeme o) - => Text == o.Text && - ((WordEntryPositions == null && o.WordEntryPositions == null) || - (WordEntryPositions != null && WordEntryPositions.Equals(o.WordEntryPositions))); + public static bool operator ==(WordEntryPos left, WordEntryPos right) => left.Equals(right); /// - /// Determines whether the specified object is equal to the current object. + /// Determines whether the specified object is unequal to the current object. /// - public override bool Equals(object? o) => o is Lexeme lexeme && Equals(lexeme); + public static bool operator !=(WordEntryPos left, WordEntryPos right) => !left.Equals(right); + } + /// + /// The weight is labeled from A to D. D is the default, and not printed. + /// + public enum Weight + { /// - /// Gets a hash code for the current object. + /// D, the default /// - public override int GetHashCode() => Text.GetHashCode(); + D = 0, /// - /// Determines whether the specified object is equal to the current object. + /// C /// - public static bool operator ==(Lexeme left, Lexeme right) => left.Equals(right); + C = 1, /// - /// Determines whether the specified object is unequal to the current object. + /// B /// - public static bool operator !=(Lexeme left, Lexeme right) => !left.Equals(right); + B = 2, + + /// + /// A + /// + A = 3 + } + + /// + /// Determines whether the specified object is equal to the current object. + /// + public bool Equals(Lexeme o) + { + if (Text != o.Text) + return false; + + if (WordEntryPositions is null) + return o.WordEntryPositions is null; + + if (o.WordEntryPositions is null || WordEntryPositions.Count != o.WordEntryPositions.Count) + return false; + + for (var i = 0; i < WordEntryPositions.Count; i++) + if (!WordEntryPositions[i].Equals(o.WordEntryPositions[i])) + return false; + + return true; } + + /// + /// Determines whether the specified object is equal to the current object. + /// + public override bool Equals(object? o) => o is Lexeme lexeme && Equals(lexeme); + + /// + /// Gets a hash code for the current object. + /// + public override int GetHashCode() => Text.GetHashCode(); + + /// + /// Determines whether the specified object is equal to the current object. + /// + public static bool operator ==(Lexeme left, Lexeme right) => left.Equals(right); + + /// + /// Determines whether the specified object is unequal to the current object. + /// + public static bool operator !=(Lexeme left, Lexeme right) => !left.Equals(right); } } diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs index 6ea43575aa..753f0f0919 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs @@ -1,651 +1,585 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Globalization; using System.Net; using System.Net.Sockets; using System.Text; -using System.Text.RegularExpressions; -using Npgsql.Util; #pragma warning disable 1591 // ReSharper disable once CheckNamespace -namespace NpgsqlTypes +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL point type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html +/// +public struct NpgsqlPoint : IEquatable { - /// - /// Represents a PostgreSQL point type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - /// - public struct NpgsqlPoint : IEquatable - { - static readonly Regex Regex = new Regex(@"\((-?\d+.?\d*),(-?\d+.?\d*)\)"); + public double X { get; set; } + public double Y { get; set; } - public double X { get; set; } - public double Y { get; set; } - - public NpgsqlPoint(double x, double y) - : this() - { - X = x; - Y = y; - } + public NpgsqlPoint(double x, double y) + : this() + { + X = x; + Y = y; + } - // ReSharper disable CompareOfFloatsByEqualityOperator - public bool Equals(NpgsqlPoint other) => X == other.X && Y == other.Y; - // ReSharper restore CompareOfFloatsByEqualityOperator + // ReSharper disable CompareOfFloatsByEqualityOperator + public bool Equals(NpgsqlPoint other) => X == other.X && Y == other.Y; + // ReSharper restore CompareOfFloatsByEqualityOperator - public override bool Equals(object? obj) - => obj is NpgsqlPoint point && Equals(point); + public override bool Equals(object? obj) + => obj is NpgsqlPoint point && Equals(point); - public static bool operator ==(NpgsqlPoint x, NpgsqlPoint y) => x.Equals(y); + public static bool operator ==(NpgsqlPoint x, NpgsqlPoint y) => x.Equals(y); - public static bool operator !=(NpgsqlPoint x, NpgsqlPoint y) => !(x == y); + public static bool operator !=(NpgsqlPoint x, NpgsqlPoint y) => !(x == y); - public override int GetHashCode() - => X.GetHashCode() ^ PGUtil.RotateShift(Y.GetHashCode(), PGUtil.BitsInInt / 2); + public override int GetHashCode() + => HashCode.Combine(X, Y); - public static NpgsqlPoint Parse(string s) - { - var m = Regex.Match(s); - if (!m.Success) { - throw new FormatException("Not a valid point: " + s); - } - return new NpgsqlPoint(double.Parse(m.Groups[1].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[2].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat)); - } + public override string ToString() + => string.Format(CultureInfo.InvariantCulture, "({0},{1})", X, Y); +} - public override string ToString() - => string.Format(CultureInfo.InvariantCulture, "({0},{1})", X, Y); - } +/// +/// Represents a PostgreSQL line type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html +/// +public struct NpgsqlLine : IEquatable +{ + public double A { get; set; } + public double B { get; set; } + public double C { get; set; } - /// - /// Represents a PostgreSQL line type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - /// - public struct NpgsqlLine : IEquatable + public NpgsqlLine(double a, double b, double c) + : this() { - static readonly Regex Regex = new Regex(@"\{(-?\d+.?\d*),(-?\d+.?\d*),(-?\d+.?\d*)\}"); + A = a; + B = b; + C = c; + } - public double A { get; set; } - public double B { get; set; } - public double C { get; set; } + public override string ToString() + => string.Format(CultureInfo.InvariantCulture, "{{{0},{1},{2}}}", A, B, C); - public NpgsqlLine(double a, double b, double c) - : this() - { - A = a; - B = b; - C = c; - } + public override int GetHashCode() + => HashCode.Combine(A, B, C); - public static NpgsqlLine Parse(string s) - { - var m = Regex.Match(s); - if (!m.Success) - throw new FormatException("Not a valid line: " + s); - return new NpgsqlLine( - double.Parse(m.Groups[1].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[2].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[3].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat) - ); - } + public bool Equals(NpgsqlLine other) + => A == other.A && B == other.B && C == other.C; - public override string ToString() - => string.Format(CultureInfo.InvariantCulture, "{{{0},{1},{2}}}", A, B, C); + public override bool Equals(object? obj) + => obj is NpgsqlLine line && Equals(line); - public override int GetHashCode() => A.GetHashCode() * B.GetHashCode() * C.GetHashCode(); - - public bool Equals(NpgsqlLine other) => A == other.A && B == other.B && C == other.C; + public static bool operator ==(NpgsqlLine x, NpgsqlLine y) => x.Equals(y); + public static bool operator !=(NpgsqlLine x, NpgsqlLine y) => !(x == y); +} - public override bool Equals(object? obj) - => obj is NpgsqlLine line && Equals(line); +/// +/// Represents a PostgreSQL Line Segment type. +/// +public struct NpgsqlLSeg : IEquatable +{ + public NpgsqlPoint Start { get; set; } + public NpgsqlPoint End { get; set; } - public static bool operator ==(NpgsqlLine x, NpgsqlLine y) => x.Equals(y); - public static bool operator !=(NpgsqlLine x, NpgsqlLine y) => !(x == y); + public NpgsqlLSeg(NpgsqlPoint start, NpgsqlPoint end) + : this() + { + Start = start; + End = end; } - /// - /// Represents a PostgreSQL Line Segment type. - /// - public struct NpgsqlLSeg : IEquatable + public NpgsqlLSeg(double startx, double starty, double endx, double endy) : this() { - static readonly Regex Regex = new Regex(@"\[\((-?\d+.?\d*),(-?\d+.?\d*)\),\((-?\d+.?\d*),(-?\d+.?\d*)\)\]"); + Start = new NpgsqlPoint(startx, starty); + End = new NpgsqlPoint(endx, endy); + } + + public override string ToString() + => string.Format(CultureInfo.InvariantCulture, "[{0},{1}]", Start, End); - public NpgsqlPoint Start { get; set; } - public NpgsqlPoint End { get; set; } + public override int GetHashCode() + => HashCode.Combine(Start.X, Start.Y, End.X, End.Y); - public NpgsqlLSeg(NpgsqlPoint start, NpgsqlPoint end) - : this() + public bool Equals(NpgsqlLSeg other) + => Start == other.Start && End == other.End; + + public override bool Equals(object? obj) + => obj is NpgsqlLSeg seg && Equals(seg); + + public static bool operator ==(NpgsqlLSeg x, NpgsqlLSeg y) => x.Equals(y); + public static bool operator !=(NpgsqlLSeg x, NpgsqlLSeg y) => !(x == y); +} + +/// +/// Represents a PostgreSQL box type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/datatype-geometric.html +/// +public struct NpgsqlBox : IEquatable +{ + NpgsqlPoint _upperRight; + public NpgsqlPoint UpperRight + { + get => _upperRight; + set { - Start = start; - End = end; + _upperRight = value; + NormalizeBox(); } + } - public NpgsqlLSeg(double startx, double starty, double endx, double endy) : this() + NpgsqlPoint _lowerLeft; + public NpgsqlPoint LowerLeft + { + get => _lowerLeft; + set { - Start = new NpgsqlPoint(startx, starty); - End = new NpgsqlPoint(endx, endy); + _lowerLeft = value; + NormalizeBox(); } + } - public static NpgsqlLSeg Parse(string s) - { - var m = Regex.Match(s); - if (!m.Success) { - throw new FormatException("Not a valid line: " + s); - } - return new NpgsqlLSeg( - double.Parse(m.Groups[1].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[2].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[3].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[4].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat) - ); + public NpgsqlBox(NpgsqlPoint upperRight, NpgsqlPoint lowerLeft) : this() + { + _upperRight = upperRight; + _lowerLeft = lowerLeft; + NormalizeBox(); + } - } + public NpgsqlBox(double top, double right, double bottom, double left) + : this(new NpgsqlPoint(right, top), new NpgsqlPoint(left, bottom)) { } - public override string ToString() - => string.Format(CultureInfo.InvariantCulture, "[{0},{1}]", Start, End); + public double Left => LowerLeft.X; + public double Right => UpperRight.X; + public double Bottom => LowerLeft.Y; + public double Top => UpperRight.Y; + public double Width => Right - Left; + public double Height => Top - Bottom; - public override int GetHashCode() - => Start.X.GetHashCode() ^ - PGUtil.RotateShift(Start.Y.GetHashCode(), PGUtil.BitsInInt / 4) ^ - PGUtil.RotateShift(End.X.GetHashCode(), PGUtil.BitsInInt / 2) ^ - PGUtil.RotateShift(End.Y.GetHashCode(), PGUtil.BitsInInt * 3 / 4); + public bool IsEmpty => Width == 0 || Height == 0; - public bool Equals(NpgsqlLSeg other) => Start == other.Start && End == other.End; + public bool Equals(NpgsqlBox other) + => UpperRight == other.UpperRight && LowerLeft == other.LowerLeft; - public override bool Equals(object? obj) - => obj is NpgsqlLSeg seg && Equals(seg); + public override bool Equals(object? obj) + => obj is NpgsqlBox box && Equals(box); - public static bool operator ==(NpgsqlLSeg x, NpgsqlLSeg y) => x.Equals(y); - public static bool operator !=(NpgsqlLSeg x, NpgsqlLSeg y) => !(x == y); - } + public static bool operator ==(NpgsqlBox x, NpgsqlBox y) => x.Equals(y); + public static bool operator !=(NpgsqlBox x, NpgsqlBox y) => !(x == y); + public override string ToString() + => string.Format(CultureInfo.InvariantCulture, "{0},{1}", UpperRight, LowerLeft); - /// - /// Represents a PostgreSQL box type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html - /// - public struct NpgsqlBox : IEquatable - { - static readonly Regex Regex = new Regex(@"\((-?\d+.?\d*),(-?\d+.?\d*)\),\((-?\d+.?\d*),(-?\d+.?\d*)\)"); + public override int GetHashCode() + => HashCode.Combine(Top, Right, Bottom, LowerLeft); - public NpgsqlPoint UpperRight { get; set; } - public NpgsqlPoint LowerLeft { get; set; } + // Swaps corners for isomorphic boxes, to mirror postgres behavior. + // See: https://github.com/postgres/postgres/blob/af2324fabf0020e464b0268be9ef03e8f46ed84b/src/backend/utils/adt/geo_ops.c#L435-L447 + void NormalizeBox() + { + if (_upperRight.X < _lowerLeft.X) + (_upperRight.X, _lowerLeft.X) = (_lowerLeft.X, _upperRight.X); - public NpgsqlBox(NpgsqlPoint upperRight, NpgsqlPoint lowerLeft) : this() - { - UpperRight = upperRight; - LowerLeft = lowerLeft; - } + if (_upperRight.Y < _lowerLeft.Y) + (_upperRight.Y, _lowerLeft.Y) = (_lowerLeft.Y, _upperRight.Y); + } +} - public NpgsqlBox(double top, double right, double bottom, double left) - : this(new NpgsqlPoint(right, top), new NpgsqlPoint(left, bottom)) { } +/// +/// Represents a PostgreSQL Path type. +/// +public struct NpgsqlPath : IList, IEquatable +{ + readonly List _points; + public bool Open { get; set; } - public double Left => LowerLeft.X; - public double Right => UpperRight.X; - public double Bottom => LowerLeft.Y; - public double Top => UpperRight.Y; - public double Width => Right - Left; - public double Height => Top - Bottom; + public NpgsqlPath() + => _points = new(); - public bool IsEmpty => Width == 0 || Height == 0; + public NpgsqlPath(IEnumerable points, bool open) + { + _points = new List(points); + Open = open; + } - public bool Equals(NpgsqlBox other) => UpperRight == other.UpperRight && LowerLeft == other.LowerLeft; + public NpgsqlPath(IEnumerable points) : this(points, false) {} + public NpgsqlPath(params NpgsqlPoint[] points) : this(points, false) {} - public override bool Equals(object? obj) - => obj is NpgsqlBox box && Equals(box); + public NpgsqlPath(bool open) : this() + { + _points = new List(); + Open = open; + } - public static bool operator ==(NpgsqlBox x, NpgsqlBox y) => x.Equals(y); - public static bool operator !=(NpgsqlBox x, NpgsqlBox y) => !(x == y); - public override string ToString() - => string.Format(CultureInfo.InvariantCulture, "{0},{1}", UpperRight, LowerLeft); + public NpgsqlPath(int capacity, bool open) : this() + { + _points = new List(capacity); + Open = open; + } - public static NpgsqlBox Parse(string s) - { - var m = Regex.Match(s); - return new NpgsqlBox( - new NpgsqlPoint(double.Parse(m.Groups[1].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[2].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat)), - new NpgsqlPoint(double.Parse(m.Groups[3].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[4].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat)) - ); - } + public NpgsqlPath(int capacity) : this(capacity, false) {} - public override int GetHashCode() - => Top.GetHashCode() ^ - PGUtil.RotateShift(Right.GetHashCode(), PGUtil.BitsInInt / 4) ^ - PGUtil.RotateShift(Bottom.GetHashCode(), PGUtil.BitsInInt / 2) ^ - PGUtil.RotateShift(LowerLeft.GetHashCode(), PGUtil.BitsInInt * 3 / 4); + public NpgsqlPoint this[int index] + { + get => _points[index]; + set => _points[index] = value; } - /// - /// Represents a PostgreSQL Path type. - /// - public struct NpgsqlPath : IList, IEquatable + public int Capacity => _points.Capacity; + public int Count => _points.Count; + public bool IsReadOnly => false; + + public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); + public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); + public void RemoveAt(int index) => _points.RemoveAt(index); + public void Add(NpgsqlPoint item) => _points.Add(item); + public void Clear() => _points.Clear(); + public bool Contains(NpgsqlPoint item) => _points.Contains(item); + public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); + public bool Remove(NpgsqlPoint item) => _points.Remove(item); + public IEnumerator GetEnumerator() => _points.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public bool Equals(NpgsqlPath other) { - readonly List _points; - public bool Open { get; set; } + if (Open != other.Open || Count != other.Count) + return false; + if (ReferenceEquals(_points, other._points))//Short cut for shallow copies. + return true; + for (var i = 0; i != Count; ++i) + if (this[i] != other[i]) + return false; + return true; + } - public NpgsqlPath(IEnumerable points, bool open) : this() - { - _points = new List(points); - Open = open; - } + public override bool Equals(object? obj) + => obj is NpgsqlPath path && Equals(path); - public NpgsqlPath(IEnumerable points) : this(points, false) {} - public NpgsqlPath(params NpgsqlPoint[] points) : this(points, false) {} + public static bool operator ==(NpgsqlPath x, NpgsqlPath y) => x.Equals(y); + public static bool operator !=(NpgsqlPath x, NpgsqlPath y) => !(x == y); - public NpgsqlPath(bool open) : this() - { - _points = new List(); - Open = open; - } + public override int GetHashCode() + { + var hashCode = new HashCode(); + hashCode.Add(Open); - public NpgsqlPath(int capacity, bool open) : this() + foreach (var point in this) { - _points = new List(capacity); - Open = open; + hashCode.Add(point.X); + hashCode.Add(point.Y); } - public NpgsqlPath(int capacity) : this(capacity, false) {} + return hashCode.ToHashCode(); + } - public NpgsqlPoint this[int index] - { - get => _points[index]; - set => _points[index] = value; - } + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append(Open ? '[' : '('); + int i; + for (i = 0; i < _points.Count; i++) + { + var p = _points[i]; + sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); + if (i < _points.Count - 1) + sb.Append(","); + } + sb.Append(Open ? ']' : ')'); + return sb.ToString(); + } +} - public int Capacity => _points.Capacity; - public int Count => _points.Count; - public bool IsReadOnly => false; - - public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); - public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); - public void RemoveAt(int index) => _points.RemoveAt(index); - public void Add(NpgsqlPoint item) => _points.Add(item); - public void Clear() => _points.Clear(); - public bool Contains(NpgsqlPoint item) => _points.Contains(item); - public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); - public bool Remove(NpgsqlPoint item) => _points.Remove(item); - public IEnumerator GetEnumerator() => _points.GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public bool Equals(NpgsqlPath other) - { - if (Open != other.Open || Count != other.Count) - return false; - if (ReferenceEquals(_points, other._points))//Short cut for shallow copies. - return true; - for (var i = 0; i != Count; ++i) - if (this[i] != other[i]) - return false; - return true; - } +/// +/// Represents a PostgreSQL Polygon type. +/// +public readonly struct NpgsqlPolygon : IList, IEquatable +{ + readonly List _points; - public override bool Equals(object? obj) - => obj is NpgsqlPath path && Equals(path); + public NpgsqlPolygon() + => _points = new(); - public static bool operator ==(NpgsqlPath x, NpgsqlPath y) => x.Equals(y); - public static bool operator !=(NpgsqlPath x, NpgsqlPath y) => !(x == y); + public NpgsqlPolygon(IEnumerable points) + => _points = new List(points); - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - foreach (var point in this) - { - //The ideal amount to shift each value is one that would evenly spread it throughout - //the resultant bytes. Using the current result % 32 is essentially using a random value - //but one that will be the same on subsequent calls. - ret ^= PGUtil.RotateShift(point.GetHashCode(), ret % PGUtil.BitsInInt); - } - return Open ? ret : -ret; - } + public NpgsqlPolygon(params NpgsqlPoint[] points) : this((IEnumerable) points) {} - public override string ToString() - { - var sb = new StringBuilder(); - sb.Append(Open ? '[' : '('); - int i; - for (i = 0; i < _points.Count; i++) - { - var p = _points[i]; - sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); - if (i < _points.Count - 1) - sb.Append(","); - } - sb.Append(Open ? ']' : ')'); - return sb.ToString(); - } + public NpgsqlPolygon(int capacity) + => _points = new List(capacity); - public static NpgsqlPath Parse(string s) - { - var open = s[0] switch - { - '[' => true, - '(' => false, - _ => throw new Exception("Invalid path string: " + s) - }; - Debug.Assert(s[s.Length - 1] == (open ? ']' : ')')); - var result = new NpgsqlPath(open); - var i = 1; - while (true) - { - var i2 = s.IndexOf(')', i); - result.Add(NpgsqlPoint.Parse(s.Substring(i, i2 - i + 1))); - if (s[i2 + 1] != ',') - break; - i = i2 + 2; - } - return result; - } + public NpgsqlPoint this[int index] + { + get => _points[index]; + set => _points[index] = value; } - /// - /// Represents a PostgreSQL Polygon type. - /// - public struct NpgsqlPolygon : IList, IEquatable + public int Capacity => _points.Capacity; + public int Count => _points.Count; + public bool IsReadOnly => false; + + public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); + public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); + public void RemoveAt(int index) => _points.RemoveAt(index); + public void Add(NpgsqlPoint item) => _points.Add(item); + public void Clear() => _points.Clear(); + public bool Contains(NpgsqlPoint item) => _points.Contains(item); + public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); + public bool Remove(NpgsqlPoint item) => _points.Remove(item); + public IEnumerator GetEnumerator() => _points.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public bool Equals(NpgsqlPolygon other) { - readonly List _points; + if (Count != other.Count) + return false; + if (ReferenceEquals(_points, other._points)) + return true; + for (var i = 0; i != Count; ++i) + if (this[i] != other[i]) + return false; + return true; + } - public NpgsqlPolygon(IEnumerable points) - { - _points = new List(points); - } + public override bool Equals(object? obj) + => obj is NpgsqlPolygon polygon && Equals(polygon); - public NpgsqlPolygon(params NpgsqlPoint[] points) : this ((IEnumerable) points) {} + public static bool operator ==(NpgsqlPolygon x, NpgsqlPolygon y) => x.Equals(y); + public static bool operator !=(NpgsqlPolygon x, NpgsqlPolygon y) => !(x == y); - public NpgsqlPolygon(int capacity) - { - _points = new List(capacity); - } - - public NpgsqlPoint this[int index] - { - get => _points[index]; - set => _points[index] = value; - } + public override int GetHashCode() + { + var hashCode = new HashCode(); - public int Capacity => _points.Capacity; - public int Count => _points.Count; - public bool IsReadOnly => false; - - public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); - public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); - public void RemoveAt(int index) => _points.RemoveAt(index); - public void Add(NpgsqlPoint item) => _points.Add(item); - public void Clear() => _points.Clear(); - public bool Contains(NpgsqlPoint item) => _points.Contains(item); - public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); - public bool Remove(NpgsqlPoint item) => _points.Remove(item); - public IEnumerator GetEnumerator() => _points.GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public bool Equals(NpgsqlPolygon other) + foreach (var point in this) { - if (Count != other.Count) - return false; - if (ReferenceEquals(_points, other._points)) - return true; - for (var i = 0; i != Count; ++i) - if (this[i] != other[i]) - return false; - return true; + hashCode.Add(point.X); + hashCode.Add(point.Y); } - public override bool Equals(object? obj) - => obj is NpgsqlPolygon polygon && Equals(polygon); - - public static bool operator ==(NpgsqlPolygon x, NpgsqlPolygon y) => x.Equals(y); - public static bool operator !=(NpgsqlPolygon x, NpgsqlPolygon y) => !(x == y); + return hashCode.ToHashCode(); + } - public override int GetHashCode() - { - var ret = 266370105;//seed with something other than zero to make paths of all zeros hash differently. - foreach (var point in this) - { - //The ideal amount to shift each value is one that would evenly spread it throughout - //the resultant bytes. Using the current result % 32 is essentially using a random value - //but one that will be the same on subsequent calls. - ret ^= PGUtil.RotateShift(point.GetHashCode(), ret % PGUtil.BitsInInt); + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append('('); + int i; + for (i = 0; i < _points.Count; i++) + { + var p = _points[i]; + sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); + if (i < _points.Count - 1) { + sb.Append(","); } - return ret; } + sb.Append(')'); + return sb.ToString(); + } +} - public static NpgsqlPolygon Parse(string s) - { - var points = new List(); - var i = 1; - while (true) - { - var i2 = s.IndexOf(')', i); - points.Add(NpgsqlPoint.Parse(s.Substring(i, i2 - i + 1))); - if (s[i2 + 1] != ',') - break; - i = i2 + 2; - } - return new NpgsqlPolygon(points); - } +/// +/// Represents a PostgreSQL Circle type. +/// +public struct NpgsqlCircle : IEquatable +{ + public double X { get; set; } + public double Y { get; set; } + public double Radius { get; set; } - public override string ToString() - { - var sb = new StringBuilder(); - sb.Append('('); - int i; - for (i = 0; i < _points.Count; i++) - { - var p = _points[i]; - sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); - if (i < _points.Count - 1) { - sb.Append(","); - } - } - sb.Append(')'); - return sb.ToString(); - } + public NpgsqlCircle(NpgsqlPoint center, double radius) + : this() + { + X = center.X; + Y = center.Y; + Radius = radius; } - /// - /// Represents a PostgreSQL Circle type. - /// - public struct NpgsqlCircle : IEquatable + public NpgsqlCircle(double x, double y, double radius) : this() { - static readonly Regex Regex = new Regex(@"<\((-?\d+.?\d*),(-?\d+.?\d*)\),(\d+.?\d*)>"); + X = x; + Y = y; + Radius = radius; + } - public double X { get; set; } - public double Y { get; set; } - public double Radius { get; set; } + public NpgsqlPoint Center + { + get => new(X, Y); + set => (X, Y) = (value.X, value.Y); + } - public NpgsqlCircle(NpgsqlPoint center, double radius) - : this() - { - X = center.X; - Y = center.Y; - Radius = radius; - } + // ReSharper disable CompareOfFloatsByEqualityOperator + public bool Equals(NpgsqlCircle other) + => X == other.X && Y == other.Y && Radius == other.Radius; + // ReSharper restore CompareOfFloatsByEqualityOperator - public NpgsqlCircle(double x, double y, double radius) : this() - { - X = x; - Y = y; - Radius = radius; - } - - public NpgsqlPoint Center - { - get => new NpgsqlPoint(X, Y); - set - { - X = value.X; - Y = value.Y; - } - } + public override bool Equals(object? obj) + => obj is NpgsqlCircle circle && Equals(circle); - // ReSharper disable CompareOfFloatsByEqualityOperator - public bool Equals(NpgsqlCircle other) - => X == other.X && Y == other.Y && Radius == other.Radius; - // ReSharper restore CompareOfFloatsByEqualityOperator + public override string ToString() + => string.Format(CultureInfo.InvariantCulture, "<({0},{1}),{2}>", X, Y, Radius); - public override bool Equals(object? obj) - => obj is NpgsqlCircle circle && Equals(circle); + public static bool operator ==(NpgsqlCircle x, NpgsqlCircle y) => x.Equals(y); + public static bool operator !=(NpgsqlCircle x, NpgsqlCircle y) => !(x == y); - public static NpgsqlCircle Parse(string s) - { - var m = Regex.Match(s); - if (!m.Success) - throw new FormatException("Not a valid circle: " + s); - - return new NpgsqlCircle( - double.Parse(m.Groups[1].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[2].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat), - double.Parse(m.Groups[3].ToString(), NumberStyles.Any, CultureInfo.InvariantCulture.NumberFormat) - ); - } - - public override string ToString() - => string.Format(CultureInfo.InvariantCulture, "<({0},{1}),{2}>", X, Y, Radius); + public override int GetHashCode() + => HashCode.Combine(X, Y, Radius); +} - public static bool operator ==(NpgsqlCircle x, NpgsqlCircle y) => x.Equals(y); - public static bool operator !=(NpgsqlCircle x, NpgsqlCircle y) => !(x == y); +/// +/// Represents a PostgreSQL inet type, which is a combination of an IPAddress and a subnet mask. +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-net-types.html +/// +public readonly record struct NpgsqlInet +{ + public IPAddress Address { get; } + public byte Netmask { get; } - public override int GetHashCode() - => X.GetHashCode() * Y.GetHashCode() * Radius.GetHashCode(); + public NpgsqlInet(IPAddress address, byte netmask) + { + CheckAddressFamily(address); + Address = address; + Netmask = netmask; } - /// - /// Represents a PostgreSQL inet type, which is a combination of an IPAddress and a - /// subnet mask. - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-net-types.html - /// - [Obsolete("Use ValueTuple instead")] - public struct NpgsqlInet : IEquatable + public NpgsqlInet(IPAddress address) + : this(address, (byte)(address.AddressFamily == AddressFamily.InterNetwork ? 32 : 128)) { - public IPAddress Address { get; set; } - public int Netmask { get; set; } - - public NpgsqlInet(IPAddress address, int netmask) - { - if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) - throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); + } - Address = address; - Netmask = netmask; + public NpgsqlInet(string addr) + { + switch (addr.Split('/')) + { + case { Length: 2 } segments: + (Address, Netmask) = (IPAddress.Parse(segments[0]), byte.Parse(segments[1])); + break; + case { Length: 1 } segments: + var ipAddr = IPAddress.Parse(segments[0]); + CheckAddressFamily(ipAddr); + (Address, Netmask) = ( + ipAddr, + ipAddr.AddressFamily == AddressFamily.InterNetworkV6 ? (byte)128 : (byte)32); + break; + default: + throw new FormatException("Invalid number of parts in CIDR specification"); } + } - public NpgsqlInet(IPAddress address) - { - if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) - throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); + public override string ToString() + => (Address.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || + (Address.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128) + ? Address.ToString() + : $"{Address}/{Netmask}"; - Address = address; - Netmask = address.AddressFamily == AddressFamily.InterNetwork ? 32 : 128; - } + public static explicit operator IPAddress(NpgsqlInet inet) + => inet.Address; - public NpgsqlInet(string addr) - { - if (addr.IndexOf('/') > 0) - { - var addrbits = addr.Split('/'); - if (addrbits.GetUpperBound(0) != 1) { - throw new FormatException("Invalid number of parts in CIDR specification"); - } - Address = IPAddress.Parse(addrbits[0]); - Netmask = int.Parse(addrbits[1]); - } - else - { - Address = IPAddress.Parse(addr); - Netmask = 32; - } - } + public static implicit operator NpgsqlInet(IPAddress ip) + => new(ip); - public override string ToString() - { - if ((Address.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || - (Address.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128)) - { - return Address.ToString(); - } - return $"{Address}/{Netmask}"; - } + public void Deconstruct(out IPAddress address, out byte netmask) + { + address = Address; + netmask = Netmask; + } - // ReSharper disable once InconsistentNaming - public static IPAddress ToIPAddress(NpgsqlInet inet) - { - if (inet.Netmask != 32) - throw new InvalidCastException("Cannot cast CIDR network to address"); - return inet.Address; - } + static void CheckAddressFamily(IPAddress address) + { + if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) + throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); + } +} - public static explicit operator IPAddress(NpgsqlInet inet) => ToIPAddress(inet); +/// +/// Represents a PostgreSQL cidr type. +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-net-types.html +/// +public readonly record struct NpgsqlCidr +{ + public IPAddress Address { get; } + public byte Netmask { get; } - public static NpgsqlInet ToNpgsqlInet(IPAddress? ip) - => ip is null ? default : new NpgsqlInet(ip); - //=> ReferenceEquals(ip, null) ? default : new NpgsqlInet(ip); + public NpgsqlCidr(IPAddress address, byte netmask) + { + if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) + throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); - public static implicit operator NpgsqlInet(IPAddress ip) => ToNpgsqlInet(ip); + Address = address; + Netmask = netmask; + } - public void Deconstruct(out IPAddress address, out int netmask) + public NpgsqlCidr(string addr) + => (Address, Netmask) = addr.Split('/') switch { - address = Address; - netmask = Netmask; - } + { Length: 2 } segments => (IPAddress.Parse(segments[0]), byte.Parse(segments[1])), + { Length: 1 } => throw new FormatException("Missing netmask"), + _ => throw new FormatException("Invalid number of parts in CIDR specification") + }; - public bool Equals(NpgsqlInet other) => Address.Equals(other.Address) && Netmask == other.Netmask; + public static implicit operator NpgsqlInet(NpgsqlCidr cidr) + => new(cidr.Address, cidr.Netmask); - public override bool Equals(object? obj) - => obj is NpgsqlInet inet && Equals(inet); + public static explicit operator IPAddress(NpgsqlCidr cidr) + => cidr.Address; - public override int GetHashCode() - => PGUtil.RotateShift(Address.GetHashCode(), Netmask%32); + public override string ToString() + => $"{Address}/{Netmask}"; - public static bool operator ==(NpgsqlInet x, NpgsqlInet y) => x.Equals(y); - public static bool operator !=(NpgsqlInet x, NpgsqlInet y) => !(x == y); + public void Deconstruct(out IPAddress address, out byte netmask) + { + address = Address; + netmask = Netmask; } +} +/// +/// Represents a PostgreSQL tid value +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-oid.html +/// +public readonly struct NpgsqlTid : IEquatable +{ /// - /// Represents a PostgreSQL tid value + /// Block number /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-oid.html - /// - public readonly struct NpgsqlTid : IEquatable - { - /// - /// Block number - /// - public uint BlockNumber { get; } + public uint BlockNumber { get; } - /// - /// Tuple index within block - /// - public ushort OffsetNumber { get; } + /// + /// Tuple index within block + /// + public ushort OffsetNumber { get; } - public NpgsqlTid(uint blockNumber, ushort offsetNumber) - { - BlockNumber = blockNumber; - OffsetNumber = offsetNumber; - } + public NpgsqlTid(uint blockNumber, ushort offsetNumber) + { + BlockNumber = blockNumber; + OffsetNumber = offsetNumber; + } - public bool Equals(NpgsqlTid other) - => BlockNumber == other.BlockNumber && OffsetNumber == other.OffsetNumber; + public bool Equals(NpgsqlTid other) + => BlockNumber == other.BlockNumber && OffsetNumber == other.OffsetNumber; - public override bool Equals(object? o) - => o is NpgsqlTid tid && Equals(tid); + public override bool Equals(object? o) + => o is NpgsqlTid tid && Equals(tid); - public override int GetHashCode() => (int)BlockNumber ^ OffsetNumber; - public static bool operator ==(NpgsqlTid left, NpgsqlTid right) => left.Equals(right); - public static bool operator !=(NpgsqlTid left, NpgsqlTid right) => !(left == right); - public override string ToString() => $"({BlockNumber},{OffsetNumber})"; - } + public override int GetHashCode() => (int)BlockNumber ^ OffsetNumber; + public static bool operator ==(NpgsqlTid left, NpgsqlTid right) => left.Equals(right); + public static bool operator !=(NpgsqlTid left, NpgsqlTid right) => !(left == right); + public override string ToString() => $"({BlockNumber},{OffsetNumber})"; } #pragma warning restore 1591 diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlUserTypes.cs b/src/Npgsql/NpgsqlTypes/NpgsqlUserTypes.cs deleted file mode 100644 index 8e0e06b496..0000000000 --- a/src/Npgsql/NpgsqlTypes/NpgsqlUserTypes.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; - -// ReSharper disable once CheckNamespace -namespace NpgsqlTypes -{ - /// - /// Indicates that this property or field correspond to a PostgreSQL field with the specified name - /// - [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property | AttributeTargets.Parameter)] - public class PgNameAttribute : Attribute - { - /// - /// The name of PostgreSQL field that corresponds to this CLR property or field - /// - public string PgName { get; private set; } - - /// - /// Indicates that this property or field correspond to a PostgreSQL field with the specified name - /// - /// The name of PostgreSQL field that corresponds to this CLR property or field - public PgNameAttribute(string pgName) - { - PgName = pgName; - } - } -} diff --git a/src/Npgsql/NpgsqlTypes/PgNameAttribute.cs b/src/Npgsql/NpgsqlTypes/PgNameAttribute.cs new file mode 100644 index 0000000000..48cbc955e4 --- /dev/null +++ b/src/Npgsql/NpgsqlTypes/PgNameAttribute.cs @@ -0,0 +1,29 @@ +using System; + +// ReSharper disable once CheckNamespace +namespace NpgsqlTypes; + +/// +/// Indicates that this property or field corresponds to a PostgreSQL field with the specified name +/// +[AttributeUsage( + AttributeTargets.Enum | + AttributeTargets.Class | + AttributeTargets.Struct | + AttributeTargets.Field | + AttributeTargets.Property | + AttributeTargets.Parameter)] +public class PgNameAttribute : Attribute +{ + /// + /// The name of PostgreSQL field that corresponds to this CLR property or field + /// + public string PgName { get; } + + /// + /// Indicates that this property or field corresponds to a PostgreSQL field with the specified name + /// + /// The name of PostgreSQL field that corresponds to this CLR property or field + public PgNameAttribute(string pgName) + => PgName = pgName; +} diff --git a/src/Npgsql/NpgsqlWriteBuffer.Stream.cs b/src/Npgsql/NpgsqlWriteBuffer.Stream.cs deleted file mode 100644 index 36b2eeaf21..0000000000 --- a/src/Npgsql/NpgsqlWriteBuffer.Stream.cs +++ /dev/null @@ -1,123 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql -{ - public sealed partial class NpgsqlWriteBuffer - { - sealed class ParameterStream : Stream - { - readonly NpgsqlWriteBuffer _buf; - bool _disposed; - - internal ParameterStream(NpgsqlWriteBuffer buf) - => _buf = buf; - - internal void Init() - => _disposed = false; - - public override bool CanRead => false; - - public override bool CanWrite => true; - - public override bool CanSeek => false; - - public override long Length => throw new NotSupportedException(); - - public override void SetLength(long value) - => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - => throw new NotSupportedException(); - - public override void Flush() - => CheckDisposed(); - - public override Task FlushAsync(CancellationToken cancellationToken = default) - { - CheckDisposed(); - return cancellationToken.IsCancellationRequested - ? Task.FromCanceled(cancellationToken) : Task.CompletedTask; - } - - public override int Read(byte[] buffer, int offset, int count) - => throw new NotSupportedException(); - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => throw new NotSupportedException(); - - public override void Write(byte[] buffer, int offset, int count) - => Write(buffer, offset, count, false); - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - { - using (NoSynchronizationContextScope.Enter()) - return Write(buffer, offset, count, true, cancellationToken); - } - - Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - CheckDisposed(); - - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); - if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); - if (cancellationToken.IsCancellationRequested) - return Task.FromCanceled(cancellationToken); - - while (count > 0) - { - var left = _buf.WriteSpaceLeft; - if (left == 0) - return WriteLong(buffer, offset, count, async, cancellationToken); - - var slice = Math.Min(count, left); - _buf.WriteBytes(buffer, offset, slice); - offset += slice; - count -= slice; - } - - return Task.CompletedTask; - } - - async Task WriteLong(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - while (count > 0) - { - var left = _buf.WriteSpaceLeft; - if (left == 0) - { - await _buf.Flush(async, cancellationToken); - continue; - } - var slice = Math.Min(count, left); - _buf.WriteBytes(buffer, offset, slice); - offset += slice; - count -= slice; - } - } - - void CheckDisposed() - { - if (_disposed) - throw new ObjectDisposedException(null); - } - - protected override void Dispose(bool disposing) - => _disposed = true; - } - } -} diff --git a/src/Npgsql/NpgsqlWriteBuffer.cs b/src/Npgsql/NpgsqlWriteBuffer.cs deleted file mode 100644 index 07d534fde0..0000000000 --- a/src/Npgsql/NpgsqlWriteBuffer.cs +++ /dev/null @@ -1,597 +0,0 @@ -using System; -using System.Buffers; -using System.Buffers.Binary; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Util; -using static System.Threading.Timeout; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member -namespace Npgsql -{ - /// - /// A buffer used by Npgsql to write data to the socket efficiently. - /// Provides methods which encode different values types and tracks the current position. - /// - public sealed partial class NpgsqlWriteBuffer : IDisposable - { - #region Fields and Properties - - internal readonly NpgsqlConnector Connector; - - internal Stream Underlying { private get; set; } - - readonly Socket? _underlyingSocket; - - readonly ResettableCancellationTokenSource _timeoutCts; - - /// - /// Timeout for sync and async writes - /// - internal TimeSpan Timeout - { - get => _timeoutCts.Timeout; - set - { - if (_timeoutCts.Timeout != value) - { - Debug.Assert(_underlyingSocket != null); - - if (value > TimeSpan.Zero) - { - _underlyingSocket.SendTimeout = (int)value.TotalMilliseconds; - _timeoutCts.Timeout = value; - } - else - { - _underlyingSocket.SendTimeout = -1; - _timeoutCts.Timeout = InfiniteTimeSpan; - } - } - } - } - - /// - /// The total byte length of the buffer. - /// - internal int Size { get; private set; } - - bool _copyMode; - internal Encoding TextEncoding { get; } - - public int WriteSpaceLeft => Size - WritePosition; - - internal readonly byte[] Buffer; - readonly Encoder _textEncoder; - - internal int WritePosition; - - ParameterStream? _parameterStream; - - bool _disposed; - - /// - /// The minimum buffer size possible. - /// - internal const int MinimumSize = 4096; - internal const int DefaultSize = 8192; - - #endregion - - #region Constructors - - internal NpgsqlWriteBuffer(NpgsqlConnector connector, Stream stream, Socket? socket, int size, Encoding textEncoding) - { - if (size < MinimumSize) - throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); - - Connector = connector; - Underlying = stream; - _underlyingSocket = socket; - _timeoutCts = new ResettableCancellationTokenSource(); - Size = size; - Buffer = ArrayPool.Shared.Rent(size); - TextEncoding = textEncoding; - _textEncoder = TextEncoding.GetEncoder(); - } - - #endregion - - #region I/O - - public async Task Flush(bool async, CancellationToken cancellationToken = default) - { - if (_copyMode) - { - // In copy mode, we write CopyData messages. The message code has already been - // written to the beginning of the buffer, but we need to go back and write the - // length. - if (WritePosition == 1) - return; - var pos = WritePosition; - WritePosition = 1; - WriteInt32(pos - 1); - WritePosition = pos; - } else if (WritePosition == 0) - return; - - var finalCt = cancellationToken; - if (async && Timeout > TimeSpan.Zero) - finalCt = _timeoutCts.Start(cancellationToken); - - try - { - if (async) - { - await Underlying.WriteAsync(Buffer, 0, WritePosition, finalCt); - await Underlying.FlushAsync(finalCt); - _timeoutCts.Stop(); - } - else - { - Underlying.Write(Buffer, 0, WritePosition); - Underlying.Flush(); - } - } - catch (Exception e) - { - // Stopping twice (in case the previous Stop() call succeeded) doesn't hurt. - // Not stopping will cause an assertion failure in debug mode when we call Start() the next time. - // We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained - // _timeoutCts - _timeoutCts.Stop(); - switch (e) - { - // User requested the cancellation - case OperationCanceledException _ when (cancellationToken.IsCancellationRequested): - throw Connector.Break(e); - // Read timeout - case OperationCanceledException _: - // Note that mono throws SocketException with the wrong error (see #1330) - case IOException _ when (e.InnerException as SocketException)?.SocketErrorCode == - (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): - Debug.Assert(e is OperationCanceledException ? async : !async); - throw Connector.Break(new NpgsqlException("Exception while writing to stream", new TimeoutException("Timeout during writing attempt"))); - } - - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); - } - NpgsqlEventSource.Log.BytesWritten(WritePosition); - //NpgsqlEventSource.Log.RequestFailed(); - - WritePosition = 0; - if (CurrentCommand != null) - { - CurrentCommand.FlushOccurred = true; - CurrentCommand = null; - } - if (_copyMode) - WriteCopyDataHeader(); - } - - internal void Flush() => Flush(false).GetAwaiter().GetResult(); - - internal NpgsqlCommand? CurrentCommand { get; set; } - - #endregion - - #region Direct write - - internal void DirectWrite(ReadOnlySpan buffer) - { - Flush(); - - if (_copyMode) - { - // Flush has already written the CopyData header for us, but write the CopyData - // header to the socket with the write length before we can start writing the data directly. - Debug.Assert(WritePosition == 5); - - WritePosition = 1; - WriteInt32(buffer.Length + 4); - WritePosition = 5; - _copyMode = false; - Flush(); - _copyMode = true; - WriteCopyDataHeader(); // And ready the buffer after the direct write completes - } - else - Debug.Assert(WritePosition == 0); - - try - { - Underlying.Write(buffer); - } - catch (Exception e) - { - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); - } - } - - internal async Task DirectWrite(ReadOnlyMemory memory, bool async, CancellationToken cancellationToken = default) - { - await Flush(async, cancellationToken); - - if (_copyMode) - { - // Flush has already written the CopyData header for us, but write the CopyData - // header to the socket with the write length before we can start writing the data directly. - Debug.Assert(WritePosition == 5); - - WritePosition = 1; - WriteInt32(memory.Length + 4); - WritePosition = 5; - _copyMode = false; - await Flush(async, cancellationToken); - _copyMode = true; - WriteCopyDataHeader(); // And ready the buffer after the direct write completes - } - else - Debug.Assert(WritePosition == 0); - - try - { - if (async) - await Underlying.WriteAsync(memory, cancellationToken); - else - Underlying.Write(memory.Span); - } - catch (Exception e) - { - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); - } - } - - #endregion Direct write - - #region Write Simple - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteSByte(sbyte value) => Write(value); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteByte(byte value) => Write(value); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void WriteInt16(int value) - => WriteInt16((short)value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt16(short value) - => WriteInt16(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt16(short value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt16(ushort value) - => WriteUInt16(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt16(ushort value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt32(int value) - => WriteInt32(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt32(int value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt32(uint value) - => WriteUInt32(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt32(uint value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt64(long value) - => WriteInt64(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteInt64(long value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt64(ulong value) - => WriteUInt64(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt64(ulong value, bool littleEndian) - => Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteSingle(float value) - => WriteSingle(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteSingle(float value, bool littleEndian) - => WriteInt32(Unsafe.As(ref value), littleEndian); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteDouble(double value) - => WriteDouble(value, false); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteDouble(double value, bool littleEndian) - => WriteInt64(Unsafe.As(ref value), littleEndian); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void Write(T value) - { - if (Unsafe.SizeOf() > WriteSpaceLeft) - ThrowNotSpaceLeft(); - - Unsafe.WriteUnaligned(ref Buffer[WritePosition], value); - WritePosition += Unsafe.SizeOf(); - } - - [MethodImpl(MethodImplOptions.NoInlining)] - static void ThrowNotSpaceLeft() - => throw new InvalidOperationException("There is not enough space left in the buffer."); - - public Task WriteString(string s, int byteLen, bool async, CancellationToken cancellationToken = default) - => WriteString(s, s.Length, byteLen, async, cancellationToken); - - public Task WriteString(string s, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default) - { - if (byteLen <= WriteSpaceLeft) - { - WriteString(s, charLen); - return Task.CompletedTask; - } - return WriteStringLong(this, async, s, charLen, byteLen, cancellationToken); - - static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int charLen, int byteLen, CancellationToken cancellationToken) - { - Debug.Assert(byteLen > buffer.WriteSpaceLeft); - if (byteLen <= buffer.Size) - { - // String can fit entirely in an empty buffer. Flush and retry rather than - // going into the partial writing flow below (which requires ToCharArray()) - await buffer.Flush(async, cancellationToken); - buffer.WriteString(s, charLen); - } - else - { - var charPos = 0; - while (true) - { - buffer.WriteStringChunked(s, charPos, charLen - charPos, true, out var charsUsed, out var completed); - if (completed) - break; - await buffer.Flush(async, cancellationToken); - charPos += charsUsed; - } - } - } - } - - internal Task WriteChars(char[] chars, int offset, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default) - { - if (byteLen <= WriteSpaceLeft) - { - WriteChars(chars, offset, charLen); - return Task.CompletedTask; - } - return WriteCharsLong(this, async, chars, offset, charLen, byteLen, cancellationToken); - - static async Task WriteCharsLong(NpgsqlWriteBuffer buffer, bool async, char[] chars, int offset, int charLen, int byteLen, CancellationToken cancellationToken) - { - Debug.Assert(byteLen > buffer.WriteSpaceLeft); - if (byteLen <= buffer.Size) - { - // String can fit entirely in an empty buffer. Flush and retry rather than - // going into the partial writing flow below (which requires ToCharArray()) - await buffer.Flush(async, cancellationToken); - buffer.WriteChars(chars, offset, charLen); - } - else - { - var charPos = 0; - - while (true) - { - buffer.WriteStringChunked(chars, charPos + offset, charLen - charPos, true, out var charsUsed, out var completed); - if (completed) - break; - await buffer.Flush(async, cancellationToken); - charPos += charsUsed; - } - } - } - } - - public void WriteString(string s, int len = 0) - { - Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft); - WritePosition += TextEncoding.GetBytes(s, 0, len == 0 ? s.Length : len, Buffer, WritePosition); - } - - internal void WriteChars(char[] chars, int offset, int len) - { - var charCount = len == 0 ? chars.Length : len; - Debug.Assert(TextEncoding.GetByteCount(chars, 0, charCount) <= WriteSpaceLeft); - WritePosition += TextEncoding.GetBytes(chars, offset, charCount, Buffer, WritePosition); - } - - public void WriteBytes(ReadOnlySpan buf) - { - Debug.Assert(buf.Length <= WriteSpaceLeft); - buf.CopyTo(new Span(Buffer, WritePosition, Buffer.Length - WritePosition)); - WritePosition += buf.Length; - } - - public void WriteBytes(byte[] buf, int offset, int count) - => WriteBytes(new ReadOnlySpan(buf, offset, count)); - - public Task WriteBytesRaw(byte[] bytes, bool async, CancellationToken cancellationToken = default) - { - if (bytes.Length <= WriteSpaceLeft) - { - WriteBytes(bytes); - return Task.CompletedTask; - } - return WriteBytesLong(this, async, bytes, cancellationToken); - - static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, byte[] bytes, CancellationToken cancellationToken) - { - if (bytes.Length <= buffer.Size) - { - // value can fit entirely in an empty buffer. Flush and retry rather than - // going into the partial writing flow below - await buffer.Flush(async, cancellationToken); - buffer.WriteBytes(bytes); - } - else - { - var remaining = bytes.Length; - do - { - if (buffer.WriteSpaceLeft == 0) - await buffer.Flush(async, cancellationToken); - var writeLen = Math.Min(remaining, buffer.WriteSpaceLeft); - var offset = bytes.Length - remaining; - buffer.WriteBytes(bytes, offset, writeLen); - remaining -= writeLen; - } - while (remaining > 0); - } - } - } - - public void WriteNullTerminatedString(string s) - { - Debug.Assert(s.All(c => c < 128), "Method only supports ASCII strings"); - Debug.Assert(WriteSpaceLeft >= s.Length + 1); - WritePosition += Encoding.ASCII.GetBytes(s, 0, s.Length, Buffer, WritePosition); - WriteByte(0); - } - - #endregion - - #region Write Complex - - public Stream GetStream() - { - if (_parameterStream == null) - _parameterStream = new ParameterStream(this); - - _parameterStream.Init(); - return _parameterStream; - } - - internal void WriteStringChunked(char[] chars, int charIndex, int charCount, - bool flush, out int charsUsed, out bool completed) - { - if (WriteSpaceLeft < _textEncoder.GetByteCount(chars, charIndex, 1, flush: false)) - { - charsUsed = 0; - completed = false; - return; - } - - _textEncoder.Convert(chars, charIndex, charCount, Buffer, WritePosition, WriteSpaceLeft, - flush, out charsUsed, out var bytesUsed, out completed); - WritePosition += bytesUsed; - } - - internal unsafe void WriteStringChunked(string s, int charIndex, int charCount, - bool flush, out int charsUsed, out bool completed) - { - int bytesUsed; - - fixed (char* sPtr = s) - fixed (byte* bufPtr = Buffer) - { - if (WriteSpaceLeft < _textEncoder.GetByteCount(sPtr + charIndex, 1, flush: false)) - { - charsUsed = 0; - completed = false; - return; - } - - _textEncoder.Convert(sPtr + charIndex, charCount, bufPtr + WritePosition, WriteSpaceLeft, - flush, out charsUsed, out bytesUsed, out completed); - } - - WritePosition += bytesUsed; - } - - #endregion - - #region Copy - - internal void StartCopyMode() - { - _copyMode = true; - Size -= 5; - WriteCopyDataHeader(); - } - - internal void EndCopyMode() - { - // EndCopyMode is usually called after a Flush which ended the last CopyData message. - // That Flush also wrote the header for another CopyData which we clear here. - _copyMode = false; - Size += 5; - Clear(); - } - - void WriteCopyDataHeader() - { - Debug.Assert(_copyMode); - Debug.Assert(WritePosition == 0); - WriteByte(FrontendMessageCode.CopyData); - // Leave space for the message length - WriteInt32(0); - } - - #endregion - - #region Dispose - - public void Dispose() - { - if (_disposed) - return; - - ArrayPool.Shared.Return(Buffer); - - _timeoutCts.Dispose(); - _disposed = true; - } - - #endregion - - #region Misc - - internal void Clear() - { - WritePosition = 0; - } - - /// - /// Returns all contents currently written to the buffer (but not flushed). - /// Useful for pre-generating messages. - /// - internal byte[] GetContents() - { - var buf = new byte[WritePosition]; - Array.Copy(Buffer, buf, WritePosition); - return buf; - } - - #endregion - } -} diff --git a/src/Npgsql/PgPassFile.cs b/src/Npgsql/PgPassFile.cs index 011121d9e2..364d2b7409 100644 --- a/src/Npgsql/PgPassFile.cs +++ b/src/Npgsql/PgPassFile.cs @@ -1,160 +1,201 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Text.RegularExpressions; +using System.Text; -namespace Npgsql +namespace Npgsql; + +/// +/// Represents a .pgpass file, which contains passwords for noninteractive connections +/// +sealed class PgPassFile { + #region Properties + /// - /// Represents a .pgpass file, which contains passwords for noninteractive connections + /// File name being parsed for credentials /// - class PgPassFile - { - #region Properties + internal string FileName { get; } - /// - /// File name being parsed for credentials - /// - internal string FileName { get; } + #endregion - #endregion + #region Construction + + /// + /// Initializes a new instance of the class + /// + /// + public PgPassFile(string fileName) + => FileName = fileName; + + #endregion + + /// + /// Parses file content and gets all credentials from the file + /// + /// corresponding to all lines in the .pgpass file + internal IEnumerable Entries + { + get + { + var bytes = File.ReadAllBytes(FileName); + var mem = new MemoryStream(bytes); + using var reader = new StreamReader(mem); + while (reader.ReadLine() is { } l) + { + var line = l.Trim(); + if (line.Length > 0 && line[0] != '#') + yield return Entry.Parse(line); + } + } + } + + /// + /// Searches queries loaded from .PGPASS file to find first entry matching the provided parameters. + /// + /// Hostname to query. Use null to match any. + /// Port to query. Use null to match any. + /// Database to query. Use null to match any. + /// User name to query. Use null to match any. + /// Matching if match was found. Otherwise, returns null. + internal Entry? GetFirstMatchingEntry(string? host = null, int? port = null, string? database = null, string? username = null) + { + foreach (var entry in Entries) + if (entry.IsMatch(host, port, database, username)) + return entry; + return null; + } - #region Construction + /// + /// Represents a hostname, port, database, username, and password combination that has been retrieved from a .pgpass file + /// + internal sealed class Entry + { + #region Fields and Properties /// - /// Initializes a new instance of the class + /// Hostname parsed from the .pgpass file /// - /// - public PgPassFile(string fileName) - => FileName = fileName; + internal string? Host { get; } + /// + /// Port parsed from the .pgpass file + /// + internal int? Port { get; } + /// + /// Database parsed from the .pgpass file + /// + internal string? Database { get; } + /// + /// User name parsed from the .pgpass file + /// + internal string? Username { get; } + /// + /// Password parsed from the .pgpass file + /// + internal string? Password { get; } #endregion - /// - /// Parses file content and gets all credentials from the file - /// - /// corresponding to all lines in the .pgpass file - internal IEnumerable Entries => File.ReadLines(FileName) - .Select(line => line.Trim()) - .Where(line => line.Any() && line[0] != '#') - .Select(Entry.Parse); + #region Construction / Initialization /// - /// Searches queries loaded from .PGPASS file to find first entry matching the provided parameters. + /// This class represents an entry from the .pgpass file /// - /// Hostname to query. Use null to match any. - /// Port to query. Use null to match any. - /// Database to query. Use null to match any. - /// User name to query. Use null to match any. - /// Matching if match was found. Otherwise, returns null. - internal Entry? GetFirstMatchingEntry(string? host = null, int? port = null, string? database = null, string? username = null) - => Entries.FirstOrDefault(entry => entry.IsMatch(host, port, database, username)); + /// Hostname parsed from the .pgpass file + /// Port parsed from the .pgpass file + /// Database parsed from the .pgpass file + /// User name parsed from the .pgpass file + /// Password parsed from the .pgpass file + Entry(string? host, int? port, string? database, string? username, string? password) + { + Host = host; + Port = port; + Database = database; + Username = username; + Password = password; + } /// - /// Represents a hostname, port, database, username, and password combination that has been retrieved from a .pgpass file + /// Creates new based on string in the format hostname:port:database:username:password. The : and \ characters should be escaped with a \. /// - internal class Entry + /// string for the entry from the pgpass file + /// New instance of for the string + /// Entry is not formatted as hostname:port:database:username:password or non-wildcard port is not a number + internal static Entry Parse(string serializedEntry) { - const string PgPassWildcard = "*"; - - #region Fields and Properties - - /// - /// Hostname parsed from the .pgpass file - /// - internal string? Host { get; } - /// - /// Port parsed from the .pgpass file - /// - internal int? Port { get; } - /// - /// Database parsed from the .pgpass file - /// - internal string? Database { get; } - /// - /// User name parsed from the .pgpass file - /// - internal string? Username { get; } - /// - /// Password parsed from the .pgpass file - /// - internal string? Password { get; } - - #endregion - - #region Construction / Initialization - - /// - /// This class represents an entry from the .pgpass file - /// - /// Hostname parsed from the .pgpass file - /// Port parsed from the .pgpass file - /// Database parsed from the .pgpass file - /// User name parsed from the .pgpass file - /// Password parsed from the .pgpass file - Entry(string? host, int? port, string? database, string? username, string? password) - { - Host = host; - Port = port; - Database = database; - Username = username; - Password = password; - } + var parts = new List(5); - /// - /// Creates new based on string in the format hostname:port:database:username:password. The : and \ characters should be escaped with a \. - /// - /// string for the entry from the pgpass file - /// New instance of for the string - /// Entry is not formatted as hostname:port:database:username:password or non-wildcard port is not a number - internal static Entry Parse(string serializedEntry) + var builder = new StringBuilder(); + for (var pos = 0; pos < serializedEntry.Length; pos++) { - var parts = Regex.Split(serializedEntry, @"(? part.Replace("\\:", ":").Replace("\\\\", "\\")) // unescape any escaped characters - .Select(part => part == PgPassWildcard ? null : part) - .ToArray(); + var c = serializedEntry[pos]; - int? port = null; - if (processedParts[1] != null) + switch (c) { - if (!int.TryParse(processedParts[1], out var tempPort)) - throw new FormatException("pgpass entry was not formatted correctly. Port must be a valid integer."); - port = tempPort; + case '\\' when pos < serializedEntry.Length - 1: + // Strip backslash before colon or backslash, otherwise preserve it + c = serializedEntry[++pos]; + if (c is not (':' or '\\')) + { + builder.Append('\\'); + } + + builder.Append(c); + continue; + + case ':': + var part = builder.ToString(); + parts.Add(part == "*" ? null : part); + builder.Clear(); + continue; + + default: + builder.Append(c); + continue; } + } + + var lastPart = builder.ToString(); + parts.Add(lastPart == "*" ? null : lastPart); - return new Entry(processedParts[0], port, processedParts[2], processedParts[3], processedParts[4]); + if (parts.Count != 5) + throw new FormatException("pgpass entry was not well-formed. Please ensure all non-comment entries are formatted as hostname:port:database:username:password. If colon is included, it must be escaped like \\:."); + + int? port = null; + if (parts[1] != null) + { + if (!int.TryParse(parts[1], out var tempPort)) + throw new FormatException("pgpass entry was not formatted correctly. Port must be a valid integer."); + port = tempPort; } - #endregion - - - /// - /// Checks whether this matches the parameters supplied - /// - /// Hostname to check against this entry - /// Port to check against this entry - /// Database to check against this entry - /// Username to check against this entry - /// True if the entry is a match. False otherwise. - internal bool IsMatch(string? host, int? port, string? database, string? username) => - AreValuesMatched(host, Host) && AreValuesMatched(port, Port) && AreValuesMatched(database, Database) && AreValuesMatched(username, Username); - - /// - /// Checks if 2 strings are a match for a considering that either value can be a wildcard (*) - /// - /// Value being searched - /// Value from the PGPASS entry - /// True if the values are a match. False otherwise. - bool AreValuesMatched(string? query, string? actual) - => query == actual || actual == null || query == null; - - bool AreValuesMatched(int? query, int? actual) - => query == actual || actual == null || query == null; + return new Entry(parts[0], port, parts[2], parts[3], parts[4]); } + + #endregion + + + /// + /// Checks whether this matches the parameters supplied + /// + /// Hostname to check against this entry + /// Port to check against this entry + /// Database to check against this entry + /// Username to check against this entry + /// True if the entry is a match. False otherwise. + internal bool IsMatch(string? host, int? port, string? database, string? username) => + AreValuesMatched(host, Host) && AreValuesMatched(port, Port) && AreValuesMatched(database, Database) && AreValuesMatched(username, Username); + + /// + /// Checks if 2 strings are a match for a considering that either value can be a wildcard (*) + /// + /// Value being searched + /// Value from the PGPASS entry + /// True if the values are a match. False otherwise. + bool AreValuesMatched(string? query, string? actual) + => query == actual || actual == null || query == null; + + bool AreValuesMatched(int? query, int? actual) + => query == actual || actual == null || query == null; } } diff --git a/src/Npgsql/PoolManager.cs b/src/Npgsql/PoolManager.cs index b4fcae8d46..d1086b5196 100644 --- a/src/Npgsql/PoolManager.cs +++ b/src/Npgsql/PoolManager.cs @@ -1,131 +1,48 @@ using System; -using System.Diagnostics.CodeAnalysis; -using System.Threading; +using System.Collections.Concurrent; -namespace Npgsql -{ - /// - /// Provides lookup for a pool based on a connection string. - /// - /// - /// is lock-free, to avoid contention, but the same isn't - /// true of , which acquires a lock. The calling code always tries - /// before trying to . - /// - static class PoolManager - { - internal const int InitialPoolsSize = 10; - - static readonly object Lock = new object(); - static volatile (string Key, ConnectorPool Pool)[] _pools = new (string, ConnectorPool)[InitialPoolsSize]; - static volatile int _nextSlot; - - internal static (string Key, ConnectorPool Pool)[] Pools => _pools; - - internal static bool TryGetValue(string key, [NotNullWhen(true)] out ConnectorPool? pool) - { - // Note that pools never get removed. _pools is strictly append-only. - var nextSlot = _nextSlot; - var pools = _pools; - var sw = new SpinWait(); - - // First scan the pools and do reference equality on the connection strings - for (var i = 0; i < nextSlot; i++) - { - var cp = pools[i]; - if (ReferenceEquals(cp.Key, key)) - { - // It's possible that this pool entry is currently being written: the connection string - // component has already been writte, but the pool component is just about to be. So we - // loop on the pool until it's non-null - while (Volatile.Read(ref cp.Pool) == null) - sw.SpinOnce(); - pool = cp.Pool; - return true; - } - } - - // Next try value comparison on the strings - for (var i = 0; i < nextSlot; i++) - { - var cp = pools[i]; - if (cp.Key == key) - { - // See comment above - while (Volatile.Read(ref cp.Pool) == null) - sw.SpinOnce(); - pool = cp.Pool; - return true; - } - } - - pool = null; - return false; - } - - internal static ConnectorPool GetOrAdd(string key, ConnectorPool pool) - { - lock (Lock) - { - if (TryGetValue(key, out var result)) - return result; +namespace Npgsql; - // May need to grow the array. - if (_nextSlot == _pools.Length) - { - var newPools = new (string, ConnectorPool)[_pools.Length * 2]; - Array.Copy(_pools, newPools, _pools.Length); - _pools = newPools; - } - - _pools[_nextSlot].Key = key; - _pools[_nextSlot].Pool = pool; - Interlocked.Increment(ref _nextSlot); - return pool; - } - } +/// +/// Provides lookup for a pool based on a connection string. +/// +/// +/// Note that pools created directly as are referenced directly by users, and aren't managed here. +/// +static class PoolManager +{ + internal static ConcurrentDictionary Pools { get; } = new(); - internal static void Clear(string connString) - { - if (TryGetValue(connString, out var pool)) - pool.Clear(); - } + internal static void Clear(string connString) + { + // TODO: Actually remove the pools from here, #3387 (but be careful of concurrency) + if (Pools.TryGetValue(connString, out var pool)) + pool.Clear(); + } - internal static void ClearAll() - { - lock (Lock) - { - var pools = _pools; - for (var i = 0; i < _nextSlot; i++) - { - var cp = pools[i]; - if (cp.Key == null) - return; - cp.Pool?.Clear(); - } - } - } + internal static void ClearAll() + { + // TODO: Actually remove the pools from here, #3387 (but be careful of concurrency) + foreach (var pool in Pools.Values) + pool.Clear(); + } - static PoolManager() - { - // When the appdomain gets unloaded (e.g. web app redeployment) attempt to nicely - // close idle connectors to prevent errors in PostgreSQL logs (#491). - AppDomain.CurrentDomain.DomainUnload += (sender, args) => ClearAll(); - AppDomain.CurrentDomain.ProcessExit += (sender, args) => ClearAll(); - } + static PoolManager() + { + // When the appdomain gets unloaded (e.g. web app redeployment) attempt to nicely + // close idle connectors to prevent errors in PostgreSQL logs (#491). + AppDomain.CurrentDomain.DomainUnload += (_, _) => ClearAll(); + AppDomain.CurrentDomain.ProcessExit += (_, _) => ClearAll(); + } - /// - /// Resets the pool manager to its initial state, for test purposes only. - /// Assumes that no other threads are accessing the pool. - /// - internal static void Reset() - { - lock (Lock) - { - ClearAll(); - _pools = new (string, ConnectorPool)[InitialPoolsSize]; - _nextSlot = 0; - } - } + /// + /// Resets the pool manager to its initial state, for test purposes only. + /// Assumes that no other threads are accessing the pool. + /// + internal static void Reset() + { + // TODO: Remove once #3387 is implemented + ClearAll(); + Pools.Clear(); } -} +} \ No newline at end of file diff --git a/src/Npgsql/PoolingDataSource.cs b/src/Npgsql/PoolingDataSource.cs new file mode 100644 index 0000000000..192a86c052 --- /dev/null +++ b/src/Npgsql/PoolingDataSource.cs @@ -0,0 +1,471 @@ +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using System.Transactions; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; +using Npgsql.Util; + +namespace Npgsql; + +class PoolingDataSource : NpgsqlDataSource +{ + #region Fields and properties + + internal int MaxConnections { get; } + internal int MinConnections { get; } + + readonly TimeSpan _connectionLifetime; + + volatile int _numConnectors; + + volatile int _idleCount; + + /// + /// Tracks all connectors currently managed by this pool, whether idle or busy. + /// Only updated rarely - when physical connections are opened/closed - but is read in perf-sensitive contexts. + /// + private protected readonly NpgsqlConnector?[] Connectors; + + readonly NpgsqlMultiHostDataSource? _parentPool; + + /// + /// Reader side for the idle connector channel. Contains nulls in order to release waiting attempts after + /// a connector has been physically closed/broken. + /// + readonly ChannelReader _idleConnectorReader; + internal ChannelWriter IdleConnectorWriter { get; } + + readonly ILogger _logger; + + /// + /// Incremented every time this pool is cleared via or + /// . Allows us to identify connections which were + /// created before the clear. + /// + volatile int _clearCounter; + + static readonly TimerCallback PruningTimerCallback = PruneIdleConnectors; + readonly Timer _pruningTimer; + readonly TimeSpan _pruningSamplingInterval; + readonly int _pruningSampleSize; + readonly int[] _pruningSamples; + readonly int _pruningMedianIndex; + volatile bool _pruningTimerEnabled; + int _pruningSampleIndex; + + volatile int _isClearing; + + #endregion + + internal sealed override (int Total, int Idle, int Busy) Statistics + { + get + { + var numConnectors = _numConnectors; + var idleCount = _idleCount; + return (numConnectors, idleCount, numConnectors - idleCount); + } + } + + internal sealed override bool OwnsConnectors => true; + + internal PoolingDataSource( + NpgsqlConnectionStringBuilder settings, + NpgsqlDataSourceConfiguration dataSourceConfig, + NpgsqlMultiHostDataSource? parentPool = null) + : base(settings, dataSourceConfig) + { + if (settings.MaxPoolSize < settings.MinPoolSize) + throw new ArgumentException($"Connection can't have 'Max Pool Size' {settings.MaxPoolSize} under 'Min Pool Size' {settings.MinPoolSize}"); + + _parentPool = parentPool; + + // We enforce Max Pool Size, so no need to to create a bounded channel (which is less efficient) + // On the consuming side, we have the multiplexing write loop but also non-multiplexing Rents + // On the producing side, we have connections being released back into the pool (both multiplexing and not) + var idleChannel = Channel.CreateUnbounded(); + _idleConnectorReader = idleChannel.Reader; + IdleConnectorWriter = idleChannel.Writer; + + MaxConnections = settings.MaxPoolSize; + MinConnections = settings.MinPoolSize; + + if (settings.ConnectionPruningInterval == 0) + throw new ArgumentException("ConnectionPruningInterval can't be 0."); + var connectionIdleLifetime = TimeSpan.FromSeconds(settings.ConnectionIdleLifetime); + var pruningSamplingInterval = TimeSpan.FromSeconds(settings.ConnectionPruningInterval); + if (connectionIdleLifetime < pruningSamplingInterval) + throw new ArgumentException($"Connection can't have {nameof(settings.ConnectionIdleLifetime)} {connectionIdleLifetime} under {nameof(settings.ConnectionPruningInterval)} {pruningSamplingInterval}"); + + _pruningTimer = new Timer(PruningTimerCallback, this, Timeout.Infinite, Timeout.Infinite); + _pruningSampleSize = DivideRoundingUp(settings.ConnectionIdleLifetime, settings.ConnectionPruningInterval); + _pruningMedianIndex = DivideRoundingUp(_pruningSampleSize, 2) - 1; // - 1 to go from length to index + _pruningSamplingInterval = pruningSamplingInterval; + _pruningSamples = new int[_pruningSampleSize]; + _pruningTimerEnabled = false; + + _connectionLifetime = TimeSpan.FromSeconds(settings.ConnectionLifetime); + Connectors = new NpgsqlConnector[MaxConnections]; + + _logger = LoggingConfiguration.ConnectionLogger; + } + + static SemaphoreSlim SyncOverAsyncSemaphore { get; } = new(Math.Max(1, Environment.ProcessorCount / 2)); + + internal sealed override ValueTask Get( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + CheckDisposed(); + + return TryGetIdleConnector(out var connector) + ? new ValueTask(connector) + : RentAsync(conn, timeout, async, cancellationToken); + + async ValueTask RentAsync( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + // First, try to open a new physical connector. This will fail if we're at max capacity. + var connector = await OpenNewConnector(conn, timeout, async, cancellationToken).ConfigureAwait(false); + if (connector != null) + return connector; + + // We're at max capacity. Block on the idle channel with a timeout. + // Note that Channels guarantee fair FIFO behavior to callers of ReadAsync (first-come first- + // served), which is crucial to us. + using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var finalToken = linkedSource.Token; + linkedSource.CancelAfter(timeout.CheckAndGetTimeLeft()); + MetricsReporter.ReportPendingConnectionRequestStart(); + + try + { + while (true) + { + try + { + if (async) + connector = await _idleConnectorReader.ReadAsync(finalToken).ConfigureAwait(false); + else + { + SyncOverAsyncSemaphore.Wait(finalToken); + try + { + var awaiter = _idleConnectorReader.ReadAsync(finalToken).ConfigureAwait(false).GetAwaiter(); + var mres = new ManualResetEventSlim(false, 0); + + // Cancellation happens through the ReadAsync call, which will complete the task. + awaiter.UnsafeOnCompleted(() => mres.Set()); + mres.Wait(CancellationToken.None); + connector = awaiter.GetResult(); + } + finally + { + SyncOverAsyncSemaphore.Release(); + } + } + + if (CheckIdleConnector(connector)) + return connector; + } + catch (OperationCanceledException) + { + cancellationToken.ThrowIfCancellationRequested(); + Debug.Assert(finalToken.IsCancellationRequested); + + MetricsReporter.ReportConnectionPoolTimeout(); + throw new NpgsqlException( + $"The connection pool has been exhausted, either raise 'Max Pool Size' (currently {MaxConnections}) " + + $"or 'Timeout' (currently {Settings.Timeout} seconds) in your connection string.", + new TimeoutException()); + } + catch (ChannelClosedException) + { + throw new NpgsqlException("The connection pool has been shut down."); + } + + // If we're here, our waiting attempt on the idle connector channel was released with a null + // (or bad connector), or we're in sync mode. Check again if a new idle connector has appeared since we last checked. + if (TryGetIdleConnector(out connector)) + return connector; + + // We might have closed a connector in the meantime and no longer be at max capacity + // so try to open a new connector and if that fails, loop again. + connector = await OpenNewConnector(conn, timeout, async, cancellationToken).ConfigureAwait(false); + if (connector != null) + return connector; + } + } + finally + { + MetricsReporter.ReportPendingConnectionRequestStop(); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal sealed override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + { + while (_idleConnectorReader.TryRead(out connector)) + if (CheckIdleConnector(connector)) + return true; + + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) + { + if (connector is null) + return false; + + // Only decrement when the connector has a value. + Interlocked.Decrement(ref _idleCount); + + // An connector could be broken because of a keepalive that occurred while it was + // idling in the pool + // TODO: Consider removing the pool from the keepalive code. The following branch is simply irrelevant + // if keepalive isn't turned on. + if (connector.IsBroken) + { + CloseConnector(connector); + return false; + } + + if (_connectionLifetime != TimeSpan.Zero && DateTime.UtcNow > connector.OpenTimestamp + _connectionLifetime) + { + LogMessages.ConnectionExceededMaximumLifetime(_logger, _connectionLifetime, connector.Id); + CloseConnector(connector); + return false; + } + + // The connector directly references the data source type mapper into the connector, to protect it against changes by a concurrent + // ReloadTypes. We update them here before returning the connector from the pool. + Debug.Assert(SerializerOptions is not null); + Debug.Assert(DatabaseInfo is not null); + connector.SerializerOptions = SerializerOptions; + connector.DatabaseInfo = DatabaseInfo; + + Debug.Assert(connector.State == ConnectorState.Ready, + $"Got idle connector but {nameof(connector.State)} is {connector.State}"); + Debug.Assert(connector.CommandsInFlightCount == 0, + $"Got idle connector but {nameof(connector.CommandsInFlightCount)} is {connector.CommandsInFlightCount}"); + Debug.Assert(connector.MultiplexAsyncWritingLock == 0, + $"Got idle connector but {nameof(connector.MultiplexAsyncWritingLock)} is 1"); + + return true; + } + + internal sealed override async ValueTask OpenNewConnector( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + // As long as we're under max capacity, attempt to increase the connector count and open a new connection. + for (var numConnectors = _numConnectors; numConnectors < MaxConnections; numConnectors = _numConnectors) + { + // Note that we purposefully don't use SpinWait for this: https://github.com/dotnet/coreclr/pull/21437 + if (Interlocked.CompareExchange(ref _numConnectors, numConnectors + 1, numConnectors) != numConnectors) + continue; + + try + { + // We've managed to increase the open counter, open a physical connections. +#if NET7_0_OR_GREATER + var startTime = Stopwatch.GetTimestamp(); +#endif + var connector = new NpgsqlConnector(this, conn) { ClearCounter = _clearCounter }; + await connector.Open(timeout, async, cancellationToken).ConfigureAwait(false); +#if NET7_0_OR_GREATER + MetricsReporter.ReportConnectionCreateTime(Stopwatch.GetElapsedTime(startTime)); +#endif + + var i = 0; + for (; i < MaxConnections; i++) + if (Interlocked.CompareExchange(ref Connectors[i], connector, null) == null) + break; + + Debug.Assert(i < MaxConnections, $"Could not find free slot in {Connectors} when opening."); + if (i == MaxConnections) + throw new NpgsqlException($"Could not find free slot in {Connectors} when opening. Please report a bug."); + + // Only start pruning if we've incremented open count past _min. + // Note that we don't do it only once, on equality, because the thread which incremented open count past _min might get exception + // on NpgsqlConnector.Open due to timeout, CancellationToken or other reasons. + if (numConnectors >= MinConnections) + UpdatePruningTimer(); + + return connector; + } + catch + { + // Physical open failed, decrement the open and busy counter back down. + Interlocked.Decrement(ref _numConnectors); + + // In case there's a waiting attempt on the channel, we write a null to the idle connector channel + // to wake it up, so it will try opening (and probably throw immediately) + // Statement order is important since we have synchronous completions on the channel. + IdleConnectorWriter.TryWrite(null); + + // Just in case we always call UpdatePruningTimer for failed physical open + UpdatePruningTimer(); + + throw; + } + } + + return null; + } + + internal sealed override void Return(NpgsqlConnector connector) + { + Debug.Assert(!connector.InTransaction); + Debug.Assert(connector.MultiplexAsyncWritingLock == 0 || connector.IsBroken || connector.IsClosed, + $"About to return multiplexing connector to the pool, but {nameof(connector.MultiplexAsyncWritingLock)} is {connector.MultiplexAsyncWritingLock}"); + + // If Clear/ClearAll has been been called since this connector was first opened, + // throw it away. The same if it's broken (in which case CloseConnector is only + // used to update state/perf counter). + if (connector.ClearCounter != _clearCounter || connector.IsBroken) + { + CloseConnector(connector); + return; + } + + // Statement order is important since we have synchronous completions on the channel. + Interlocked.Increment(ref _idleCount); + var written = IdleConnectorWriter.TryWrite(connector); + Debug.Assert(written); + } + + internal override void Clear() + { + Interlocked.Increment(ref _clearCounter); + + if (Interlocked.CompareExchange(ref _isClearing, 1, 0) == 1) + return; + + try + { + var count = _idleCount; + while (count > 0 && _idleConnectorReader.TryRead(out var connector)) + { + if (CheckIdleConnector(connector)) + { + CloseConnector(connector); + count--; + } + } + } + finally + { + _isClearing = 0; + } + } + + void CloseConnector(NpgsqlConnector connector) + { + try + { + connector.Close(); + } + catch (Exception exception) + { + LogMessages.ExceptionWhenClosingPhysicalConnection(_logger, connector.Id, exception); + } + + var i = 0; + for (; i < MaxConnections; i++) + if (Interlocked.CompareExchange(ref Connectors[i], null, connector) == connector) + break; + + // If CloseConnector is being called from within OpenNewConnector (e.g. an error happened during a connection initializer which + // causes the connector to Break, and therefore return the connector), then we haven't yet added the connector to Connectors. + // In this case, there's no state to revert here (that's all taken care of in OpenNewConnector), skip it. + if (i == MaxConnections) + return; + + var numConnectors = Interlocked.Decrement(ref _numConnectors); + Debug.Assert(numConnectors >= 0); + + // If a connector has been closed for any reason, we write a null to the idle connector channel to wake up + // a waiter, who will open a new physical connection + // Statement order is important since we have synchronous completions on the channel. + IdleConnectorWriter.TryWrite(null); + + // Only turn off the timer one time, when it was this Close that brought Open back to _min. + if (numConnectors == MinConnections) + UpdatePruningTimer(); + } + + internal override bool TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) + => _parentPool is null + ? base.TryRemovePendingEnlistedConnector(connector, transaction) + : _parentPool.TryRemovePendingEnlistedConnector(connector, transaction); + + #region Pruning + + void UpdatePruningTimer() + { + lock (_pruningTimer) + { + var numConnectors = _numConnectors; + if (numConnectors > MinConnections && !_pruningTimerEnabled) + { + _pruningTimerEnabled = true; + _pruningTimer.Change(_pruningSamplingInterval, Timeout.InfiniteTimeSpan); + } + else if (numConnectors <= MinConnections && _pruningTimerEnabled) + { + _pruningTimer.Change(Timeout.Infinite, Timeout.Infinite); + _pruningSampleIndex = 0; + _pruningTimerEnabled = false; + } + } + } + + static void PruneIdleConnectors(object? state) + { + var pool = (PoolingDataSource)state!; + var samples = pool._pruningSamples; + int toPrune; + lock (pool._pruningTimer) + { + // Check if we might have been contending with DisablePruning. + if (!pool._pruningTimerEnabled) + return; + + var sampleIndex = pool._pruningSampleIndex; + samples[sampleIndex] = pool._idleCount; + if (sampleIndex != pool._pruningSampleSize - 1) + { + pool._pruningSampleIndex = sampleIndex + 1; + pool._pruningTimer.Change(pool._pruningSamplingInterval, Timeout.InfiniteTimeSpan); + return; + } + + // Calculate median value for pruning, reset index and timer, and release the lock. + Array.Sort(samples); + toPrune = samples[pool._pruningMedianIndex]; + pool._pruningSampleIndex = 0; + pool._pruningTimer.Change(pool._pruningSamplingInterval, Timeout.InfiniteTimeSpan); + } + + while (toPrune > 0 && + pool._numConnectors > pool.MinConnections && + pool._idleConnectorReader.TryRead(out var connector) && + connector != null) + { + if (pool.CheckIdleConnector(connector)) + pool.CloseConnector(connector); + + toPrune--; + } + } + + static int DivideRoundingUp(int value, int divisor) => 1 + (value - 1) / divisor; + + #endregion +} diff --git a/src/Npgsql/PostgresDatabaseInfo.cs b/src/Npgsql/PostgresDatabaseInfo.cs index 269623693e..4d793238b6 100644 --- a/src/Npgsql/PostgresDatabaseInfo.cs +++ b/src/Npgsql/PostgresDatabaseInfo.cs @@ -1,134 +1,119 @@ using System; using System.Collections.Generic; -using System.Data; using System.Diagnostics; using System.Globalization; -using System.Linq; +using System.Text; using System.Threading.Tasks; -using Npgsql.Logging; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Npgsql.BackendMessages; +using Npgsql.Internal; using Npgsql.PostgresTypes; using Npgsql.Util; +using static Npgsql.Util.Statics; // ReSharper disable StringLiteralTypo // ReSharper disable CommentTypo -namespace Npgsql +namespace Npgsql; + +/// +/// The default implementation of , for standard PostgreSQL databases.. +/// +sealed class PostgresDatabaseInfoFactory : INpgsqlDatabaseInfoFactory { - /// - /// The default implementation of , for standard PostgreSQL databases.. - /// - class PostgresDatabaseInfoFactory : INpgsqlDatabaseInfoFactory + /// + public async Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) { - /// - public async Task Load(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) - { - var db = new PostgresDatabaseInfo(conn); - await db.LoadPostgresInfo(conn, timeout, async); - Debug.Assert(db.LongVersion != null); - return db; - } + var db = new PostgresDatabaseInfo(conn); + await db.LoadPostgresInfo(conn, timeout, async).ConfigureAwait(false); + Debug.Assert(db.LongVersion != null); + return db; } +} + +/// +/// The default implementation of NpgsqlDatabase, for standard PostgreSQL databases. +/// +class PostgresDatabaseInfo : NpgsqlDatabaseInfo +{ + readonly ILogger _connectionLogger; /// - /// The default implementation of NpgsqlDatabase, for standard PostgreSQL databases. + /// The PostgreSQL types detected in the database. /// - class PostgresDatabaseInfo : NpgsqlDatabaseInfo - { - /// - /// The Npgsql logger instance. - /// - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(PostgresDatabaseInfo)); - - /// - /// The PostgreSQL types detected in the database. - /// - List? _types; - - /// - protected override IEnumerable GetTypes() => _types ?? Enumerable.Empty(); - - /// - /// The PostgreSQL version string as returned by the version() function. Populated during loading. - /// - public string LongVersion { get; set; } = default!; - - /// - /// True if the backend is Amazon Redshift; otherwise, false. - /// - public bool IsRedshift { get; private set; } - - /// - public override bool SupportsUnlisten => Version >= new Version(6, 4, 0) && !IsRedshift; - - /// - /// True if the 'pg_enum' table includes the 'enumsortorder' column; otherwise, false. - /// - public virtual bool HasEnumSortOrder => Version >= new Version(9, 1, 0); - - /// - /// True if the 'pg_type' table includes the 'typcategory' column; otherwise, false. - /// - /// - /// pg_type.typcategory is added after 8.4. - /// see: https://www.postgresql.org/docs/8.4/static/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE - /// - public virtual bool HasTypeCategory => Version >= new Version(8, 4, 0); - - internal PostgresDatabaseInfo(NpgsqlConnection conn) - : base(conn.Host!, conn.Port, conn.Database!, ParseServerVersion(conn.PostgresParameters["server_version"])) - { - } + List? _types; - /// - /// Loads database information from the PostgreSQL database specified by . - /// - /// The database connection. - /// The timeout while loading types from the backend. - /// True to load types asynchronously. - /// - /// A task representing the asynchronous operation. - /// - internal async Task LoadPostgresInfo(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) - { - HasIntegerDateTimes = - conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) && - intDateTimes == "on"; + /// + protected override IEnumerable GetTypes() => _types ?? (IEnumerable)Array.Empty(); - IsRedshift = conn.Settings.ServerCompatibilityMode == ServerCompatibilityMode.Redshift; - _types = await LoadBackendTypes(conn, timeout, async); - } + /// + /// The PostgreSQL version string as returned by the version() function. Populated during loading. + /// + public string LongVersion { get; set; } = default!; - /// - /// Generates a raw SQL query string to select type information. - /// - /// True to load range types. - /// True to load enum types. - /// - /// True to load table composites. - /// - /// A raw SQL query string that selects type information. - /// - /// - /// Select all types (base, array which is also base, enum, range, composite). - /// Note that arrays are distinguished from primitive types through them having typreceive=array_recv. - /// Order by primitives first, container later. - /// For arrays and ranges, join in the element OID and type (to filter out arrays of unhandled - /// types). - /// - static string GenerateTypesQuery(bool withRange, bool withEnum, bool withEnumSortOrder, - bool loadTableComposites) - => $@" -SELECT version(); - -SELECT ns.nspname, typ_and_elem_type.*, - CASE - WHEN typtype IN ('b', 'e', 'p') THEN 0 -- First base types, enums, pseudo-types - WHEN typtype = 'r' THEN 1 -- Ranges after - WHEN typtype = 'c' THEN 2 -- Composites after - WHEN typtype = 'd' AND elemtyptype <> 'a' THEN 3 -- Domains over non-arrays after - WHEN typtype = 'a' THEN 4 -- Arrays before - WHEN typtype = 'd' AND elemtyptype = 'a' THEN 5 -- Domains over arrays last - END AS ord + /// + /// True if the backend is Amazon Redshift; otherwise, false. + /// + public bool IsRedshift { get; private set; } + + /// + public override bool SupportsUnlisten => Version.IsGreaterOrEqual(6, 4) && !IsRedshift; + + /// + /// True if the 'pg_enum' table includes the 'enumsortorder' column; otherwise, false. + /// + public virtual bool HasEnumSortOrder => Version.IsGreaterOrEqual(9, 1); + + /// + /// True if the 'pg_type' table includes the 'typcategory' column; otherwise, false. + /// + /// + /// pg_type.typcategory is added after 8.4. + /// see: https://www.postgresql.org/docs/8.4/static/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE + /// + public virtual bool HasTypeCategory => Version.IsGreaterOrEqual(8, 4); + + internal PostgresDatabaseInfo(NpgsqlConnector conn) + : base(conn.Host!, conn.Port, conn.Database!, conn.PostgresParameters["server_version"]) + => _connectionLogger = conn.LoggingConfiguration.ConnectionLogger; + + private protected PostgresDatabaseInfo(string host, int port, string databaseName, string serverVersion) + : base(host, port, databaseName, serverVersion) + => _connectionLogger = NullLogger.Instance; + + /// + /// Loads database information from the PostgreSQL database specified by . + /// + /// The database connection. + /// The timeout while loading types from the backend. + /// True to load types asynchronously. + /// + /// A task representing the asynchronous operation. + /// + internal async Task LoadPostgresInfo(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) + { + HasIntegerDateTimes = + conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) && + intDateTimes == "on"; + + IsRedshift = conn.Settings.ServerCompatibilityMode == ServerCompatibilityMode.Redshift; + _types = await LoadBackendTypes(conn, timeout, async).ConfigureAwait(false); + } + + /// + /// Generates a raw SQL query string to select type information. + /// + /// + /// Select all types (base, array which is also base, enum, range, composite). + /// Note that arrays are distinguished from primitive types through them having typreceive=array_recv. + /// Order by primitives first, container later. + /// For arrays and ranges, join in the element OID and type (to filter out arrays of unhandled + /// types). + /// + static string GenerateLoadTypesQuery(bool withRange, bool withMultirange, bool loadTableComposites) + => $@" +SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid FROM ( -- Arrays have typtype=b - this subquery identifies them by their typreceive and converts their typtype to a -- We first do this for the type (innerest-most subquery), and then for its element type @@ -143,6 +128,7 @@ END AS ord CASE WHEN proc.proname='array_recv' THEN typ.typelem {(withRange ? "WHEN typ.typtype='r' THEN rngsubtype" : "")} + {(withMultirange ? "WHEN typ.typtype='m' THEN (SELECT rngtypid FROM pg_range WHERE rngmultitypid = typ.oid)" : "")} WHEN typ.typtype='d' THEN typ.typbasetype END AS elemtypoid FROM pg_type AS typ @@ -153,19 +139,29 @@ LEFT JOIN pg_class AS cls ON (cls.oid = typ.typrelid) LEFT JOIN pg_type AS elemtyp ON elemtyp.oid = elemtypoid LEFT JOIN pg_class AS elemcls ON (elemcls.oid = elemtyp.typrelid) LEFT JOIN pg_proc AS elemproc ON elemproc.oid = elemtyp.typreceive -) AS typ_and_elem_type +) AS t JOIN pg_namespace AS ns ON (ns.oid = typnamespace) WHERE - typtype IN ('b', 'r', 'e', 'd') OR -- Base, range, enum, domain + typtype IN ('b', 'r', 'm', 'e', 'd') OR -- Base, range, multirange, enum, domain (typtype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "relkind='c'")}) OR -- User-defined free-standing composites (not table composites) by default - (typtype = 'p' AND typname IN ('record', 'void')) OR -- Some special supported pseudo-types + (typtype = 'p' AND typname IN ('record', 'void', 'unknown')) OR -- Some special supported pseudo-types (typtype = 'a' AND ( -- Array of... - elemtyptype IN ('b', 'r', 'e', 'd') OR -- Array of base, range, enum, domain + elemtyptype IN ('b', 'r', 'm', 'e', 'd') OR -- Array of base, range, multirange, enum, domain (elemtyptype = 'p' AND elemtypname IN ('record', 'void')) OR -- Arrays of special supported pseudo-types (elemtyptype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "elemrelkind='c'")}) -- Array of user-defined free-standing composites (not table composites) by default )) -ORDER BY ord; - +ORDER BY CASE + WHEN typtype IN ('b', 'e', 'p') THEN 0 -- First base types, enums, pseudo-types + WHEN typtype = 'c' THEN 1 -- Composites after (fields loaded later in 2nd pass) + WHEN typtype = 'r' THEN 2 -- Ranges after + WHEN typtype = 'm' THEN 3 -- Multiranges after + WHEN typtype = 'd' AND elemtyptype <> 'a' THEN 4 -- Domains over non-arrays after + WHEN typtype = 'a' THEN 5 -- Arrays after + WHEN typtype = 'd' AND elemtyptype = 'a' THEN 6 -- Domains over arrays last +END;"; + + static string GenerateLoadCompositeTypesQuery(bool loadTableComposites) + => $@" -- Load field definitions for (free-standing) composite types SELECT typ.oid, att.attname, att.atttypid FROM pg_type AS typ @@ -176,235 +172,345 @@ JOIN pg_attribute AS att ON (att.attrelid = typ.typrelid) (typ.typtype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "cls.relkind='c'")}) AND attnum > 0 AND -- Don't load system attributes NOT attisdropped -ORDER BY typ.oid, att.attnum; +ORDER BY typ.oid, att.attnum;"; -{(withEnum ? $@" + static string GenerateLoadEnumFieldsQuery(bool withEnumSortOrder) + => $@" -- Load enum fields SELECT pg_type.oid, enumlabel FROM pg_enum JOIN pg_type ON pg_type.oid=enumtypid -ORDER BY oid{(withEnumSortOrder ? ", enumsortorder" : "")};" : "")} -"; - - /// - /// Loads type information from the backend specified by . - /// - /// The database connection. - /// The timeout while loading types from the backend. - /// True to load types asynchronously. - /// - /// A collection of types loaded from the backend. - /// - /// - /// Unknown typtype for type '{internalName}' in pg_type: {typeChar}. - internal async Task> LoadBackendTypes(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) +ORDER BY oid{(withEnumSortOrder ? ", enumsortorder" : "")};"; + + /// + /// Loads type information from the backend specified by . + /// + /// The database connection. + /// The timeout while loading types from the backend. + /// True to load types asynchronously. + /// + /// A collection of types loaded from the backend. + /// + /// + /// Unknown typtype for type '{internalName}' in pg_type: {typeChar}. + internal async Task> LoadBackendTypes(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) + { + var versionQuery = "SELECT version();"; + var loadTypesQuery = GenerateLoadTypesQuery(SupportsRangeTypes, SupportsMultirangeTypes, conn.Settings.LoadTableComposites); + var loadCompositeTypesQuery = GenerateLoadCompositeTypesQuery(conn.Settings.LoadTableComposites); + var loadEnumFieldsQuery = SupportsEnumTypes + ? GenerateLoadEnumFieldsQuery(HasEnumSortOrder) + : string.Empty; + + timeout.CheckAndApply(conn); + // The Lexer (https://github.com/postgres/postgres/blob/master/src/backend/replication/repl_scanner.l) + // and Parser (https://github.com/postgres/postgres/blob/master/src/backend/replication/repl_gram.y) + // for replication connections are pretty picky and somewhat flawed. + // Currently (2022-01-22) they do not support + // - SQL batches containing multiple commands + // - The ('\r') in Windows or Mac newlines + // - Comments + // For this reason we need clean up our type loading queries for replication connections and execute + // them individually instead of batched. + // Theoretically we cold even use the extended protocol + batching for regular (non-replication) + // connections but that would branch our code even more for very little gain. + var isReplicationConnection = conn.Settings.ReplicationMode != ReplicationMode.Off; + if (isReplicationConnection) { - var commandTimeout = 0; // Default to infinity - if (timeout.IsSet) - { - commandTimeout = (int)timeout.TimeLeft.TotalSeconds; - if (commandTimeout <= 0) - throw new TimeoutException(); - } + await conn.WriteQuery(versionQuery, async).ConfigureAwait(false); + await conn.WriteQuery(SanitizeForReplicationConnection(loadTypesQuery), async).ConfigureAwait(false); + await conn.WriteQuery(SanitizeForReplicationConnection(loadCompositeTypesQuery), async).ConfigureAwait(false); + if (SupportsEnumTypes) + await conn.WriteQuery(SanitizeForReplicationConnection(loadEnumFieldsQuery), async).ConfigureAwait(false); - var typeLoadingQuery = GenerateTypesQuery(SupportsRangeTypes, SupportsEnumTypes, HasEnumSortOrder, conn.Settings.LoadTableComposites); - using var command = new NpgsqlCommand(typeLoadingQuery, conn) + static string SanitizeForReplicationConnection(string str) { - CommandTimeout = commandTimeout, - AllResultTypesAreUnknown = true - }; - - timeout.CheckAndApply(conn.Connector!); - using var reader = async ? await command.ExecuteReaderAsync() : command.ExecuteReader(); - var byOID = new Dictionary(); + var sb = new StringBuilder(str.Length); + using var c = str.GetEnumerator(); + while (c.MoveNext()) + { + switch (c.Current) + { + case '\r': + sb.Append('\n'); + // Check for a \n after the \r + // and swallow it if it exists + if (c.MoveNext()) + { + if (c.Current == '-') + goto case '-'; + if (c.Current != '\n') + sb.Append(c.Current); + } + break; + case '-': + // Check if there is a second dash + if (c.MoveNext()) + { + if (c.Current == '\r') + { + sb.Append('-'); + goto case '\r'; + } + if (c.Current != '-') + { + sb.Append('-'); + sb.Append(c.Current); + break; + } + + // Comment mode + // Swallow everything until we find a newline + while (c.MoveNext()) + { + if (c.Current == '\r') + goto case '\r'; + if (c.Current == '\n') + { + sb.Append('\n'); + break; + } + } + } + break; + default: + sb.Append(c.Current); + break; + } + } - // First the PostgreSQL version - if (async) - { - await reader.ReadAsync(); - LongVersion = reader.GetString(0); - await reader.NextResultAsync(); - } - else - { - reader.Read(); - LongVersion = reader.GetString(0); - reader.NextResult(); + return sb.ToString(); } + } + else + { + var batchQuery = new StringBuilder( + versionQuery.Length + + loadTypesQuery.Length + + loadCompositeTypesQuery.Length + + (SupportsEnumTypes + ? loadEnumFieldsQuery.Length + : 0)) + .AppendLine(versionQuery) + .AppendLine(loadTypesQuery) + .AppendLine(loadCompositeTypesQuery); - // Then load the types - while (async ? await reader.ReadAsync() : reader.Read()) + if (SupportsEnumTypes) + batchQuery.AppendLine(loadEnumFieldsQuery); + await conn.WriteQuery(batchQuery.ToString(), async).ConfigureAwait(false); + } + await conn.Flush(async).ConfigureAwait(false); + var byOID = new Dictionary(); + + // First read the PostgreSQL version + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + + // We read the message in non-sequential mode which buffers the whole message. + // There is no need to ensure data within the message boundaries + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + // Note that here and below we don't assign ReadBuffer to a variable + // because we might allocate oversize buffer + conn.ReadBuffer.Skip(2); // Column count + LongVersion = ReadNonNullableString(conn.ReadBuffer); + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + if (isReplicationConnection) + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + + // Then load the types + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + IBackendMessage msg; + while (true) + { + msg = await conn.ReadMessage(async).ConfigureAwait(false); + if (msg is not DataRowMessage) + break; + + conn.ReadBuffer.Skip(2); // Column count + var nspname = ReadNonNullableString(conn.ReadBuffer); + var oid = uint.Parse(ReadNonNullableString(conn.ReadBuffer), NumberFormatInfo.InvariantInfo); + Debug.Assert(oid != 0); + var typname = ReadNonNullableString(conn.ReadBuffer); + var typtype = ReadNonNullableString(conn.ReadBuffer)[0]; + var typnotnull = ReadNonNullableString(conn.ReadBuffer)[0] == 't'; + var len = conn.ReadBuffer.ReadInt32(); + var elemtypoid = len == -1 ? 0 : uint.Parse(conn.ReadBuffer.ReadString(len), NumberFormatInfo.InvariantInfo); + + switch (typtype) { - var ns = reader.GetString("nspname"); - var internalName = reader.GetString("typname"); - var oid = uint.Parse(reader.GetString("oid"), NumberFormatInfo.InvariantInfo); - Debug.Assert(oid != 0); + case 'b': // Normal base type + var baseType = new PostgresBaseType(nspname, typname, oid); + byOID[baseType.OID] = baseType; + continue; - var elementOID = reader.IsDBNull("elemtypoid") - ? 0 - : uint.Parse(reader.GetString("elemtypoid"), NumberFormatInfo.InvariantInfo); - - var typeChar = reader.GetChar("typtype"); - switch (typeChar) + case 'a': // Array + { + Debug.Assert(elemtypoid > 0); + if (!byOID.TryGetValue(elemtypoid, out var elementPostgresType)) { - case 'b': // Normal base type - var baseType = new PostgresBaseType(ns, internalName, oid); - byOID[baseType.OID] = baseType; - continue; + _connectionLogger.LogTrace("Array type '{ArrayTypeName}' refers to unknown element with OID {ElementTypeOID}, skipping", + typname, elemtypoid); + continue; + } - case 'a': // Array - { - Debug.Assert(elementOID > 0); - if (!byOID.TryGetValue(elementOID, out var elementPostgresType)) - { - Log.Trace($"Array type '{internalName}' refers to unknown element with OID {elementOID}, skipping", conn.ProcessID); - continue; - } + var arrayType = new PostgresArrayType(nspname, typname, oid, elementPostgresType); + byOID[arrayType.OID] = arrayType; + continue; + } - var arrayType = new PostgresArrayType(ns, internalName, oid, elementPostgresType); - byOID[arrayType.OID] = arrayType; - continue; - } + case 'r': // Range + { + Debug.Assert(elemtypoid > 0); + if (!byOID.TryGetValue(elemtypoid, out var subtypePostgresType)) + { + _connectionLogger.LogTrace("Range type '{RangeTypeName}' refers to unknown subtype with OID {ElementTypeOID}, skipping", + typname, elemtypoid); + continue; + } - case 'r': // Range - { - Debug.Assert(elementOID > 0); - if (!byOID.TryGetValue(elementOID, out var subtypePostgresType)) - { - Log.Trace($"Range type '{internalName}' refers to unknown subtype with OID {elementOID}, skipping", conn.ProcessID); - continue; - } + var rangeType = new PostgresRangeType(nspname, typname, oid, subtypePostgresType); + byOID[rangeType.OID] = rangeType; + continue; + } - var rangeType = new PostgresRangeType(ns, internalName, oid, subtypePostgresType); - byOID[rangeType.OID] = rangeType; - continue; - } + case 'm': // Multirange + Debug.Assert(elemtypoid > 0); + if (!byOID.TryGetValue(elemtypoid, out var type)) + { + _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to unknown range with OID {ElementTypeOID}, skipping", + typname, elemtypoid); + continue; + } - case 'e': // Enum - var enumType = new PostgresEnumType(ns, internalName, oid); - byOID[enumType.OID] = enumType; - continue; + if (type is not PostgresRangeType rangePostgresType) + { + _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to non-range type '{TypeName}', skipping", + typname, type.Name); + continue; + } - case 'c': // Composite - var compositeType = new PostgresCompositeType(ns, internalName, oid); - byOID[compositeType.OID] = compositeType; - continue; + var multirangeType = new PostgresMultirangeType(nspname, typname, oid, rangePostgresType); + byOID[multirangeType.OID] = multirangeType; + continue; - case 'd': // Domain - Debug.Assert(elementOID > 0); - if (!byOID.TryGetValue(elementOID, out var basePostgresType)) - { - Log.Trace($"Domain type '{internalName}' refers to unknown base type with OID {elementOID}, skipping", conn.ProcessID); - continue; - } - var domainType = new PostgresDomainType(ns, internalName, oid, basePostgresType, reader.GetString("typnotnull") == "t"); - byOID[domainType.OID] = domainType; - continue; + case 'e': // Enum + var enumType = new PostgresEnumType(nspname, typname, oid); + byOID[enumType.OID] = enumType; + continue; - case 'p': // pseudo-type (record, void) - goto case 'b'; // Hack this as a base type + case 'c': // Composite + var compositeType = new PostgresCompositeType(nspname, typname, oid); + byOID[compositeType.OID] = compositeType; + continue; - default: - throw new ArgumentOutOfRangeException($"Unknown typtype for type '{internalName}' in pg_type: {typeChar}"); + case 'd': // Domain + Debug.Assert(elemtypoid > 0); + if (!byOID.TryGetValue(elemtypoid, out var basePostgresType)) + { + _connectionLogger.LogTrace("Domain type '{DomainTypeName}' refers to unknown base type with OID {ElementTypeOID}, skipping", + typname, elemtypoid); + continue; } - } - - if (async) - await reader.NextResultAsync(); - else - reader.NextResult(); - LoadCompositeFields(reader, byOID); + var domainType = new PostgresDomainType(nspname, typname, oid, basePostgresType, typnotnull); + byOID[domainType.OID] = domainType; + continue; - if (SupportsEnumTypes) - { - if (async) - await reader.NextResultAsync(); - else - reader.NextResult(); + case 'p': // pseudo-type (record, void) + goto case 'b'; // Hack this as a base type - LoadEnumLabels(reader, byOID); + default: + throw new ArgumentOutOfRangeException($"Unknown typtype for type '{typname}' in pg_type: {typtype}"); } - - return byOID.Values.ToList(); } + Expect(msg, conn); + if (isReplicationConnection) + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); - /// - /// Loads composite fields for the composite type specified by the OID. - /// - /// The reader from which to read composite fields. - /// The OID of the composite type for which fields are read. - static void LoadCompositeFields(NpgsqlDataReader reader, Dictionary byOID) - { - var currentOID = uint.MaxValue; - PostgresCompositeType? currentComposite = null; - var skipCurrent = false; + // Then load the composite type fields + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); - while (reader.Read()) - { - var oid = uint.Parse(reader.GetString("oid"), NumberFormatInfo.InvariantInfo); - if (oid != currentOID) - { - currentOID = oid; + var currentOID = uint.MaxValue; + PostgresCompositeType? currentComposite = null; + var skipCurrent = false; - if (!byOID.TryGetValue(oid, out var type)) // See #2020 - { - Log.Warn($"Skipping composite type with OID {oid} which was not found in pg_type"); - byOID.Remove(oid); - skipCurrent = true; - continue; - } + while (true) + { + msg = await conn.ReadMessage(async).ConfigureAwait(false); + if (msg is not DataRowMessage) + break; - currentComposite = type as PostgresCompositeType; - if (currentComposite == null) - { - Log.Warn($"Type {type.Name} was referenced as a composite type but is a {type.GetType()}"); - byOID.Remove(oid); - skipCurrent = true; - continue; - } + conn.ReadBuffer.Skip(2); // Column count + var oid = uint.Parse(ReadNonNullableString(conn.ReadBuffer), NumberFormatInfo.InvariantInfo); + var attname = ReadNonNullableString(conn.ReadBuffer); + var atttypid = uint.Parse(ReadNonNullableString(conn.ReadBuffer), NumberFormatInfo.InvariantInfo); - skipCurrent = false; - } + if (oid != currentOID) + { + currentOID = oid; - if (skipCurrent) + if (!byOID.TryGetValue(oid, out var type)) // See #2020 + { + _connectionLogger.LogWarning("Skipping composite type with OID {CompositeTypeOID} which was not found in pg_type", oid); + byOID.Remove(oid); + skipCurrent = true; continue; + } - var fieldName = reader.GetString("attname"); - var fieldTypeOID = uint.Parse(reader.GetString("atttypid"), NumberFormatInfo.InvariantInfo); - if (!byOID.TryGetValue(fieldTypeOID, out var fieldType)) // See #2020 + currentComposite = type as PostgresCompositeType; + if (currentComposite == null) { - Log.Warn($"Skipping composite type {currentComposite!.DisplayName} with field {fieldName} with type OID {fieldTypeOID}, which could not be resolved to a PostgreSQL type."); + _connectionLogger.LogWarning("Type {TypeName} was referenced as a composite type but is a {type}", type.Name, type.GetType()); byOID.Remove(oid); skipCurrent = true; continue; } - currentComposite!.MutableFields.Add(new PostgresCompositeType.Field(fieldName, fieldType)); + skipCurrent = false; } + + if (skipCurrent) + continue; + + if (!byOID.TryGetValue(atttypid, out var fieldType)) // See #2020 + { + _connectionLogger.LogWarning("Skipping composite type '{CompositeTypeName}' with field '{fieldName}' with type OID '{FieldTypeOID}', which could not be resolved to a PostgreSQL type.", + currentComposite!.DisplayName, attname, atttypid); + byOID.Remove(oid); + skipCurrent = true; + continue; + } + + currentComposite!.MutableFields.Add(new PostgresCompositeType.Field(attname, fieldType)); } + Expect(msg, conn); + if (isReplicationConnection) + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); - /// - /// Loads enum labels for the enum type specified by the OID. - /// - /// The reader from which to read enum labels. - /// The OID of the enum type for which labels are read. - static void LoadEnumLabels(NpgsqlDataReader reader, Dictionary byOID) + if (SupportsEnumTypes) { - var currentOID = uint.MaxValue; + // Then load the enum fields + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + + currentOID = uint.MaxValue; PostgresEnumType? currentEnum = null; - var skipCurrent = false; + skipCurrent = false; - while (reader.Read()) + while (true) { - var oid = uint.Parse(reader.GetString("oid"), NumberFormatInfo.InvariantInfo); + msg = await conn.ReadMessage(async).ConfigureAwait(false); + if (msg is not DataRowMessage) + break; + + conn.ReadBuffer.Skip(2); // Column count + var oid = uint.Parse(ReadNonNullableString(conn.ReadBuffer), NumberFormatInfo.InvariantInfo); + var enumlabel = ReadNonNullableString(conn.ReadBuffer); if (oid != currentOID) { currentOID = oid; if (!byOID.TryGetValue(oid, out var type)) // See #2020 { - Log.Warn($"Skipping enum type with OID {oid} which was not found in pg_type"); + _connectionLogger.LogWarning("Skipping enum type with OID {OID} which was not found in pg_type", oid); byOID.Remove(oid); skipCurrent = true; continue; @@ -413,7 +519,7 @@ static void LoadEnumLabels(NpgsqlDataReader reader, Dictionary(msg, conn); + if (isReplicationConnection) + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); } + + if (!isReplicationConnection) + Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); + + return new(byOID.Values); + + static string ReadNonNullableString(NpgsqlReadBuffer buffer) + => buffer.ReadString(buffer.ReadInt32()); } } diff --git a/src/Npgsql/PostgresEnvironment.cs b/src/Npgsql/PostgresEnvironment.cs index 66f393498e..69036601e5 100644 --- a/src/Npgsql/PostgresEnvironment.cs +++ b/src/Npgsql/PostgresEnvironment.cs @@ -1,41 +1,58 @@ using System; -using System.Collections.Generic; using System.IO; -using Npgsql.Util; +using System.Runtime.InteropServices; -namespace Npgsql +namespace Npgsql; + +static class PostgresEnvironment { - static class PostgresEnvironment - { - public static string? User => Environment.GetEnvironmentVariable("PGUSER"); + internal static string? User => Environment.GetEnvironmentVariable("PGUSER"); + + internal static string? Password => Environment.GetEnvironmentVariable("PGPASSWORD"); + + internal static string? PassFile => Environment.GetEnvironmentVariable("PGPASSFILE"); + + internal static string? PassFileDefault + => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? GetHomePostgresDir() : GetHomeDir()) is string homedir && + Path.Combine(homedir, RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "pgpass.conf" : ".pgpass") is var path && + File.Exists(path) + ? path + : null; - public static string? Password => Environment.GetEnvironmentVariable("PGPASSWORD"); + internal static string? SslCert => Environment.GetEnvironmentVariable("PGSSLCERT"); - public static string? PassFile => Environment.GetEnvironmentVariable("PGPASSFILE"); + internal static string? SslCertDefault + => GetHomePostgresDir() is string homedir && Path.Combine(homedir, "postgresql.crt") is var path && File.Exists(path) + ? path + : null; - public static string? PassFileDefault => GetDefaultFilePath(PGUtil.IsWindows ? "pgpass.conf" : ".pgpass"); + internal static string? SslKey => Environment.GetEnvironmentVariable("PGSSLKEY"); - public static string? SslCert => Environment.GetEnvironmentVariable("PGSSLCERT"); + internal static string? SslKeyDefault + => GetHomePostgresDir() is string homedir && Path.Combine(homedir, "postgresql.key") is var path && File.Exists(path) + ? path + : null; - public static string? SslCertDefault => GetDefaultFilePath("postgresql.crt"); + internal static string? SslCertRoot => Environment.GetEnvironmentVariable("PGSSLROOTCERT"); - public static string? SslCertRoot => Environment.GetEnvironmentVariable("PGSSLROOTCERT"); + internal static string? SslCertRootDefault + => GetHomePostgresDir() is string homedir && Path.Combine(homedir, "root.crt") is var path && File.Exists(path) + ? path + : null; - public static string? SslCertRootDefault => GetDefaultFilePath("root.crt"); + internal static string? ClientEncoding => Environment.GetEnvironmentVariable("PGCLIENTENCODING"); - public static string? SslKey => Environment.GetEnvironmentVariable("PGSSLKEY"); + internal static string? TimeZone => Environment.GetEnvironmentVariable("PGTZ"); - public static string? ClientEncoding => Environment.GetEnvironmentVariable("PGCLIENTENCODING"); + internal static string? Options => Environment.GetEnvironmentVariable("PGOPTIONS"); - public static string? TimeZone => Environment.GetEnvironmentVariable("PGTZ"); + internal static string? TargetSessionAttributes => Environment.GetEnvironmentVariable("PGTARGETSESSIONATTRS"); - public static string? Options => Environment.GetEnvironmentVariable("PGOPTIONS"); + static string? GetHomeDir() + => Environment.GetEnvironmentVariable(RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "APPDATA" : "HOME"); - static string? GetDefaultFilePath(string fileName) => - Environment.GetEnvironmentVariable(PGUtil.IsWindows ? "APPDATA" : "HOME") is string appData && - Path.Combine(appData, "postgresql", fileName) is string filePath && - File.Exists(filePath) - ? filePath - : null; - } -} + static string? GetHomePostgresDir() + => GetHomeDir() is string homedir + ? Path.Combine(homedir, RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "postgresql" : ".postgresql") + : null; +} \ No newline at end of file diff --git a/src/Npgsql/PostgresErrorCodes.cs b/src/Npgsql/PostgresErrorCodes.cs index c260f8cd65..afeadbf2c6 100644 --- a/src/Npgsql/PostgresErrorCodes.cs +++ b/src/Npgsql/PostgresErrorCodes.cs @@ -1,465 +1,488 @@ #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member -namespace Npgsql +using System; + +namespace Npgsql; + +/// +/// Provides constants for PostgreSQL error codes. +/// +/// +/// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html +/// +public static class PostgresErrorCodes { - /// - /// Provides constants for PostgreSQL error codes. - /// - /// - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html - /// - public static class PostgresErrorCodes - { - #region Class 00 - Successful Completion + #region Class 00 - Successful Completion + + public const string SuccessfulCompletion = "00000"; - public const string SuccessfulCompletion = "00000"; + #endregion Class 00 - Successful Completion - #endregion Class 00 - Successful Completion + #region Class 01 - Warning - #region Class 01 - Warning + public const string Warning = "01000"; + public const string DynamicResultSetsReturnedWarning = "0100C"; + public const string ImplicitZeroBitPaddingWarning = "01008"; + public const string NullValueEliminatedInSetFunctionWarning = "01003"; + public const string PrivilegeNotGrantedWarning = "01007"; + public const string PrivilegeNotRevokedWarning = "01006"; + public const string StringDataRightTruncationWarning = "01004"; + public const string DeprecatedFeatureWarning = "01P01"; - public const string Warning = "01000"; - public const string DynamicResultSetsReturnedWarning = "0100C"; - public const string ImplicitZeroBitPaddingWarning = "01008"; - public const string NullValueEliminatedInSetFunctionWarning = "01003"; - public const string PrivilegeNotGrantedWarning = "01007"; - public const string PrivilegeNotRevokedWarning = "01006"; - public const string StringDataRightTruncationWarning = "01004"; - public const string DeprecatedFeatureWarning = "01P01"; + #endregion Class 01 - Warning - #endregion Class 01 - Warning + #region Class 02 - No Data - #region Class 02 - No Data + public const string NoData = "02000"; + public const string NoAdditionalDynamicResultSetsReturned = "02001"; - public const string NoData = "02000"; - public const string NoAdditionalDynamicResultSetsReturned = "02001"; + #endregion Class 02 - No Data - #endregion Class 02 - No Data + #region Class 03 - SQL Statement Not Yet Complete - #region Class 03 - SQL Statement Not Yet Complete + public const string SqlStatementNotYetComplete = "03000"; - public const string SqlStatementNotYetComplete = "03000"; + #endregion Class 03 - SQL Statement Not Yet Complete - #endregion Class 03 - SQL Statement Not Yet Complete + #region Class 08 - Connection Exception - #region Class 08 - Connection Exception + public const string ConnectionException = "08000"; + public const string ConnectionDoesNotExist = "08003"; + public const string ConnectionFailure = "08006"; + public const string SqlClientUnableToEstablishSqlConnection = "08001"; + public const string SqlServerRejectedEstablishmentOfSqlConnection = "08004"; + public const string TransactionResolutionUnknown = "08007"; + public const string ProtocolViolation = "08P01"; - public const string ConnectionException = "08000"; - public const string ConnectionDoesNotExist = "08003"; - public const string ConnectionFailure = "08006"; - public const string SqlClientUnableToEstablishSqlConnection = "08001"; - public const string SqlServerRejectedEstablishmentOfSqlConnection = "08004"; - public const string TransactionResolutionUnknown = "08007"; - public const string ProtocolViolation = "08P01"; + #endregion Class 08 - Connection Exception - #endregion Class 08 - Connection Exception + #region Class 09 - Triggered Action Exception - #region Class 09 - Triggered Action Exception + public const string TriggeredActionException = "09000"; - public const string TriggeredActionException = "09000"; + #endregion Class 09 - Triggered Action Exception - #endregion Class 09 - Triggered Action Exception + #region Class 0A - Feature Not Supported - #region Class 0A - Feature Not Supported + public const string FeatureNotSupported = "0A000"; - public const string FeatureNotSupported = "0A000"; + #endregion Class 0A - Feature Not Supported - #endregion Class 0A - Feature Not Supported + #region Class 0B - Invalid Transaction Initiation - #region Class 0B - Invalid Transaction Initiation + public const string InvalidTransactionInitiation = "0B000"; - public const string InvalidTransactionInitiation = "0B000"; + #endregion Class 0B - Invalid Transaction Initiation - #endregion Class 0B - Invalid Transaction Initiation + #region Class 0F - Locator Exception - #region Class 0F - Locator Exception + public const string LocatorException = "0F000"; + public const string InvalidLocatorSpecification = "0F001"; - public const string LocatorException = "0F000"; - public const string InvalidLocatorSpecification = "0F001"; + #endregion Class 0F - Locator Exception - #endregion Class 0F - Locator Exception + #region Class 0L - Invalid Grantor - #region Class 0L - Invalid Grantor + public const string InvalidGrantor = "0L000"; + public const string InvalidGrantOperation = "0LP01"; - public const string InvalidGrantor = "0L000"; - public const string InvalidGrantOperation = "0LP01"; + #endregion Class 0L - Invalid Grantor - #endregion Class 0L - Invalid Grantor + #region Class 0P - Invalid Role Specification - #region Class 0P - Invalid Role Specification + public const string InvalidRoleSpecification = "0P000"; - public const string InvalidRoleSpecification = "0P000"; + #endregion Class 0P - Invalid Role Specification - #endregion Class 0P - Invalid Role Specification + #region Class 0Z - Diagnostics Exception - #region Class 0Z - Diagnostics Exception + public const string DiagnosticsException = "0Z000"; + public const string StackedDiagnosticsAccessedWithoutActiveHandler = "0Z002"; - public const string DiagnosticsException = "0Z000"; - public const string StackedDiagnosticsAccessedWithoutActiveHandler = "0Z002"; + #endregion Class 0Z - Diagnostics Exception - #endregion Class 0Z - Diagnostics Exception + #region Class 20 - Case Not Found - #region Class 20 - Case Not Found + public const string CaseNotFound = "20000"; - public const string CaseNotFound = "20000"; + #endregion Class 20 - Case Not Found - #endregion Class 20 - Case Not Found + #region Class 21 - CardinalityViolation - #region Class 21 - CardinalityViolation + public const string CardinalityViolation = "21000"; - public const string CardinalityViolation = "21000"; + #endregion Class 21 - CardinalityViolation - #endregion Class 21 - CardinalityViolation + #region Class 22 - Data Exception - #region Class 22 - Data Exception + public const string DataException = "22000"; + public const string ArraySubscriptError = "2202E"; + public const string CharacterNotInRepertoire = "22021"; + public const string DatetimeFieldOverflow = "22008"; + public const string DivisionByZero = "22012"; + public const string ErrorInAssignment = "22005"; + public const string EscapeCharacterConflict = "2200B"; + public const string IndicatorOverflow = "22022"; + public const string IntervalFieldOverflow = "22015"; + public const string InvalidArgumentForLogarithm = "2201E"; + public const string InvalidArgumentForNtileFunction = "22014"; + public const string InvalidArgumentForNthValueFunction = "22016"; + public const string InvalidArgumentForPowerFunction = "2201F"; + public const string InvalidArgumentForWidthBucketFunction = "2201G"; + public const string InvalidCharacterValueForCast = "22018"; + public const string InvalidDatetimeFormat = "22007"; + public const string InvalidEscapeCharacter = "22019"; + public const string InvalidEscapeOctet = "2200D"; + public const string InvalidEscapeSequence = "22025"; + public const string NonstandardUseOfEscapeCharacter = "22P06"; + public const string InvalidIndicatorParameterValue = "22010"; + public const string InvalidParameterValue = "22023"; + public const string InvalidRegularExpression = "2201B"; + public const string InvalidRowCountInLimitClause = "2201W"; + public const string InvalidRowCountInResultOffsetClause = "2201X"; + public const string InvalidTablesampleArgument = "2202H"; + public const string InvalidTablesampleRepeat = "2202G"; + public const string InvalidTimeZoneDisplacementValue = "22009"; + public const string InvalidUseOfEscapeCharacter = "2200C"; + public const string MostSpecificTypeMismatch = "2200G"; + public const string NullValueNotAllowed = "22004"; + public const string NullValueNoIndicatorParameter = "22002"; + public const string NumericValueOutOfRange = "22003"; + public const string StringDataLengthMismatch = "22026"; + public const string StringDataRightTruncation = "22001"; + public const string SubstringError = "22011"; + public const string TrimError = "22027"; + public const string UnterminatedCString = "22024"; + public const string ZeroLengthCharacterString = "2200F"; + public const string FloatingPointException = "22P01"; + public const string InvalidTextRepresentation = "22P02"; + public const string InvalidBinaryRepresentation = "22P03"; + public const string BadCopyFileFormat = "22P04"; + public const string UntranslatableCharacter = "22P05"; + public const string NotAnXmlDocument = "2200L"; + public const string InvalidXmlDocument = "2200M"; + public const string InvalidXmlContent = "2200N"; + public const string InvalidXmlComment = "2200S"; + public const string InvalidXmlProcessingInstruction = "2200T"; - public const string DataException = "22000"; - public const string ArraySubscriptError = "2202E"; - public const string CharacterNotInRepertoire = "22021"; - public const string DatetimeFieldOverflow = "22008"; - public const string DivisionByZero = "22012"; - public const string ErrorInAssignment = "22005"; - public const string EscapeCharacterConflict = "2200B"; - public const string IndicatorOverflow = "22022"; - public const string IntervalFieldOverflow = "22015"; - public const string InvalidArgumentForLogarithm = "2201E"; - public const string InvalidArgumentForNtileFunction = "22014"; - public const string InvalidArgumentForNthValueFunction = "22016"; - public const string InvalidArgumentForPowerFunction = "2201F"; - public const string InvalidArgumentForWidthBucketFunction = "2201G"; - public const string InvalidCharacterValueForCast = "22018"; - public const string InvalidDatetimeFormat = "22007"; - public const string InvalidEscapeCharacter = "22019"; - public const string InvalidEscapeOctet = "2200D"; - public const string InvalidEscapeSequence = "22025"; - public const string NonstandardUseOfEscapeCharacter = "22P06"; - public const string InvalidIndicatorParameterValue = "22010"; - public const string InvalidParameterValue = "22023"; - public const string InvalidRegularExpression = "2201B"; - public const string InvalidRowCountInLimitClause = "2201W"; - public const string InvalidRowCountInResultOffsetClause = "2201X"; - public const string InvalidTablesampleArgument = "2202H"; - public const string InvalidTablesampleRepeat = "2202G"; - public const string InvalidTimeZoneDisplacementValue = "22009"; - public const string InvalidUseOfEscapeCharacter = "2200C"; - public const string MostSpecificTypeMismatch = "2200G"; - public const string NullValueNotAllowed = "22004"; - public const string NullValueNoIndicatorParameter = "22002"; - public const string NumericValueOutOfRange = "22003"; - public const string StringDataLengthMismatch = "22026"; - public const string StringDataRightTruncation = "22001"; - public const string SubstringError = "22011"; - public const string TrimError = "22027"; - public const string UnterminatedCString = "22024"; - public const string ZeroLengthCharacterString = "2200F"; - public const string FloatingPointException = "22P01"; - public const string InvalidTextRepresentation = "22P02"; - public const string InvalidBinaryRepresentation = "22P03"; - public const string BadCopyFileFormat = "22P04"; - public const string UntranslatableCharacter = "22P05"; - public const string NotAnXmlDocument = "2200L"; - public const string InvalidXmlDocument = "2200M"; - public const string InvalidXmlContent = "2200N"; - public const string InvalidXmlComment = "2200S"; - public const string InvalidXmlProcessingInstruction = "2200T"; + #endregion Class 22 - Data Exception - #endregion Class 22 - Data Exception + #region Class 23 - Integrity Constraint Violation - #region Class 23 - Integrity Constraint Violation + public const string IntegrityConstraintViolation = "23000"; + public const string RestrictViolation = "23001"; + public const string NotNullViolation = "23502"; + public const string ForeignKeyViolation = "23503"; + public const string UniqueViolation = "23505"; + public const string CheckViolation = "23514"; + public const string ExclusionViolation = "23P01"; - public const string IntegrityConstraintViolation = "23000"; - public const string RestrictViolation = "23001"; - public const string NotNullViolation = "23502"; - public const string ForeignKeyViolation = "23503"; - public const string UniqueViolation = "23505"; - public const string CheckViolation = "23514"; - public const string ExclusionViolation = "23P01"; + #endregion Class 23 - Integrity Constraint Violation - #endregion Class 23 - Integrity Constraint Violation + #region Class 24 - Invalid Cursor State - #region Class 24 - Invalid Cursor State + public const string InvalidCursorState = "24000"; - public const string InvalidCursorState = "24000"; + #endregion Class 24 - Invalid Cursor State - #endregion Class 24 - Invalid Cursor State + #region Class 25 - Invalid Transaction State - #region Class 25 - Invalid Transaction State + public const string InvalidTransactionState = "25000"; + public const string ActiveSqlTransaction = "25001"; + public const string BranchTransactionAlreadyActive = "25002"; + public const string HeldCursorRequiresSameIsolationLevel = "25008"; + public const string InappropriateAccessModeForBranchTransaction = "25003"; + public const string InappropriateIsolationLevelForBranchTransaction = "25004"; + public const string NoActiveSqlTransactionForBranchTransaction = "25005"; + public const string ReadOnlySqlTransaction = "25006"; + public const string SchemaAndDataStatementMixingNotSupported = "25007"; + public const string NoActiveSqlTransaction = "25P01"; + public const string InFailedSqlTransaction = "25P02"; - public const string InvalidTransactionState = "25000"; - public const string ActiveSqlTransaction = "25001"; - public const string BranchTransactionAlreadyActive = "25002"; - public const string HeldCursorRequiresSameIsolationLevel = "25008"; - public const string InappropriateAccessModeForBranchTransaction = "25003"; - public const string InappropriateIsolationLevelForBranchTransaction = "25004"; - public const string NoActiveSqlTransactionForBranchTransaction = "25005"; - public const string ReadOnlySqlTransaction = "25006"; - public const string SchemaAndDataStatementMixingNotSupported = "25007"; - public const string NoActiveSqlTransaction = "25P01"; - public const string InFailedSqlTransaction = "25P02"; + #endregion Class 25 - Invalid Transaction State - #endregion Class 25 - Invalid Transaction State + #region Class 26 - Invalid SQL Statement Name - #region Class 26 - Invalid SQL Statement Name + public const string InvalidSqlStatementName = "26000"; - public const string InvalidSqlStatementName = "26000"; + #endregion Class 26 - Invalid SQL Statement Name - #endregion Class 26 - Invalid SQL Statement Name + #region Class 27 - Triggered Data Change Violation - #region Class 27 - Triggered Data Change Violation + public const string TriggeredDataChangeViolation = "27000"; - public const string TriggeredDataChangeViolation = "27000"; + #endregion Class 27 - Triggered Data Change Violation - #endregion Class 27 - Triggered Data Change Violation + #region Class 28 - Invalid Authorization Scheme - #region Class 28 - Invalid Authorization Scheme + public const string InvalidAuthorizationSpecification = "28000"; + public const string InvalidPassword = "28P01"; - public const string InvalidAuthorizationSpecification = "28000"; - public const string InvalidPassword = "28P01"; + #endregion Class 28 - Invalid Authorization Scheme - #endregion Class 28 - Invalid Authorization Scheme + #region Class 2B - Dependent Privilege Descriptors Still Exist - #region Class 2B - Dependent Privilege Descriptors Still Exist + public const string DependentPrivilegeDescriptorsStillExist = "2B000"; + public const string DependentObjectsStillExist = "2BP01"; - public const string DependentPrivilegeDescriptorsStillExist = "2B000"; - public const string DependentObjectsStillExist = "2BP01"; + #endregion Class 2B - Dependent Privilege Descriptors Still Exist - #endregion Class 2B - Dependent Privilege Descriptors Still Exist + #region Class 2D - Invalid Transaction Termination - #region Class 2D - Invalid Transaction Termination + public const string InvalidTransactionTermination = "2D000"; - public const string InvalidTransactionTermination = "2D000"; + #endregion Class 2D - Invalid Transaction Termination - #endregion Class 2D - Invalid Transaction Termination + #region Class 2F - SQL Routine Exception - #region Class 2F - SQL Routine Exception + public const string SqlRoutineException = "2F000"; + public const string FunctionExecutedNoReturnStatementSqlRoutineException = "2F005"; + public const string ModifyingSqlDataNotPermittedSqlRoutineException = "2F002"; + public const string ProhibitedSqlStatementAttemptedSqlRoutineException = "2F003"; + public const string ReadingSqlDataNotPermittedSqlRoutineException = "2F004"; - public const string SqlRoutineException = "2F000"; - public const string FunctionExecutedNoReturnStatementSqlRoutineException = "2F005"; - public const string ModifyingSqlDataNotPermittedSqlRoutineException = "2F002"; - public const string ProhibitedSqlStatementAttemptedSqlRoutineException = "2F003"; - public const string ReadingSqlDataNotPermittedSqlRoutineException = "2F004"; + #endregion Class 2F - SQL Routine Exception - #endregion Class 2F - SQL Routine Exception + #region Class 34 - Invalid Cursor Name - #region Class 34 - Invalid Cursor Name + public const string InvalidCursorName = "34000"; - public const string InvalidCursorName = "34000"; + #endregion Class 34 - Invalid Cursor Name - #endregion Class 34 - Invalid Cursor Name + #region Class 38 - External Routine Exception - #region Class 38 - External Routine Exception + public const string ExternalRoutineException = "38000"; + public const string ContainingSqlNotPermittedExternalRoutineException = "38001"; + public const string ModifyingSqlDataNotPermittedExternalRoutineException = "38002"; + public const string ProhibitedSqlStatementAttemptedExternalRoutineException = "38003"; + public const string ReadingSqlDataNotPermittedExternalRoutineException = "38004"; - public const string ExternalRoutineException = "38000"; - public const string ContainingSqlNotPermittedExternalRoutineException = "38001"; - public const string ModifyingSqlDataNotPermittedExternalRoutineException = "38002"; - public const string ProhibitedSqlStatementAttemptedExternalRoutineException = "38003"; - public const string ReadingSqlDataNotPermittedExternalRoutineException = "38004"; + #endregion Class 38 - External Routine Exception - #endregion Class 38 - External Routine Exception + #region Class 39 - External Routine Invocation Exception - #region Class 39 - External Routine Invocation Exception + public const string ExternalRoutineInvocationException = "39000"; + public const string InvalidSqlstateReturnedExternalRoutineInvocationException = "39001"; + public const string NullValueNotAllowedExternalRoutineInvocationException = "39004"; + public const string TriggerProtocolViolatedExternalRoutineInvocationException = "39P01"; + public const string SrfProtocolViolatedExternalRoutineInvocationException = "39P02"; + public const string EventTriggerProtocolViolatedExternalRoutineInvocationException = "39P03"; - public const string ExternalRoutineInvocationException = "39000"; - public const string InvalidSqlstateReturnedExternalRoutineInvocationException = "39001"; - public const string NullValueNotAllowedExternalRoutineInvocationException = "39004"; - public const string TriggerProtocolViolatedExternalRoutineInvocationException = "39P01"; - public const string SrfProtocolViolatedExternalRoutineInvocationException = "39P02"; - public const string EventTriggerProtocolViolatedExternalRoutineInvocationException = "39P03"; + #endregion Class 39 - External Routine Invocation Exception - #endregion Class 39 - External Routine Invocation Exception + #region Class 3B - Savepoint Exception - #region Class 3B - Savepoint Exception + public const string SavepointException = "3B000"; + public const string InvalidSavepointSpecification = "3B001"; - public const string SavepointException = "3B000"; - public const string InvalidSavepointSpecification = "3B001"; + #endregion Class 3B - Savepoint Exception - #endregion Class 3B - Savepoint Exception + #region Class 3D - Invalid Catalog Name - #region Class 3D - Invalid Catalog Name + public const string InvalidCatalogName = "3D000"; - public const string InvalidCatalogName = "3D000"; + #endregion Class 3D - Invalid Catalog Name - #endregion Class 3D - Invalid Catalog Name + #region Class 3F - Invalid Schema Name - #region Class 3F - Invalid Schema Name + public const string InvalidSchemaName = "3F000"; - public const string InvalidSchemaName = "3F000"; + #endregion Class 3F - Invalid Schema Name - #endregion Class 3F - Invalid Schema Name + #region Class 40 - Transaction Rollback - #region Class 40 - Transaction Rollback + public const string TransactionRollback = "40000"; + public const string TransactionIntegrityConstraintViolation = "40002"; + public const string SerializationFailure = "40001"; + public const string StatementCompletionUnknown = "40003"; + public const string DeadlockDetected = "40P01"; - public const string TransactionRollback = "40000"; - public const string TransactionIntegrityConstraintViolation = "40002"; - public const string SerializationFailure = "40001"; - public const string StatementCompletionUnknown = "40003"; - public const string DeadlockDetected = "40P01"; + #endregion Class 40 - Transaction Rollback - #endregion Class 40 - Transaction Rollback + #region Class 42 - Syntax Error or Access Rule Violation - #region Class 42 - Syntax Error or Access Rule Violation + public const string SyntaxErrorOrAccessRuleViolation = "42000"; + public const string SyntaxError = "42601"; + public const string InsufficientPrivilege = "42501"; + public const string CannotCoerce = "42846"; + public const string GroupingError = "42803"; + public const string WindowingError = "42P20"; + public const string InvalidRecursion = "42P19"; + public const string InvalidForeignKey = "42830"; + public const string InvalidName = "42602"; + public const string NameTooLong = "42622"; + public const string ReservedName = "42939"; + public const string DatatypeMismatch = "42804"; + public const string IndeterminateDatatype = "42P18"; + public const string CollationMismatch = "42P21"; + public const string IndeterminateCollation = "42P22"; + public const string WrongObjectType = "42809"; + public const string UndefinedColumn = "42703"; + public const string UndefinedFunction = "42883"; + public const string UndefinedTable = "42P01"; + public const string UndefinedParameter = "42P02"; + public const string UndefinedObject = "42704"; + public const string DuplicateColumn = "42701"; + public const string DuplicateCursor = "42P03"; + public const string DuplicateDatabase = "42P04"; + public const string DuplicateFunction = "42723"; + public const string DuplicatePreparedStatement = "42P05"; + public const string DuplicateSchema = "42P06"; + public const string DuplicateTable = "42P07"; + public const string DuplicateAlias = "42712"; + public const string DuplicateObject = "42710"; + public const string AmbiguousColumn = "42702"; + public const string AmbiguousFunction = "42725"; + public const string AmbiguousParameter = "42P08"; + public const string AmbiguousAlias = "42P09"; + public const string InvalidColumnReference = "42P10"; + public const string InvalidColumnDefinition = "42611"; + public const string InvalidCursorDefinition = "42P11"; + public const string InvalidDatabaseDefinition = "42P12"; + public const string InvalidFunctionDefinition = "42P13"; + public const string InvalidPreparedStatementDefinition = "42P14"; + public const string InvalidSchemaDefinition = "42P15"; + public const string InvalidTableDefinition = "42P16"; + public const string InvalidObjectDefinition = "42P17"; - public const string SyntaxErrorOrAccessRuleViolation = "42000"; - public const string SyntaxError = "42601"; - public const string InsufficientPrivilege = "42501"; - public const string CannotCoerce = "42846"; - public const string GroupingError = "42803"; - public const string WindowingError = "42P20"; - public const string InvalidRecursion = "42P19"; - public const string InvalidForeignKey = "42830"; - public const string InvalidName = "42602"; - public const string NameTooLong = "42622"; - public const string ReservedName = "42939"; - public const string DatatypeMismatch = "42804"; - public const string IndeterminateDatatype = "42P18"; - public const string CollationMismatch = "42P21"; - public const string IndeterminateCollation = "42P22"; - public const string WrongObjectType = "42809"; - public const string UndefinedColumn = "42703"; - public const string UndefinedFunction = "42883"; - public const string UndefinedTable = "42P01"; - public const string UndefinedParameter = "42P02"; - public const string UndefinedObject = "42704"; - public const string DuplicateColumn = "42701"; - public const string DuplicateCursor = "42P03"; - public const string DuplicateDatabase = "42P04"; - public const string DuplicateFunction = "42723"; - public const string DuplicatePreparedStatement = "42P05"; - public const string DuplicateSchema = "42P06"; - public const string DuplicateTable = "42P07"; - public const string DuplicateAlias = "42712"; - public const string DuplicateObject = "42710"; - public const string AmbiguousColumn = "42702"; - public const string AmbiguousFunction = "42725"; - public const string AmbiguousParameter = "42P08"; - public const string AmbiguousAlias = "42P09"; - public const string InvalidColumnReference = "42P10"; - public const string InvalidColumnDefinition = "42611"; - public const string InvalidCursorDefinition = "42P11"; - public const string InvalidDatabaseDefinition = "42P12"; - public const string InvalidFunctionDefinition = "42P13"; - public const string InvalidPreparedStatementDefinition = "42P14"; - public const string InvalidSchemaDefinition = "42P15"; - public const string InvalidTableDefinition = "42P16"; - public const string InvalidObjectDefinition = "42P17"; + #endregion Class 42 - Syntax Error or Access Rule Violation - #endregion Class 42 - Syntax Error or Access Rule Violation + #region Class 44 - WITH CHECK OPTION Violation - #region Class 44 - WITH CHECK OPTION Violation + public const string WithCheckOptionViolation = "44000"; + + #endregion Class 44 - WITH CHECK OPTION Violation + + #region Class 53 - Insufficient Resources + + public const string InsufficientResources = "53000"; + public const string DiskFull = "53100"; + public const string OutOfMemory = "53200"; + public const string TooManyConnections = "53300"; + public const string ConfigurationLimitExceeded = "53400"; + + #endregion Class 53 - Insufficient Resources + + #region Class 54 - Program Limit Exceeded - public const string WithCheckOptionViolation = "44000"; - - #endregion Class 44 - WITH CHECK OPTION Violation - - #region Class 53 - Insufficient Resources - - public const string InsufficientResources = "53000"; - public const string DiskFull = "53100"; - public const string OutOfMemory = "53200"; - public const string TooManyConnections = "53300"; - public const string ConfigurationLimitExceeded = "53400"; - - #endregion Class 53 - Insufficient Resources - - #region Class 54 - Program Limit Exceeded + public const string ProgramLimitExceeded = "54000"; + public const string StatementTooComplex = "54001"; + public const string TooManyColumns = "54011"; + public const string TooManyArguments = "54023"; - public const string ProgramLimitExceeded = "54000"; - public const string StatementTooComplex = "54001"; - public const string TooManyColumns = "54011"; - public const string TooManyArguments = "54023"; + #endregion Class 54 - Program Limit Exceeded - #endregion Class 54 - Program Limit Exceeded - - #region Class 55 - Object Not In Prerequisite State - - public const string ObjectNotInPrerequisiteState = "55000"; - public const string ObjectInUse = "55006"; - public const string CantChangeRuntimeParam = "55P02"; - public const string LockNotAvailable = "55P03"; - - #endregion Class 55 - Object Not In Prerequisite State - - #region Class 57 - Operator Intervention - - public const string OperatorIntervention = "57000"; - public const string QueryCanceled = "57014"; - public const string AdminShutdown = "57P01"; - public const string CrashShutdown = "57P02"; - public const string CannotConnectNow = "57P03"; - public const string DatabaseDropped = "57P04"; - - #endregion Class 57 - Operator Intervention - - #region Class 58 - System Error (errors external to PostgreSQL itself) - - public const string SystemError = "58000"; - public const string IoError = "58030"; - public const string UndefinedFile = "58P01"; - public const string DuplicateFile = "58P02"; - - #endregion Class 58 - System Error (errors external to PostgreSQL itself) - - #region Class 72 - Snapshot Failure - - public const string SnapshotFailure = "72000"; - - #endregion Class 72 - Snapshot Failure - - #region Class F0 - Configuration File Error - - public const string ConfigFileError = "F0000"; - public const string LockFileExists = "F0001"; - - #endregion Class F0 - Configuration File Error - - #region Class HV - Foreign Data Wrapper Error (SQL/MED) - - public const string FdwError = "HV000"; - public const string FdwColumnNameNotFound = "HV005"; - public const string FdwDynamicParameterValueNeeded = "HV002"; - public const string FdwFunctionSequenceError = "HV010"; - public const string FdwInconsistentDescriptorInformation = "HV021"; - public const string FdwInvalidAttributeValue = "HV024"; - public const string FdwInvalidColumnName = "HV007"; - public const string FdwInvalidColumnNumber = "HV008"; - public const string FdwInvalidDataType = "HV004"; - public const string FdwInvalidDataTypeDescriptors = "HV006"; - public const string FdwInvalidDescriptorFieldIdentifier = "HV091"; - public const string FdwInvalidHandle = "HV00B"; - public const string FdwInvalidOptionIndex = "HV00C"; - public const string FdwInvalidOptionName = "HV00D"; - public const string FdwInvalidStringLengthOrBufferLength = "HV090"; - public const string FdwInvalidStringFormat = "HV00A"; - public const string FdwInvalidUseOfNullPointer = "HV009"; - public const string FdwTooManyHandles = "HV014"; - public const string FdwOutOfMemory = "HV001"; - public const string FdwNoSchemas = "HV00P"; - public const string FdwOptionNameNotFound = "HV00J"; - public const string FdwReplyHandle = "HV00K"; - public const string FdwSchemaNotFound = "HV00Q"; - public const string FdwTableNotFound = "HV00R"; - public const string FdwUnableToCreateExecution = "HV00L"; - public const string FdwUnableToCreateReply = "HV00M"; - public const string FdwUnableToEstablishConnection = "HV00N"; + #region Class 55 - Object Not In Prerequisite State + + public const string ObjectNotInPrerequisiteState = "55000"; + public const string ObjectInUse = "55006"; + public const string CantChangeRuntimeParam = "55P02"; + public const string LockNotAvailable = "55P03"; + + #endregion Class 55 - Object Not In Prerequisite State + + #region Class 57 - Operator Intervention + + public const string OperatorIntervention = "57000"; + public const string QueryCanceled = "57014"; + public const string AdminShutdown = "57P01"; + public const string CrashShutdown = "57P02"; + public const string CannotConnectNow = "57P03"; + public const string DatabaseDropped = "57P04"; + public const string IdleSessionTimeout = "57P05"; + + #endregion Class 57 - Operator Intervention + + #region Class 58 - System Error (errors external to PostgreSQL itself) + + public const string SystemError = "58000"; + public const string IoError = "58030"; + public const string UndefinedFile = "58P01"; + public const string DuplicateFile = "58P02"; + + #endregion Class 58 - System Error (errors external to PostgreSQL itself) + + #region Class 72 - Snapshot Failure + + public const string SnapshotFailure = "72000"; + + #endregion Class 72 - Snapshot Failure + + #region Class F0 - Configuration File Error + + public const string ConfigFileError = "F0000"; + public const string LockFileExists = "F0001"; + + #endregion Class F0 - Configuration File Error + + #region Class HV - Foreign Data Wrapper Error (SQL/MED) + + public const string FdwError = "HV000"; + public const string FdwColumnNameNotFound = "HV005"; + public const string FdwDynamicParameterValueNeeded = "HV002"; + public const string FdwFunctionSequenceError = "HV010"; + public const string FdwInconsistentDescriptorInformation = "HV021"; + public const string FdwInvalidAttributeValue = "HV024"; + public const string FdwInvalidColumnName = "HV007"; + public const string FdwInvalidColumnNumber = "HV008"; + public const string FdwInvalidDataType = "HV004"; + public const string FdwInvalidDataTypeDescriptors = "HV006"; + public const string FdwInvalidDescriptorFieldIdentifier = "HV091"; + public const string FdwInvalidHandle = "HV00B"; + public const string FdwInvalidOptionIndex = "HV00C"; + public const string FdwInvalidOptionName = "HV00D"; + public const string FdwInvalidStringLengthOrBufferLength = "HV090"; + public const string FdwInvalidStringFormat = "HV00A"; + public const string FdwInvalidUseOfNullPointer = "HV009"; + public const string FdwTooManyHandles = "HV014"; + public const string FdwOutOfMemory = "HV001"; + public const string FdwNoSchemas = "HV00P"; + public const string FdwOptionNameNotFound = "HV00J"; + public const string FdwReplyHandle = "HV00K"; + public const string FdwSchemaNotFound = "HV00Q"; + public const string FdwTableNotFound = "HV00R"; + public const string FdwUnableToCreateExecution = "HV00L"; + public const string FdwUnableToCreateReply = "HV00M"; + public const string FdwUnableToEstablishConnection = "HV00N"; - #endregion Class HV - Foreign Data Wrapper Error (SQL/MED) + #endregion Class HV - Foreign Data Wrapper Error (SQL/MED) - #region Class P0 - PL/pgSQL Error + #region Class P0 - PL/pgSQL Error - public const string PlpgsqlError = "P0000"; - public const string RaiseException = "P0001"; - public const string NoDataFound = "P0002"; - public const string TooManyRows = "P0003"; - public const string AssertFailure = "P0004"; + public const string PlpgsqlError = "P0000"; + public const string RaiseException = "P0001"; + public const string NoDataFound = "P0002"; + public const string TooManyRows = "P0003"; + public const string AssertFailure = "P0004"; - #endregion Class P0 - PL/pgSQL Error + #endregion Class P0 - PL/pgSQL Error - #region Class XX - Internal Error + #region Class XX - Internal Error - public const string InternalError = "XX000"; - public const string DataCorrupted = "XX001"; - public const string IndexCorrupted = "XX002"; + public const string InternalError = "XX000"; + public const string DataCorrupted = "XX001"; + public const string IndexCorrupted = "XX002"; + + #endregion Class XX - Internal Error + + static readonly string[] CriticalFailureCodes = + { + "53", // Insufficient resources + AdminShutdown, // Self explanatory + CrashShutdown, // Self explanatory + CannotConnectNow, // Database is starting up + "58", // System errors, external to PG (server is dying) + "F0", // Configuration file error + "XX", // Internal error (database is dying) + }; + + internal static bool IsCriticalFailure(PostgresException e, bool clusterError = true) + { + foreach (var x in CriticalFailureCodes) + if (e.SqlState.StartsWith(x, StringComparison.Ordinal)) + return true; - #endregion Class XX - Internal Error + // We only treat ProtocolViolation as critical for connection + return !clusterError && e.SqlState == ProtocolViolation; } } diff --git a/src/Npgsql/PostgresException.cs b/src/Npgsql/PostgresException.cs index 9fb3198158..51b8e1e543 100644 --- a/src/Npgsql/PostgresException.cs +++ b/src/Npgsql/PostgresException.cs @@ -2,366 +2,371 @@ using System.Collections.Generic; using System.Runtime.Serialization; using System.Text; -using JetBrains.Annotations; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; - -#pragma warning disable CA1032 - -namespace Npgsql +using Npgsql.Internal; + +namespace Npgsql; + +/// +/// The exception that is thrown when the PostgreSQL backend reports errors (e.g. query +/// SQL issues, constraint violations). +/// +/// +/// This exception only corresponds to a PostgreSQL-delivered error. +/// Other errors (e.g. network issues) will be raised via , +/// and purely Npgsql-related issues which aren't related to the server will be raised +/// via the standard CLR exceptions (e.g. ). +/// +/// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html, +/// https://www.postgresql.org/docs/current/static/protocol-error-fields.html +/// +[Serializable] +public sealed class PostgresException : NpgsqlException { /// - /// The exception that is thrown when the PostgreSQL backend reports errors (e.g. query - /// SQL issues, constraint violations). + /// Creates a new instance. /// - /// - /// This exception only corresponds to a PostgreSQL-delivered error. - /// Other errors (e.g. network issues) will be raised via , - /// and purely Npgsql-related issues which aren't related to the server will be raised - /// via the standard CLR exceptions (e.g. ). - /// - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html, - /// https://www.postgresql.org/docs/current/static/protocol-error-fields.html - /// - [Serializable] - public sealed class PostgresException : NpgsqlException + public PostgresException(string messageText, string severity, string invariantSeverity, string sqlState) + : this(messageText, severity, invariantSeverity, sqlState, detail: null) {} + + /// + /// Creates a new instance. + /// + public PostgresException( + string messageText, string severity, string invariantSeverity, string sqlState, + string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, + string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, + string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, + string? line = null, string? routine = null) + : base(GetMessage(sqlState, messageText, position, detail)) { - /// - /// Creates a new instance. - /// - public PostgresException(string messageText, string severity, string invariantSeverity, string sqlState) - : this(messageText, severity, invariantSeverity, sqlState, detail: null) {} - - /// - /// Creates a new instance. - /// - public PostgresException( - string messageText, string severity, string invariantSeverity, string sqlState, - string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, - string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, - string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, - string? line = null, string? routine = null) - : base(sqlState + ": " + messageText) + MessageText = messageText; + Severity = severity; + InvariantSeverity = invariantSeverity; + SqlState = sqlState; + + Detail = detail; + Hint = hint; + Position = position; + InternalPosition = internalPosition; + InternalQuery = internalQuery; + Where = where; + SchemaName = schemaName; + TableName = tableName; + ColumnName = columnName; + DataTypeName = dataTypeName; + ConstraintName = constraintName; + File = file; + Line = line; + Routine = routine; + + AddData(nameof(Severity), Severity); + AddData(nameof(InvariantSeverity), InvariantSeverity); + AddData(nameof(SqlState), SqlState); + AddData(nameof(MessageText), MessageText); + AddData(nameof(Detail), Detail); + AddData(nameof(Hint), Hint); + AddData(nameof(Position), Position); + AddData(nameof(InternalPosition), InternalPosition); + AddData(nameof(InternalQuery), InternalQuery); + AddData(nameof(Where), Where); + AddData(nameof(SchemaName), SchemaName); + AddData(nameof(TableName), TableName); + AddData(nameof(ColumnName), ColumnName); + AddData(nameof(DataTypeName), DataTypeName); + AddData(nameof(ConstraintName), ConstraintName); + AddData(nameof(File), File); + AddData(nameof(Line), Line); + AddData(nameof(Routine), Routine); + + void AddData(string key, T value) { - MessageText = messageText; - Severity = severity; - InvariantSeverity = invariantSeverity; - SqlState = sqlState; - - Detail = detail; - Hint = hint; - Position = position; - InternalPosition = internalPosition; - InternalQuery = internalQuery; - Where = where; - SchemaName = schemaName; - TableName = tableName; - ColumnName = columnName; - DataTypeName = dataTypeName; - ConstraintName = constraintName; - File = file; - Line = line; - Routine = routine; - - AddData(nameof(Severity), Severity); - AddData(nameof(InvariantSeverity), InvariantSeverity); - AddData(nameof(SqlState), SqlState); - AddData(nameof(MessageText), MessageText); - AddData(nameof(Detail), Detail); - AddData(nameof(Hint), Hint); - AddData(nameof(Position), Position); - AddData(nameof(InternalPosition), InternalPosition); - AddData(nameof(InternalQuery), InternalQuery); - AddData(nameof(Where), Where); - AddData(nameof(SchemaName), SchemaName); - AddData(nameof(TableName), TableName); - AddData(nameof(ColumnName), ColumnName); - AddData(nameof(DataTypeName), DataTypeName); - AddData(nameof(ConstraintName), ConstraintName); - AddData(nameof(File), File); - AddData(nameof(Line), Line); - AddData(nameof(Routine), Routine); - - void AddData(string key, T value) - { - if (!EqualityComparer.Default.Equals(value, default!)) - Data.Add(key, value); - } + if (!EqualityComparer.Default.Equals(value, default!)) + Data.Add(key, value); } + } + + static string GetMessage(string sqlState, string messageText, int position, string? detail) + { + var baseMessage = sqlState + ": " + messageText; + var additionalMessage = + TryAddString("POSITION", position == 0 ? null : position.ToString()) + + TryAddString("DETAIL", detail); + return string.IsNullOrEmpty(additionalMessage) + ? baseMessage + : baseMessage + Environment.NewLine + additionalMessage; + } - PostgresException(ErrorOrNoticeMessage msg) - : this( - msg.Message, msg.Severity, msg.InvariantSeverity, msg.SqlState, - msg.Detail, msg.Hint, msg.Position, msg.InternalPosition, msg.InternalQuery, - msg.Where, msg.SchemaName, msg.TableName, msg.ColumnName, msg.DataTypeName, - msg.ConstraintName, msg.File, msg.Line, msg.Routine) {} + static string TryAddString(string text, string? value) => !string.IsNullOrWhiteSpace(value) ? $"{Environment.NewLine}{text}: {value}" : string.Empty; - internal static PostgresException Load(NpgsqlReadBuffer buf, bool includeDetail) - => new PostgresException(ErrorOrNoticeMessage.Load(buf, includeDetail)); + PostgresException(ErrorOrNoticeMessage msg) + : this( + msg.Message, msg.Severity, msg.InvariantSeverity, msg.SqlState, + msg.Detail, msg.Hint, msg.Position, msg.InternalPosition, msg.InternalQuery, + msg.Where, msg.SchemaName, msg.TableName, msg.ColumnName, msg.DataTypeName, + msg.ConstraintName, msg.File, msg.Line, msg.Routine) {} - internal PostgresException(SerializationInfo info, StreamingContext context) - : base(info, context) - { - Severity = GetValue(nameof(Severity)); - InvariantSeverity = GetValue(nameof(InvariantSeverity)); - SqlState = GetValue(nameof(SqlState)); - MessageText = GetValue(nameof(MessageText)); - Detail = GetValue(nameof(Detail)); - Hint = GetValue(nameof(Hint)); - Position = GetValue(nameof(Position)); - InternalPosition = GetValue(nameof(InternalPosition)); - InternalQuery = GetValue(nameof(InternalQuery)); - Where = GetValue(nameof(Where)); - SchemaName = GetValue(nameof(SchemaName)); - TableName = GetValue(nameof(TableName)); - ColumnName = GetValue(nameof(ColumnName)); - DataTypeName = GetValue(nameof(DataTypeName)); - ConstraintName = GetValue(nameof(ConstraintName)); - File = GetValue(nameof(File)); - Line = GetValue(nameof(Line)); - Routine = GetValue(nameof(Routine)); - - T GetValue(string propertyName) => (T)info.GetValue(propertyName, typeof(T))!; - } + internal static PostgresException Load(NpgsqlReadBuffer buf, bool includeDetail, ILogger exceptionLogger) + => new(ErrorOrNoticeMessage.Load(buf, includeDetail, exceptionLogger)); - /// - /// Populates a with the data needed to serialize the target object. - /// - /// The to populate with data. - /// The destination (see ) for this serialization. - public override void GetObjectData(SerializationInfo info, StreamingContext context) - { - base.GetObjectData(info, context); - info.AddValue(nameof(Severity), Severity); - info.AddValue(nameof(InvariantSeverity), InvariantSeverity); - info.AddValue(nameof(SqlState), SqlState); - info.AddValue(nameof(MessageText), MessageText); - info.AddValue(nameof(Detail), Detail); - info.AddValue(nameof(Hint), Hint); - info.AddValue(nameof(Position), Position); - info.AddValue(nameof(InternalPosition), InternalPosition); - info.AddValue(nameof(InternalQuery), InternalQuery); - info.AddValue(nameof(Where), Where); - info.AddValue(nameof(SchemaName), SchemaName); - info.AddValue(nameof(TableName), TableName); - info.AddValue(nameof(ColumnName), ColumnName); - info.AddValue(nameof(DataTypeName), DataTypeName); - info.AddValue(nameof(ConstraintName), ConstraintName); - info.AddValue(nameof(File), File); - info.AddValue(nameof(Line), Line); - info.AddValue(nameof(Routine), Routine); - } +#if NET8_0_OR_GREATER + [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] +#endif + internal PostgresException(SerializationInfo info, StreamingContext context) + : base(info, context) + { + Severity = GetValue(nameof(Severity)); + InvariantSeverity = GetValue(nameof(InvariantSeverity)); + SqlState = GetValue(nameof(SqlState)); + MessageText = GetValue(nameof(MessageText)); + Detail = GetValue(nameof(Detail)); + Hint = GetValue(nameof(Hint)); + Position = GetValue(nameof(Position)); + InternalPosition = GetValue(nameof(InternalPosition)); + InternalQuery = GetValue(nameof(InternalQuery)); + Where = GetValue(nameof(Where)); + SchemaName = GetValue(nameof(SchemaName)); + TableName = GetValue(nameof(TableName)); + ColumnName = GetValue(nameof(ColumnName)); + DataTypeName = GetValue(nameof(DataTypeName)); + ConstraintName = GetValue(nameof(ConstraintName)); + File = GetValue(nameof(File)); + Line = GetValue(nameof(Line)); + Routine = GetValue(nameof(Routine)); + + T GetValue(string propertyName) => (T)info.GetValue(propertyName, typeof(T))!; + } - /// - public override string ToString() + /// + /// Populates a with the data needed to serialize the target object. + /// + /// The to populate with data. + /// The destination (see ) for this serialization. +#if NET8_0_OR_GREATER + [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] +#endif + public override void GetObjectData(SerializationInfo info, StreamingContext context) + { + base.GetObjectData(info, context); + info.AddValue(nameof(Severity), Severity); + info.AddValue(nameof(InvariantSeverity), InvariantSeverity); + info.AddValue(nameof(SqlState), SqlState); + info.AddValue(nameof(MessageText), MessageText); + info.AddValue(nameof(Detail), Detail); + info.AddValue(nameof(Hint), Hint); + info.AddValue(nameof(Position), Position); + info.AddValue(nameof(InternalPosition), InternalPosition); + info.AddValue(nameof(InternalQuery), InternalQuery); + info.AddValue(nameof(Where), Where); + info.AddValue(nameof(SchemaName), SchemaName); + info.AddValue(nameof(TableName), TableName); + info.AddValue(nameof(ColumnName), ColumnName); + info.AddValue(nameof(DataTypeName), DataTypeName); + info.AddValue(nameof(ConstraintName), ConstraintName); + info.AddValue(nameof(File), File); + info.AddValue(nameof(Line), Line); + info.AddValue(nameof(Routine), Routine); + } + + /// + public override string ToString() + { + var builder = new StringBuilder(base.ToString()) + .AppendLine().Append(" Exception data:"); + + AppendLine(nameof(Severity), Severity); + AppendLine(nameof(SqlState), SqlState); + AppendLine(nameof(MessageText), MessageText); + AppendLine(nameof(Detail), Detail); + AppendLine(nameof(Hint), Hint); + AppendLine(nameof(Position), Position); + AppendLine(nameof(InternalPosition), InternalPosition); + AppendLine(nameof(InternalQuery), InternalQuery); + AppendLine(nameof(Where), Where); + AppendLine(nameof(SchemaName), SchemaName); + AppendLine(nameof(TableName), TableName); + AppendLine(nameof(ColumnName), ColumnName); + AppendLine(nameof(DataTypeName), DataTypeName); + AppendLine(nameof(ConstraintName), ConstraintName); + AppendLine(nameof(File), File); + AppendLine(nameof(Line), Line); + AppendLine(nameof(Routine), Routine); + + return builder.ToString(); + + void AppendLine(string propertyName, T propertyValue) { - var builder = new StringBuilder(base.ToString()) - .AppendLine().Append(" Exception data:"); - - AppendLine(nameof(Severity), Severity); - AppendLine(nameof(SqlState), SqlState); - AppendLine(nameof(MessageText), MessageText); - AppendLine(nameof(Detail), Detail); - AppendLine(nameof(Hint), Hint); - AppendLine(nameof(Position), Position); - AppendLine(nameof(InternalPosition), InternalPosition); - AppendLine(nameof(InternalQuery), InternalQuery); - AppendLine(nameof(Where), Where); - AppendLine(nameof(SchemaName), SchemaName); - AppendLine(nameof(TableName), TableName); - AppendLine(nameof(ColumnName), ColumnName); - AppendLine(nameof(DataTypeName), DataTypeName); - AppendLine(nameof(ConstraintName), ConstraintName); - AppendLine(nameof(File), File); - AppendLine(nameof(Line), Line); - AppendLine(nameof(Routine), Routine); - - return builder.ToString(); - - void AppendLine(string propertyName, T propertyValue) - { - if (!EqualityComparer.Default.Equals(propertyValue, default!)) - builder.AppendLine().Append(" ").Append(propertyName).Append(": ").Append(propertyValue); - } + if (!EqualityComparer.Default.Equals(propertyValue, default!)) + builder.AppendLine().Append(" ").Append(propertyName).Append(": ").Append(propertyValue); } + } - /// - /// Specifies whether the exception is considered transient, that is, whether retrying the operation could - /// succeed (e.g. a network error). Check . - /// - public override bool IsTransient + /// + /// Specifies whether the exception is considered transient, that is, whether retrying the operation could + /// succeed (e.g. a network error). Check . + /// + public override bool IsTransient + { + get { - get + switch (SqlState) { - switch (SqlState) - { - case PostgresErrorCodes.InsufficientResources: - case PostgresErrorCodes.DiskFull: - case PostgresErrorCodes.OutOfMemory: - case PostgresErrorCodes.TooManyConnections: - case PostgresErrorCodes.ConfigurationLimitExceeded: - case PostgresErrorCodes.CannotConnectNow: - case PostgresErrorCodes.SystemError: - case PostgresErrorCodes.IoError: - case PostgresErrorCodes.SerializationFailure: - case PostgresErrorCodes.LockNotAvailable: - case PostgresErrorCodes.ObjectInUse: - case PostgresErrorCodes.ObjectNotInPrerequisiteState: - case PostgresErrorCodes.ConnectionException: - case PostgresErrorCodes.ConnectionDoesNotExist: - case PostgresErrorCodes.ConnectionFailure: - case PostgresErrorCodes.SqlClientUnableToEstablishSqlConnection: - case PostgresErrorCodes.SqlServerRejectedEstablishmentOfSqlConnection: - case PostgresErrorCodes.TransactionResolutionUnknown: - return true; - default: - return false; - } + case PostgresErrorCodes.InsufficientResources: + case PostgresErrorCodes.DiskFull: + case PostgresErrorCodes.OutOfMemory: + case PostgresErrorCodes.TooManyConnections: + case PostgresErrorCodes.ConfigurationLimitExceeded: + case PostgresErrorCodes.CannotConnectNow: + case PostgresErrorCodes.SystemError: + case PostgresErrorCodes.IoError: + case PostgresErrorCodes.SerializationFailure: + case PostgresErrorCodes.DeadlockDetected: + case PostgresErrorCodes.LockNotAvailable: + case PostgresErrorCodes.ObjectInUse: + case PostgresErrorCodes.ObjectNotInPrerequisiteState: + case PostgresErrorCodes.ConnectionException: + case PostgresErrorCodes.ConnectionDoesNotExist: + case PostgresErrorCodes.ConnectionFailure: + case PostgresErrorCodes.SqlClientUnableToEstablishSqlConnection: + case PostgresErrorCodes.SqlServerRejectedEstablishmentOfSqlConnection: + case PostgresErrorCodes.TransactionResolutionUnknown: + case PostgresErrorCodes.AdminShutdown: + case PostgresErrorCodes.CrashShutdown: + case PostgresErrorCodes.IdleSessionTimeout: + return true; + default: + return false; } } + } + + #region Message Fields - /// - /// Returns the statement which triggered this exception. - /// - public NpgsqlStatement? Statement { get; internal set; } - - #region Message Fields - - /// - /// Severity of the error or notice. - /// Always present. - /// - public string Severity { get; } - - /// - /// Severity of the error or notice, not localized. - /// Always present since PostgreSQL 9.6. - /// - public string InvariantSeverity { get; } - - /// - /// The SQLSTATE code for the error. - /// - /// - /// Always present. - /// Constants are defined in . - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html - /// -#if NET - public override string SqlState { get; } + /// + /// Severity of the error or notice. + /// Always present. + /// + public string Severity { get; } + + /// + /// Severity of the error or notice, not localized. + /// Always present since PostgreSQL 9.6. + /// + public string InvariantSeverity { get; } + + /// + /// The SQLSTATE code for the error. + /// + /// + /// Always present. + /// Constants are defined in . + /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html + /// +#if NET5_0_OR_GREATER + public override string SqlState { get; } #else - public string SqlState { get; } + public string SqlState { get; } #endif - /// - /// The SQLSTATE code for the error. - /// - /// - /// Always present. - /// Constants are defined in . - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html - /// - [Obsolete("Use SqlState instead")] - public string Code => SqlState; - - /// - /// The primary human-readable error message. This should be accurate but terse. - /// - /// - /// Always present. - /// - public string MessageText { get; } - - /// - /// An optional secondary error message carrying more detail about the problem. - /// May run to multiple lines. - /// - public string? Detail { get; } - - /// - /// An optional suggestion what to do about the problem. - /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. - /// May run to multiple lines. - /// - public string? Hint { get; } - - /// - /// The field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. - /// The first character has index 1, and positions are measured in characters not bytes. - /// 0 means not provided. - /// - public int Position { get; } - - /// - /// This is defined the same as the field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. - /// The field will always appear when this field appears. - /// 0 means not provided. - /// - public int InternalPosition { get; } - - /// - /// The text of a failed internally-generated command. - /// This could be, for example, a SQL query issued by a PL/pgSQL function. - /// - public string? InternalQuery { get; } - - /// - /// An indication of the context in which the error occurred. - /// Presently this includes a call stack traceback of active PL functions. - /// The trace is one entry per line, most recent first. - /// - public string? Where { get; } - - /// - /// If the error was associated with a specific database object, the name of the schema containing that object, if any. - /// - /// PostgreSQL 9.3 and up. - public string? SchemaName { get; } - - /// - /// Table name: if the error was associated with a specific table, the name of the table. - /// (Refer to the schema name field for the name of the table's schema.) - /// - /// PostgreSQL 9.3 and up. - public string? TableName { get; } - - /// - /// If the error was associated with a specific table column, the name of the column. - /// (Refer to the schema and table name fields to identify the table.) - /// - /// PostgreSQL 9.3 and up. - public string? ColumnName { get; } - - /// - /// If the error was associated with a specific data type, the name of the data type. - /// (Refer to the schema name field for the name of the data type's schema.) - /// - /// PostgreSQL 9.3 and up. - public string? DataTypeName { get; } - - /// - /// If the error was associated with a specific constraint, the name of the constraint. - /// Refer to fields listed above for the associated table or domain. - /// (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) - /// - /// PostgreSQL 9.3 and up. - public string? ConstraintName { get; } - - /// - /// The file name of the source-code location where the error was reported. - /// - /// PostgreSQL 9.3 and up. - public string? File { get; } - - /// - /// The line number of the source-code location where the error was reported. - /// - public string? Line { get; } - - /// - /// The name of the source-code routine reporting the error. - /// - public string? Routine { get; } - - #endregion - } + /// + /// The primary human-readable error message. This should be accurate but terse. + /// + /// + /// Always present. + /// + public string MessageText { get; } + + /// + /// An optional secondary error message carrying more detail about the problem. + /// May run to multiple lines. + /// + public string? Detail { get; } + + /// + /// An optional suggestion what to do about the problem. + /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. + /// May run to multiple lines. + /// + public string? Hint { get; } + + /// + /// The field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. + /// The first character has index 1, and positions are measured in characters not bytes. + /// 0 means not provided. + /// + public int Position { get; } + + /// + /// This is defined the same as the field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. + /// The field will always appear when this field appears. + /// 0 means not provided. + /// + public int InternalPosition { get; } + + /// + /// The text of a failed internally-generated command. + /// This could be, for example, a SQL query issued by a PL/pgSQL function. + /// + public string? InternalQuery { get; } + + /// + /// An indication of the context in which the error occurred. + /// Presently this includes a call stack traceback of active PL functions. + /// The trace is one entry per line, most recent first. + /// + public string? Where { get; } + + /// + /// If the error was associated with a specific database object, the name of the schema containing that object, if any. + /// + /// PostgreSQL 9.3 and up. + public string? SchemaName { get; } + + /// + /// Table name: if the error was associated with a specific table, the name of the table. + /// (Refer to the schema name field for the name of the table's schema.) + /// + /// PostgreSQL 9.3 and up. + public string? TableName { get; } + + /// + /// If the error was associated with a specific table column, the name of the column. + /// (Refer to the schema and table name fields to identify the table.) + /// + /// PostgreSQL 9.3 and up. + public string? ColumnName { get; } + + /// + /// If the error was associated with a specific data type, the name of the data type. + /// (Refer to the schema name field for the name of the data type's schema.) + /// + /// PostgreSQL 9.3 and up. + public string? DataTypeName { get; } + + /// + /// If the error was associated with a specific constraint, the name of the constraint. + /// Refer to fields listed above for the associated table or domain. + /// (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) + /// + /// PostgreSQL 9.3 and up. + public string? ConstraintName { get; } + + /// + /// The file name of the source-code location where the error was reported. + /// + /// PostgreSQL 9.3 and up. + public string? File { get; } + + /// + /// The line number of the source-code location where the error was reported. + /// + public string? Line { get; } + + /// + /// The name of the source-code routine reporting the error. + /// + public string? Routine { get; } + + #endregion } diff --git a/src/Npgsql/PostgresMinimalDatabaseInfo.cs b/src/Npgsql/PostgresMinimalDatabaseInfo.cs index dbcba31a96..eb90453062 100644 --- a/src/Npgsql/PostgresMinimalDatabaseInfo.cs +++ b/src/Npgsql/PostgresMinimalDatabaseInfo.cs @@ -1,38 +1,145 @@ using System.Collections.Generic; -using System.Linq; -using System.Reflection; using System.Threading.Tasks; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.Util; -using NpgsqlTypes; -namespace Npgsql +namespace Npgsql; + +sealed class PostgresMinimalDatabaseInfoFactory : INpgsqlDatabaseInfoFactory { - class PostgresMinimalDatabaseInfoFactory : INpgsqlDatabaseInfoFactory + public Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) + => Task.FromResult( + conn.Settings.ServerCompatibilityMode == ServerCompatibilityMode.NoTypeLoading + ? (NpgsqlDatabaseInfo)new PostgresMinimalDatabaseInfo(conn) + : null); +} + +sealed class PostgresMinimalDatabaseInfo : PostgresDatabaseInfo +{ + static PostgresType[]? _typesWithMultiranges, _typesWithoutMultiranges; + + static PostgresType[] CreateTypes(bool withMultiranges) { - public Task Load(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) - => Task.FromResult( - new NpgsqlConnectionStringBuilder(conn.ConnectionString).ServerCompatibilityMode == ServerCompatibilityMode.NoTypeLoading - ? (NpgsqlDatabaseInfo)new PostgresMinimalDatabaseInfo(conn) - : null - ); + var types = new List(); + + Add(DataTypeNames.Int2, oid: 21, arrayOid: 1005); + AddWithRange(DataTypeNames.Int4, oid: 23, arrayOid: 1007, + rangeName: DataTypeNames.Int4Range, rangeOid: 3904, rangeArrayOid: 3905, multirangeOid: 4451, multirangeArrayOid: 6150); + Add(DataTypeNames.Int8, oid: 20, arrayOid: 1016); + AddWithRange(DataTypeNames.Int8, oid: 20, arrayOid: 1016, + rangeName: DataTypeNames.Int8Range, rangeOid: 3926, rangeArrayOid: 3927, multirangeOid: 4536, multirangeArrayOid: 6157); + Add(DataTypeNames.Float4, oid: 700, arrayOid: 1021); + Add(DataTypeNames.Float8, oid: 701, arrayOid: 1022); + AddWithRange(DataTypeNames.Numeric, oid: 1700, arrayOid: 1231, + rangeName: DataTypeNames.NumRange, rangeOid: 3906, rangeArrayOid: 3907, multirangeOid: 4532, multirangeArrayOid: 6151); + Add(DataTypeNames.Money, oid: 790, arrayOid: 791); + Add(DataTypeNames.Bool, oid: 16, arrayOid: 1000); + Add(DataTypeNames.Box, oid: 603, arrayOid: 1020); + Add(DataTypeNames.Circle, oid: 718, arrayOid: 719); + Add(DataTypeNames.Line, oid: 628, arrayOid: 629); + Add(DataTypeNames.LSeg, oid: 601, arrayOid: 1018); + Add(DataTypeNames.Path, oid: 602, arrayOid: 1019); + Add(DataTypeNames.Point, oid: 600, arrayOid: 1017); + Add(DataTypeNames.Polygon, oid: 604, arrayOid: 1027); + Add(DataTypeNames.Bpchar, oid: 1042, arrayOid: 1014); + Add(DataTypeNames.Text, oid: 25, arrayOid: 1009); + Add(DataTypeNames.Varchar, oid: 1043, arrayOid: 1015); + Add(DataTypeNames.Name, oid: 19, arrayOid: 1003); + Add(DataTypeNames.Bytea, oid: 17, arrayOid: 1001); + AddWithRange(DataTypeNames.Date, oid: 1082, arrayOid: 1182, + rangeName: DataTypeNames.DateRange, rangeOid: 3912, rangeArrayOid: 3913, multirangeOid: 4535, multirangeArrayOid: 6155); + Add(DataTypeNames.Time, oid: 1083, arrayOid: 1183); + AddWithRange(DataTypeNames.Timestamp, oid: 1114, arrayOid: 1115, + rangeName: DataTypeNames.TsRange, rangeOid: 3908, rangeArrayOid: 3909, multirangeOid: 4533, multirangeArrayOid: 6152); + AddWithRange(DataTypeNames.TimestampTz, oid: 1184, arrayOid: 1185, + rangeName: DataTypeNames.TsTzRange, rangeOid: 3910, rangeArrayOid: 3911, multirangeOid: 4534, multirangeArrayOid: 6153); + Add(DataTypeNames.Interval, oid: 1186, arrayOid: 1187); + Add(DataTypeNames.TimeTz, oid: 1266, arrayOid: 1270); + Add(DataTypeNames.Inet, oid: 869, arrayOid: 1041); + Add(DataTypeNames.Cidr, oid: 650, arrayOid: 651); + Add(DataTypeNames.MacAddr, oid: 829, arrayOid: 1040); + Add(DataTypeNames.MacAddr8, oid: 774, arrayOid: 775); + Add(DataTypeNames.Bit, oid: 1560, arrayOid: 1561); + Add(DataTypeNames.Varbit, oid: 1562, arrayOid: 1563); + Add(DataTypeNames.TsVector, oid: 3614, arrayOid: 3643); + Add(DataTypeNames.TsQuery, oid: 3615, arrayOid: 3645); + Add(DataTypeNames.RegConfig, oid: 3734, arrayOid: 3735); + Add(DataTypeNames.Uuid, oid: 2950, arrayOid: 2951); + Add(DataTypeNames.Xml, oid: 142, arrayOid: 143); + Add(DataTypeNames.Json, oid: 114, arrayOid: 199); + Add(DataTypeNames.Jsonb, oid: 3802, arrayOid: 3807); + Add(DataTypeNames.Jsonpath, oid: 4072, arrayOid: 4073); + Add(DataTypeNames.RefCursor, oid: 1790, arrayOid: 2201); + Add(DataTypeNames.OidVector, oid: 30, arrayOid: 1013); + Add(DataTypeNames.Int2Vector, oid: 22, arrayOid: 1006); + Add(DataTypeNames.Oid, oid: 26, arrayOid: 1028); + Add(DataTypeNames.Xid, oid: 28, arrayOid: 1011); + Add(DataTypeNames.Xid8, oid: 5069, arrayOid: 271); + Add(DataTypeNames.Cid, oid: 29, arrayOid: 1012); + Add(DataTypeNames.RegType, oid: 2206, arrayOid: 2211); + Add(DataTypeNames.Tid, oid: 27, arrayOid: 1010); + Add(DataTypeNames.PgLsn, oid: 3220, arrayOid: 3221); + Add(DataTypeNames.Unknown, oid: 705, arrayOid: 0); + Add(DataTypeNames.Void, oid: 2278, arrayOid: 0); + + return types.ToArray(); + + void Add(DataTypeName name, uint oid, uint arrayOid) + { + var type = new PostgresBaseType(name, oid); + types.Add(type); + if (arrayOid is not 0) + types.Add(new PostgresArrayType(name.ToArrayName(), arrayOid, type)); + } + + void AddWithRange(DataTypeName name, uint oid, uint arrayOid, DataTypeName rangeName, uint rangeOid, uint rangeArrayOid, uint multirangeOid, uint multirangeArrayOid) + { + var type = new PostgresBaseType(name, oid); + var rangeType = new PostgresRangeType(rangeName, rangeOid, type); + types.Add(type); + types.Add(new PostgresArrayType(name.ToArrayName(), arrayOid, type)); + types.Add(rangeType); + types.Add(new PostgresArrayType(rangeName.ToArrayName(), rangeArrayOid, rangeType)); + if (withMultiranges) + { + var multirangeType = new PostgresMultirangeType(rangeName.ToDefaultMultirangeName(), multirangeOid, rangeType); + types.Add(multirangeType); + types.Add(new PostgresArrayType(multirangeType.DataTypeName.ToArrayName(), multirangeArrayOid, multirangeType)); + } + } } - class PostgresMinimalDatabaseInfo : PostgresDatabaseInfo + protected override IEnumerable GetTypes() + => SupportsMultirangeTypes + ? _typesWithMultiranges ??= CreateTypes(withMultiranges: true) + : _typesWithoutMultiranges ??= CreateTypes(withMultiranges: false); + + internal PostgresMinimalDatabaseInfo(NpgsqlConnector conn) + : base(conn) { - static readonly PostgresBaseType[] Types = typeof(NpgsqlDbType).GetFields() - .Select(f => f.GetCustomAttribute()) - .OfType() - .Select(a => new PostgresBaseType("pg_catalog", a.Name, a.OID)) - .ToArray(); + HasIntegerDateTimes = !conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) || + intDateTimes == "on"; + } - protected override IEnumerable GetTypes() => Types; + // TODO, split database info and type catalog. + internal PostgresMinimalDatabaseInfo() + : base("minimal", 5432, "minimal", "14") + { + } - internal PostgresMinimalDatabaseInfo(NpgsqlConnection conn) - : base(conn) + static PostgresMinimalDatabaseInfo? _defaultTypeCatalog; + internal static PostgresMinimalDatabaseInfo DefaultTypeCatalog + { + get { - HasIntegerDateTimes = !conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) || - intDateTimes == "on"; + if (_defaultTypeCatalog is not null) + return _defaultTypeCatalog; + + var catalog = new PostgresMinimalDatabaseInfo(); + catalog.ProcessTypes(); + return _defaultTypeCatalog = catalog; } } } diff --git a/src/Npgsql/PostgresNotice.cs b/src/Npgsql/PostgresNotice.cs index f775dd642c..ef55ad4e13 100644 --- a/src/Npgsql/PostgresNotice.cs +++ b/src/Npgsql/PostgresNotice.cs @@ -1,214 +1,204 @@ using System; -using JetBrains.Annotations; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; - -namespace Npgsql +using Npgsql.Internal; + +namespace Npgsql; + +/// +/// PostgreSQL notices are non-critical messages generated by PostgreSQL, either as a result of a user query +/// (e.g. as a warning or informational notice), or due to outside activity (e.g. if the database administrator +/// initiates a "fast" database shutdown). +/// +/// +/// https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-ASYNC +/// +public sealed class PostgresNotice { + #region Message Fields + + /// + /// Severity of the error or notice. + /// Always present. + /// + public string Severity { get; set; } + /// - /// PostgreSQL notices are non-critical messages generated by PostgreSQL, either as a result of a user query - /// (e.g. as a warning or informational notice), or due to outside activity (e.g. if the database administrator - /// initiates a "fast" database shutdown). + /// Severity of the error or notice, not localized. + /// Always present since PostgreSQL 9.6. + /// + public string InvariantSeverity { get; } + + /// + /// The SQLSTATE code for the error. /// /// - /// https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-ASYNC + /// Always present. + /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html /// - public sealed class PostgresNotice + public string SqlState { get; set; } + + /// + /// The primary human-readable error message. This should be accurate but terse. + /// + /// + /// Always present. + /// + public string MessageText { get; set; } + + /// + /// An optional secondary error message carrying more detail about the problem. + /// May run to multiple lines. + /// + public string? Detail { get; set; } + + /// + /// An optional suggestion what to do about the problem. + /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. + /// May run to multiple lines. + /// + public string? Hint { get; set; } + + /// + /// The field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. + /// The first character has index 1, and positions are measured in characters not bytes. + /// 0 means not provided. + /// + public int Position { get; set; } + + /// + /// This is defined the same as the field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. + /// The field will always appear when this field appears. + /// 0 means not provided. + /// + public int InternalPosition { get; set; } + + /// + /// The text of a failed internally-generated command. + /// This could be, for example, a SQL query issued by a PL/pgSQL function. + /// + public string? InternalQuery { get; set; } + + /// + /// An indication of the context in which the error occurred. + /// Presently this includes a call stack traceback of active PL functions. + /// The trace is one entry per line, most recent first. + /// + public string? Where { get; set; } + + /// + /// If the error was associated with a specific database object, the name of the schema containing that object, if any. + /// + /// PostgreSQL 9.3 and up. + public string? SchemaName { get; set; } + + /// + /// Table name: if the error was associated with a specific table, the name of the table. + /// (Refer to the schema name field for the name of the table's schema.) + /// + /// PostgreSQL 9.3 and up. + public string? TableName { get; set; } + + /// + /// If the error was associated with a specific table column, the name of the column. + /// (Refer to the schema and table name fields to identify the table.) + /// + /// PostgreSQL 9.3 and up. + public string? ColumnName { get; set; } + + /// + /// If the error was associated with a specific data type, the name of the data type. + /// (Refer to the schema name field for the name of the data type's schema.) + /// + /// PostgreSQL 9.3 and up. + public string? DataTypeName { get; set; } + + /// + /// If the error was associated with a specific constraint, the name of the constraint. + /// Refer to fields listed above for the associated table or domain. + /// (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) + /// + /// PostgreSQL 9.3 and up. + public string? ConstraintName { get; set; } + + /// + /// The file name of the source-code location where the error was reported. + /// + /// PostgreSQL 9.3 and up. + public string? File { get; set; } + + /// + /// The line number of the source-code location where the error was reported. + /// + public string? Line { get; set; } + + /// + /// The name of the source-code routine reporting the error. + /// + public string? Routine { get; set; } + + #endregion + + /// + /// Creates a new instance. + /// + public PostgresNotice(string severity, string invariantSeverity, string sqlState, string messageText) + : this(messageText, severity, invariantSeverity, sqlState, detail: null) {} + + /// + /// Creates a new instance. + /// + public PostgresNotice( + string messageText, string severity, string invariantSeverity, string sqlState, + string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, + string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, + string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, + string? line = null, string? routine = null) { - #region Message Fields - - /// - /// Severity of the error or notice. - /// Always present. - /// - public string Severity { get; set; } - - /// - /// Severity of the error or notice, not localized. - /// Always present since PostgreSQL 9.6. - /// - public string InvariantSeverity { get; } - - /// - /// The SQLSTATE code for the error. - /// - /// - /// Always present. - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html - /// - public string SqlState { get; set; } - - /// - /// The SQLSTATE code for the error. - /// - /// - /// Always present. - /// See https://www.postgresql.org/docs/current/static/errcodes-appendix.html - /// - [Obsolete("Use SqlState instead")] - public string Code => SqlState; - - /// - /// The primary human-readable error message. This should be accurate but terse. - /// - /// - /// Always present. - /// - public string MessageText { get; set; } - - /// - /// An optional secondary error message carrying more detail about the problem. - /// May run to multiple lines. - /// - public string? Detail { get; set; } - - /// - /// An optional suggestion what to do about the problem. - /// This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. - /// May run to multiple lines. - /// - public string? Hint { get; set; } - - /// - /// The field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. - /// The first character has index 1, and positions are measured in characters not bytes. - /// 0 means not provided. - /// - public int Position { get; set; } - - /// - /// This is defined the same as the field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. - /// The field will always appear when this field appears. - /// 0 means not provided. - /// - public int InternalPosition { get; set; } - - /// - /// The text of a failed internally-generated command. - /// This could be, for example, a SQL query issued by a PL/pgSQL function. - /// - public string? InternalQuery { get; set; } - - /// - /// An indication of the context in which the error occurred. - /// Presently this includes a call stack traceback of active PL functions. - /// The trace is one entry per line, most recent first. - /// - public string? Where { get; set; } - - /// - /// If the error was associated with a specific database object, the name of the schema containing that object, if any. - /// - /// PostgreSQL 9.3 and up. - public string? SchemaName { get; set; } - - /// - /// Table name: if the error was associated with a specific table, the name of the table. - /// (Refer to the schema name field for the name of the table's schema.) - /// - /// PostgreSQL 9.3 and up. - public string? TableName { get; set; } - - /// - /// If the error was associated with a specific table column, the name of the column. - /// (Refer to the schema and table name fields to identify the table.) - /// - /// PostgreSQL 9.3 and up. - public string? ColumnName { get; set; } - - /// - /// If the error was associated with a specific data type, the name of the data type. - /// (Refer to the schema name field for the name of the data type's schema.) - /// - /// PostgreSQL 9.3 and up. - public string? DataTypeName { get; set; } - - /// - /// If the error was associated with a specific constraint, the name of the constraint. - /// Refer to fields listed above for the associated table or domain. - /// (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) - /// - /// PostgreSQL 9.3 and up. - public string? ConstraintName { get; set; } - - /// - /// The file name of the source-code location where the error was reported. - /// - /// PostgreSQL 9.3 and up. - public string? File { get; set; } - - /// - /// The line number of the source-code location where the error was reported. - /// - public string? Line { get; set; } - - /// - /// The name of the source-code routine reporting the error. - /// - public string? Routine { get; set; } - - #endregion - - /// - /// Creates a new instance. - /// - public PostgresNotice(string severity, string invariantSeverity, string sqlState, string messageText) - : this(messageText, severity, invariantSeverity, sqlState, detail: null) {} - - /// - /// Creates a new instance. - /// - public PostgresNotice( - string messageText, string severity, string invariantSeverity, string sqlState, - string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, - string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, - string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, - string? line = null, string? routine = null) - { - MessageText = messageText; - Severity = severity; - InvariantSeverity = invariantSeverity; - SqlState = sqlState; - - Detail = detail; - Hint = hint; - Position = position; - InternalPosition = internalPosition; - InternalQuery = internalQuery; - Where = where; - SchemaName = schemaName; - TableName = tableName; - ColumnName = columnName; - DataTypeName = dataTypeName; - ConstraintName = constraintName; - File = file; - Line = line; - Routine = routine; - } - - PostgresNotice(ErrorOrNoticeMessage msg) - : this( - msg.Message, msg.Severity, msg.InvariantSeverity, msg.SqlState, - msg.Detail, msg.Hint, msg.Position, msg.InternalPosition, msg.InternalQuery, - msg.Where, msg.SchemaName, msg.TableName, msg.ColumnName, msg.DataTypeName, - msg.ConstraintName, msg.File, msg.Line, msg.Routine) {} - - internal static PostgresNotice Load(NpgsqlReadBuffer buf, bool includeDetail) - => new PostgresNotice(ErrorOrNoticeMessage.Load(buf, includeDetail)); + MessageText = messageText; + Severity = severity; + InvariantSeverity = invariantSeverity; + SqlState = sqlState; + + Detail = detail; + Hint = hint; + Position = position; + InternalPosition = internalPosition; + InternalQuery = internalQuery; + Where = where; + SchemaName = schemaName; + TableName = tableName; + ColumnName = columnName; + DataTypeName = dataTypeName; + ConstraintName = constraintName; + File = file; + Line = line; + Routine = routine; } + PostgresNotice(ErrorOrNoticeMessage msg) + : this( + msg.Message, msg.Severity, msg.InvariantSeverity, msg.SqlState, + msg.Detail, msg.Hint, msg.Position, msg.InternalPosition, msg.InternalQuery, + msg.Where, msg.SchemaName, msg.TableName, msg.ColumnName, msg.DataTypeName, + msg.ConstraintName, msg.File, msg.Line, msg.Routine) {} + + internal static PostgresNotice Load(NpgsqlReadBuffer buf, bool includeDetail, ILogger exceptionLogger) + => new(ErrorOrNoticeMessage.Load(buf, includeDetail, exceptionLogger)); +} + +/// +/// Provides data for a PostgreSQL notice event. +/// +public sealed class NpgsqlNoticeEventArgs : EventArgs +{ /// - /// Provides data for a PostgreSQL notice event. + /// The Notice that was sent from the database. /// - public sealed class NpgsqlNoticeEventArgs : EventArgs + public PostgresNotice Notice { get; } + + internal NpgsqlNoticeEventArgs(PostgresNotice notice) { - /// - /// The Notice that was sent from the database. - /// - public PostgresNotice Notice { get; } - - internal NpgsqlNoticeEventArgs(PostgresNotice notice) - { - Notice = notice; - } + Notice = notice; } } diff --git a/src/Npgsql/PostgresTypeOIDs.cs b/src/Npgsql/PostgresTypeOIDs.cs deleted file mode 100644 index 3970dfa475..0000000000 --- a/src/Npgsql/PostgresTypeOIDs.cs +++ /dev/null @@ -1,87 +0,0 @@ -namespace Npgsql -{ - /// - /// Holds well-known, built-in PostgreSQL type OIDs. - /// - static class PostgresTypeOIDs - { - // Numeric - internal const uint Int8 = 20; - internal const uint Float8 = 701; - internal const uint Int4 = 23; - internal const uint Numeric = 1700; - internal const uint Float4 = 700; - internal const uint Int2 = 21; - internal const uint Money = 790; - - // Boolean - internal const uint Bool = 16; - - // Geometric - internal const uint Box = 603; - internal const uint Circle = 718; - internal const uint Line = 628; - internal const uint LSeg = 601; - internal const uint Path = 602; - internal const uint Point = 600; - internal const uint Polygon = 604; - - // Character - internal const uint BPChar = 1042; - internal const uint Text = 25; - internal const uint Varchar = 1043; - internal const uint Name = 19; - internal const uint Char = 18; - - // Binary data - internal const uint Bytea = 17; - - // Date/Time - internal const uint Date = 1082; - internal const uint Time = 1083; - internal const uint Timestamp = 1114; - internal const uint TimestampTz = 1184; - internal const uint Interval = 1186; - internal const uint TimeTz = 1266; - internal const uint Abstime = 702; - - // Network address - internal const uint Inet = 869; - internal const uint Cidr = 650; - internal const uint Macaddr = 829; - internal const uint Macaddr8 = 774; - - // Bit string - internal const uint Bit = 1560; - internal const uint Varbit = 1562; - - // Text search - internal const uint TsVector = 3614; - internal const uint TsQuery = 3615; - internal const uint Regconfig = 3734; - - // UUID - internal const uint Uuid = 2950; - - // XML - internal const uint Xml = 142; - - // JSON - internal const uint Json = 114; - internal const uint Jsonb = 3802; - internal const uint JsonPath = 4072; - - // Internal - internal const uint Refcursor = 1790; - internal const uint Oidvector = 30; - internal const uint Int2vector = 22; - internal const uint Oid = 26; - internal const uint Xid = 28; - internal const uint Cid = 29; - internal const uint Regtype = 2206; - internal const uint Tid = 27; - - // Special - internal const uint Unknown = 705; - } -} diff --git a/src/Npgsql/PostgresTypes/PostgresArrayType.cs b/src/Npgsql/PostgresTypes/PostgresArrayType.cs index 1e34ad4d45..2f46d31cf2 100644 --- a/src/Npgsql/PostgresTypes/PostgresArrayType.cs +++ b/src/Npgsql/PostgresTypes/PostgresArrayType.cs @@ -1,39 +1,46 @@ -using System.Diagnostics; -using JetBrains.Annotations; +using Npgsql.Internal.Postgres; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL array data type, which can hold several multiple values in a single column. +/// +/// +/// See https://www.postgresql.org/docs/current/static/arrays.html. +/// +public class PostgresArrayType : PostgresType { /// - /// Represents a PostgreSQL array data type, which can hold several multiple values in a single column. + /// The PostgreSQL data type of the element contained within this array. /// - /// - /// See https://www.postgresql.org/docs/current/static/arrays.html. - /// - public class PostgresArrayType : PostgresType + public PostgresType Element { get; } + + /// + /// Constructs a representation of a PostgreSQL array data type. + /// + protected internal PostgresArrayType(string ns, string name, uint oid, PostgresType elementPostgresType) + : base(ns, name, oid) { - /// - /// The PostgreSQL data type of the element contained within this array. - /// - public PostgresType Element { get; } + Element = elementPostgresType; + Element.Array = this; + } - /// - /// Constructs a representation of a PostgreSQL array data type. - /// - protected internal PostgresArrayType(string ns, string internalName, uint oid, PostgresType elementPostgresType) - : base(ns, elementPostgresType.Name + "[]", internalName, oid) - { - Debug.Assert(internalName == '_' + elementPostgresType.InternalName); - Element = elementPostgresType; - Element.Array = this; - } + /// + /// Constructs a representation of a PostgreSQL array data type. + /// + internal PostgresArrayType(DataTypeName dataTypeName, Oid oid, PostgresType elementPostgresType) + : base(dataTypeName, oid) + { + Element = elementPostgresType; + Element.Array = this; + } - // PostgreSQL array types have an underscore-prefixed name (_text), but we - // want to return the public text[] instead - /// - internal override string GetPartialNameWithFacets(int typeModifier) - => Element.GetPartialNameWithFacets(typeModifier) + "[]"; + // PostgreSQL array types have an underscore-prefixed name (_text), but we + // want to return the public text[] instead + /// + internal override string GetPartialNameWithFacets(int typeModifier) + => Element.GetPartialNameWithFacets(typeModifier) + "[]"; - internal override PostgresFacets GetFacets(int typeModifier) - => Element.GetFacets(typeModifier); - } + internal override PostgresFacets GetFacets(int typeModifier) + => Element.GetFacets(typeModifier); } diff --git a/src/Npgsql/PostgresTypes/PostgresBaseType.cs b/src/Npgsql/PostgresTypes/PostgresBaseType.cs index 2019cdd53a..11c289b1a8 100644 --- a/src/Npgsql/PostgresTypes/PostgresBaseType.cs +++ b/src/Npgsql/PostgresTypes/PostgresBaseType.cs @@ -1,91 +1,80 @@  -namespace Npgsql.PostgresTypes +using Npgsql.Internal.Postgres; + +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL base data type, which is a simple scalar value. +/// +public class PostgresBaseType : PostgresType { /// - /// Represents a PostgreSQL base data type, which is a simple scalar value. + /// Constructs a representation of a PostgreSQL base data type. /// - public class PostgresBaseType : PostgresType + protected internal PostgresBaseType(string ns, string name, uint oid) + : base(ns, name, oid) {} + + /// + /// Constructs a representation of a PostgreSQL base data type. + /// + internal PostgresBaseType(DataTypeName dataTypeName, Oid oid) + : base(dataTypeName, oid) {} + + /// + internal override string GetPartialNameWithFacets(int typeModifier) { - /// - protected internal PostgresBaseType(string ns, string internalName, uint oid) - : base(ns, TranslateInternalName(internalName), internalName, oid) - {} + var facets = GetFacets(typeModifier); + if (facets == PostgresFacets.None) + return Name; - /// - internal override string GetPartialNameWithFacets(int typeModifier) + return Name switch { - var facets = GetFacets(typeModifier); - if (facets == PostgresFacets.None) - return Name; + // Special case for time, timestamp, timestamptz and timetz where the facet is embedded in the middle + "timestamp without time zone" => $"timestamp{facets} without time zone", + "time without time zone" => $"time{facets} without time zone", + "timestamp with time zone" => $"timestamp{facets} with time zone", + "time with time zone" => $"time{facets} with time zone", - return Name switch - { - // Special case for time, timestamp, timestamptz and timetz where the facet is embedded in the middle - "timestamp without time zone" => $"timestamp{facets} without time zone", - "time without time zone" => $"time{facets} without time zone", - "timestamp with time zone" => $"timestamp{facets} with time zone", - "time with time zone" => $"time{facets} with time zone", - _ => $"{Name}{facets}" - }; - } + // We normalize character(1) to character - they mean the same + "character" when facets.Size == 1 => "character", - internal override PostgresFacets GetFacets(int typeModifier) - { - if (typeModifier == -1) - return PostgresFacets.None; + _ => $"{Name}{facets}" + }; + } - switch (Name) - { - case "character": - return new PostgresFacets(typeModifier - 4, null, null); - case "character varying": - return new PostgresFacets(typeModifier - 4, null, null); // Max length - case "numeric": - case "decimal": - // See https://stackoverflow.com/questions/3350148/where-are-numeric-precision-and-scale-for-a-field-found-in-the-pg-catalog-tables - var precision = ((typeModifier - 4) >> 16) & 65535; - var scale = (typeModifier - 4) & 65535; - return new PostgresFacets(null, precision, scale == 0 ? (int?)null : scale); - case "timestamp without time zone": - case "time without time zone": - case "interval": - precision = typeModifier & 0xFFFF; - return new PostgresFacets(null, precision, null); - case "timestamp with time zone": - precision = typeModifier & 0xFFFF; - return new PostgresFacets(null, precision, null); - case "time with time zone": - precision = typeModifier & 0xFFFF; - return new PostgresFacets(null, precision, null); - case "bit": - case "bit varying": - return new PostgresFacets(typeModifier, null, null); - default: - return PostgresFacets.None; - } - } + internal override PostgresFacets GetFacets(int typeModifier) + { + if (typeModifier == -1) + return PostgresFacets.None; - // The type names returned by PostgreSQL are internal names (int4 instead of - // integer). We perform translation to the user-facing standard names. - // https://www.postgresql.org/docs/current/static/datatype.html#DATATYPE-TABLE - static string TranslateInternalName(string internalName) - => internalName switch - { - "bool" => "boolean", - "bpchar" => "character", - "decimal" => "numeric", - "float4" => "real", - "float8" => "double precision", - "int2" => "smallint", - "int4" => "integer", - "int8" => "bigint", - "time" => "time without time zone", - "timestamp" => "timestamp without time zone", - "timetz" => "time with time zone", - "timestamptz" => "timestamp with time zone", - "varbit" => "bit varying", - "varchar" => "character varying", - _ => internalName - }; + switch (Name) + { + case "character": + return new PostgresFacets(typeModifier - 4, null, null); + case "character varying": + return new PostgresFacets(typeModifier - 4, null, null); // Max length + case "numeric": + case "decimal": + // See https://stackoverflow.com/questions/3350148/where-are-numeric-precision-and-scale-for-a-field-found-in-the-pg-catalog-tables + var precision = ((typeModifier - 4) >> 16) & 65535; + var scale = (typeModifier - 4) & 65535; + return new PostgresFacets(null, precision, scale == 0 ? (int?)null : scale); + case "timestamp without time zone": + case "time without time zone": + case "interval": + precision = typeModifier & 0xFFFF; + return new PostgresFacets(null, precision, null); + case "timestamp with time zone": + precision = typeModifier & 0xFFFF; + return new PostgresFacets(null, precision, null); + case "time with time zone": + precision = typeModifier & 0xFFFF; + return new PostgresFacets(null, precision, null); + case "bit": + case "bit varying": + return new PostgresFacets(typeModifier, null, null); + default: + return PostgresFacets.None; + } } } diff --git a/src/Npgsql/PostgresTypes/PostgresCompositeType.cs b/src/Npgsql/PostgresTypes/PostgresCompositeType.cs index 4144bc6fb3..2d53199e6f 100644 --- a/src/Npgsql/PostgresTypes/PostgresCompositeType.cs +++ b/src/Npgsql/PostgresTypes/PostgresCompositeType.cs @@ -1,52 +1,56 @@ using System.Collections.Generic; +using Npgsql.Internal.Postgres; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL composite data type, which can hold multiple fields of varying types in a single column. +/// +/// +/// See https://www.postgresql.org/docs/current/static/rowtypes.html. +/// +public class PostgresCompositeType : PostgresType { /// - /// Represents a PostgreSQL composite data type, which can hold multiple fields of varying types in a single column. + /// Holds the name and types for all fields. /// - /// - /// See https://www.postgresql.org/docs/current/static/rowtypes.html. - /// - public class PostgresCompositeType : PostgresType - { - /// - /// Holds the name and types for all fields. - /// - public IReadOnlyList Fields => MutableFields; + public IReadOnlyList Fields => MutableFields; + + internal List MutableFields { get; } = new(); + + /// + /// Constructs a representation of a PostgreSQL array data type. + /// + internal PostgresCompositeType(string ns, string name, uint oid) + : base(ns, name, oid) {} + + /// + /// Constructs a representation of a PostgreSQL domain data type. + /// + internal PostgresCompositeType(DataTypeName dataTypeName, Oid oid) + : base(dataTypeName, oid) {} - internal List MutableFields { get; } = new List(); + /// + /// Represents a field in a PostgreSQL composite data type. + /// + public class Field + { + internal Field(string name, PostgresType type) + { + Name = name; + Type = type; + } /// - /// Constructs a representation of a PostgreSQL array data type. + /// The name of the composite field. /// -#pragma warning disable CA2222 // Do not decrease inherited member visibility - internal PostgresCompositeType(string ns, string name, uint oid) - : base(ns, name, oid) {} -#pragma warning restore CA2222 // Do not decrease inherited member visibility - + public string Name { get; } /// - /// Represents a field in a PostgreSQL composite data type. + /// The type of the composite field. /// - public class Field - { - internal Field(string name, PostgresType type) - { - Name = name; - Type = type; - } - - /// - /// The name of the composite field. - /// - public string Name { get; } - /// - /// The type of the composite field. - /// - public PostgresType Type { get; } - - /// - public override string ToString() => $"{Name} => {Type}"; - } + public PostgresType Type { get; } + + /// + public override string ToString() => $"{Name} => {Type}"; } } diff --git a/src/Npgsql/PostgresTypes/PostgresDomainType.cs b/src/Npgsql/PostgresTypes/PostgresDomainType.cs index 95b508a02d..cab9323015 100644 --- a/src/Npgsql/PostgresTypes/PostgresDomainType.cs +++ b/src/Npgsql/PostgresTypes/PostgresDomainType.cs @@ -1,41 +1,50 @@ -using JetBrains.Annotations; +using Npgsql.Internal.Postgres; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL domain type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/sql-createdomain.html. +/// +/// When PostgreSQL returns a RowDescription for a domain type, the type OID is the base type's +/// (so fetching a domain type over text returns a RowDescription for text). +/// However, when a composite type is returned, the type OID there is that of the domain, +/// so we provide "clean" support for domain types. +/// +public class PostgresDomainType : PostgresType { /// - /// Represents a PostgreSQL domain type. + /// The PostgreSQL data type of the base type, i.e. the type this domain is based on. /// - /// - /// See https://www.postgresql.org/docs/current/static/sql-createdomain.html. - /// - /// When PostgreSQL returns a RowDescription for a domain type, the type OID is the base type's - /// (so fetching a domain type over text returns a RowDescription for text). - /// However, when a composite type is returned, the type OID there is that of the domain, - /// so we provide "clean" support for domain types. - /// - public class PostgresDomainType : PostgresType - { - /// - /// The PostgreSQL data type of the base type, i.e. the type this domain is based on. - /// - public PostgresType BaseType { get; } + public PostgresType BaseType { get; } - /// - /// True if the domain has a NOT NULL constraint, otherwise false. - /// - public bool NotNull { get; } + /// + /// True if the domain has a NOT NULL constraint, otherwise false. + /// + public bool NotNull { get; } - /// - /// Constructs a representation of a PostgreSQL domain data type. - /// - protected internal PostgresDomainType(string ns, string name, uint oid, PostgresType baseType, bool notNull) - : base(ns, name, oid) - { - BaseType = baseType; - NotNull = notNull; - } + /// + /// Constructs a representation of a PostgreSQL domain data type. + /// + protected internal PostgresDomainType(string ns, string name, uint oid, PostgresType baseType, bool notNull) + : base(ns, name, oid) + { + BaseType = baseType; + NotNull = notNull; + } - internal override PostgresFacets GetFacets(int typeModifier) - => BaseType.GetFacets(typeModifier); + /// + /// Constructs a representation of a PostgreSQL domain data type. + /// + internal PostgresDomainType(DataTypeName dataTypeName, Oid oid, PostgresType baseType, bool notNull) + : base(dataTypeName, oid) + { + BaseType = baseType; + NotNull = notNull; } + + internal override PostgresFacets GetFacets(int typeModifier) + => BaseType.GetFacets(typeModifier); } diff --git a/src/Npgsql/PostgresTypes/PostgresEnumType.cs b/src/Npgsql/PostgresTypes/PostgresEnumType.cs index 62ddd837b7..7e4440252e 100644 --- a/src/Npgsql/PostgresTypes/PostgresEnumType.cs +++ b/src/Npgsql/PostgresTypes/PostgresEnumType.cs @@ -1,27 +1,33 @@ using System.Collections.Generic; +using Npgsql.Internal.Postgres; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL enum data type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/datatype-enum.html. +/// +public class PostgresEnumType : PostgresType { /// - /// Represents a PostgreSQL enum data type. + /// The enum's fields. /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-enum.html. - /// - public class PostgresEnumType : PostgresType - { - /// - /// The enum's fields. - /// - public IReadOnlyList Labels => MutableLabels; + public IReadOnlyList Labels => MutableLabels; + + internal List MutableLabels { get; } = new(); - internal List MutableLabels { get; } = new List(); + /// + /// Constructs a representation of a PostgreSQL enum data type. + /// + protected internal PostgresEnumType(string ns, string name, uint oid) + : base(ns, name, oid) {} + + /// + /// Constructs a representation of a PostgreSQL enum data type. + /// + internal PostgresEnumType(DataTypeName dataTypeName, Oid oid) + : base(dataTypeName, oid) {} - /// - /// Constructs a representation of a PostgreSQL enum data type. - /// - protected internal PostgresEnumType(string ns, string name, uint oid) - : base(ns, name, oid) - {} - } } diff --git a/src/Npgsql/PostgresTypes/PostgresFacets.cs b/src/Npgsql/PostgresTypes/PostgresFacets.cs index 8fea86c6b5..4c88724965 100644 --- a/src/Npgsql/PostgresTypes/PostgresFacets.cs +++ b/src/Npgsql/PostgresTypes/PostgresFacets.cs @@ -1,71 +1,70 @@ using System; using System.Text; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +readonly struct PostgresFacets : IEquatable { - readonly struct PostgresFacets : IEquatable + internal static readonly PostgresFacets None = new(null, null, null); + + internal PostgresFacets(int? size, int? precision, int? scale) { - internal static readonly PostgresFacets None = new PostgresFacets(null, null, null); + Size = size; + Precision = precision; + Scale = scale; + } - internal PostgresFacets(int? size, int? precision, int? scale) - { - Size = size; - Precision = precision; - Scale = scale; - } + public readonly int? Size; + public readonly int? Precision; + public readonly int? Scale; - public readonly int? Size; - public readonly int? Precision; - public readonly int? Scale; + public override bool Equals(object? o) + => o is PostgresFacets otherFacets && Equals(otherFacets); - public override bool Equals(object? o) - => o is PostgresFacets otherFacets && Equals(otherFacets); + public bool Equals(PostgresFacets o) + => Size == o.Size && Precision == o.Precision && Scale == o.Scale; - public bool Equals(PostgresFacets o) - => Size == o.Size && Precision == o.Precision && Scale == o.Scale; + public static bool operator ==(PostgresFacets x, PostgresFacets y) => x.Equals(y); - public static bool operator ==(PostgresFacets x, PostgresFacets y) => x.Equals(y); + public static bool operator !=(PostgresFacets x, PostgresFacets y) => !(x == y); - public static bool operator !=(PostgresFacets x, PostgresFacets y) => !(x == y); + public override int GetHashCode() + { + var hashcode = Size?.GetHashCode() ?? 0; + hashcode = (hashcode * 397) ^ (Precision?.GetHashCode() ?? 0); + hashcode = (hashcode * 397) ^ (Scale?.GetHashCode() ?? 0); + return hashcode; + } + + public override string ToString() + { + if (Size == null && Precision == null && Scale == null) + return string.Empty; - public override int GetHashCode() + var sb = new StringBuilder().Append('('); + var needComma = false; + + if (Size != null) { - var hashcode = Size?.GetHashCode() ?? 0; - hashcode = (hashcode * 397) ^ (Precision?.GetHashCode() ?? 0); - hashcode = (hashcode * 397) ^ (Scale?.GetHashCode() ?? 0); - return hashcode; + sb.Append(Size); + needComma = true; } - public override string ToString() + if (Precision != null) { - if (Size == null && Precision == null && Scale == null) - return string.Empty; - - var sb = new StringBuilder().Append('('); - var needComma = false; - - if (Size != null) - { - sb.Append(Size); - needComma = true; - } - - if (Precision != null) - { - if (needComma) - sb.Append(", "); - sb.Append(Precision); - needComma = true; - } - - if (Scale != null) - { - if (needComma) - sb.Append(", "); - sb.Append(Scale); - } + if (needComma) + sb.Append(", "); + sb.Append(Precision); + needComma = true; + } - return sb.Append(')').ToString(); + if (Scale != null) + { + if (needComma) + sb.Append(", "); + sb.Append(Scale); } + + return sb.Append(')').ToString(); } -} +} \ No newline at end of file diff --git a/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs b/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs new file mode 100644 index 0000000000..2769df87f8 --- /dev/null +++ b/src/Npgsql/PostgresTypes/PostgresMultirangeType.cs @@ -0,0 +1,38 @@ +using Npgsql.Internal.Postgres; + +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL multirange data type. +/// +/// +///

See https://www.postgresql.org/docs/current/static/rangetypes.html.

+///

Multirange types were introduced in PostgreSQL 14.

+///
+public class PostgresMultirangeType : PostgresType +{ + /// + /// The PostgreSQL data type of the range of this multirange. + /// + public PostgresRangeType Subrange { get; } + + /// + /// Constructs a representation of a PostgreSQL multirange data type. + /// + protected internal PostgresMultirangeType(string ns, string name, uint oid, PostgresRangeType rangePostgresType) + : base(ns, name, oid) + { + Subrange = rangePostgresType; + Subrange.Multirange = this; + } + + /// + /// Constructs a representation of a PostgreSQL multirange data type. + /// + internal PostgresMultirangeType(DataTypeName dataTypeName, Oid oid, PostgresRangeType rangePostgresType) + : base(dataTypeName, oid) + { + Subrange = rangePostgresType; + Subrange.Multirange = this; + } +} diff --git a/src/Npgsql/PostgresTypes/PostgresRangeType.cs b/src/Npgsql/PostgresTypes/PostgresRangeType.cs index 877ad19198..a26a71afae 100644 --- a/src/Npgsql/PostgresTypes/PostgresRangeType.cs +++ b/src/Npgsql/PostgresTypes/PostgresRangeType.cs @@ -1,28 +1,43 @@ -using JetBrains.Annotations; +using Npgsql.Internal.Postgres; -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL range data type. +/// +/// +/// See https://www.postgresql.org/docs/current/static/rangetypes.html. +/// +public class PostgresRangeType : PostgresType { /// - /// Represents a PostgreSQL range data type. + /// The PostgreSQL data type of the subtype of this range. + /// + public PostgresType Subtype { get; } + + /// + /// The PostgreSQL data type of the multirange of this range. /// - /// - /// See https://www.postgresql.org/docs/current/static/rangetypes.html. - /// - public class PostgresRangeType : PostgresType + public PostgresMultirangeType? Multirange { get; internal set; } + + /// + /// Constructs a representation of a PostgreSQL range data type. + /// + protected internal PostgresRangeType( + string ns, string name, uint oid, PostgresType subtypePostgresType) + : base(ns, name, oid) { - /// - /// The PostgreSQL data type of the subtype of this range. - /// - public PostgresType Subtype { get; } + Subtype = subtypePostgresType; + Subtype.Range = this; + } - /// - /// Constructs a representation of a PostgreSQL range data type. - /// - protected internal PostgresRangeType(string ns, string name, uint oid, PostgresType subtypePostgresType) - : base(ns, name, oid) - { - Subtype = subtypePostgresType; - Subtype.Range = this; - } + /// + /// Constructs a representation of a PostgreSQL range data type. + /// + internal PostgresRangeType(DataTypeName dataTypeName, Oid oid, PostgresType subtypePostgresType) + : base(dataTypeName, oid) + { + Subtype = subtypePostgresType; + Subtype.Range = this; } } diff --git a/src/Npgsql/PostgresTypes/PostgresType.cs b/src/Npgsql/PostgresTypes/PostgresType.cs index 1149216a2b..1182588c8c 100644 --- a/src/Npgsql/PostgresTypes/PostgresType.cs +++ b/src/Npgsql/PostgresTypes/PostgresType.cs @@ -1,114 +1,133 @@ -using JetBrains.Annotations; - -namespace Npgsql.PostgresTypes +using System; +using Npgsql.Internal.Postgres; + +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL data type, such as int4 or text, as discovered from pg_type. +/// This class is abstract, see derived classes for concrete types of PostgreSQL types. +/// +/// +/// Instances of this class are shared between connections to the same databases. +/// For more info about what this class and its subclasses represent, see +/// https://www.postgresql.org/docs/current/static/catalog-pg-type.html. +/// +public abstract class PostgresType { + #region Constructors + + /// + /// Constructs a representation of a PostgreSQL data type. + /// + /// The data type's namespace (or schema). + /// The data type's name. + /// The data type's OID. + private protected PostgresType(string ns, string name, uint oid) + { + DataTypeName = DataTypeName.FromDisplayName(name, ns); + OID = oid; + FullName = Namespace + "." + Name; + } + + /// + /// Constructs a representation of a PostgreSQL data type. + /// + /// The data type's fully qualified name. + /// The data type's OID. + private protected PostgresType(DataTypeName dataTypeName, Oid oid) + { + DataTypeName = dataTypeName; + OID = oid.Value; + FullName = Namespace + "." + Name; + } + + #endregion + + #region Public Properties + + /// + /// The data type's OID - a unique id identifying the data type in a given database (in pg_type). + /// + public uint OID { get; } + + /// + /// The data type's namespace (or schema). + /// + public string Namespace => DataTypeName.Schema; + /// - /// Represents a PostgreSQL data type, such as int4 or text, as discovered from pg_type. - /// This class is abstract, see derived classes for concrete types of PostgreSQL types. + /// The data type's name. /// /// - /// Instances of this class are shared between connections to the same databases. - /// For more info about what this class and its subclasses represent, see - /// https://www.postgresql.org/docs/current/static/catalog-pg-type.html. + /// Note that this is the standard, user-displayable type name (e.g. integer[]) rather than the internal + /// PostgreSQL name as it is in pg_type (_int4). See for the latter. /// - public abstract class PostgresType + public string Name => DataTypeName.UnqualifiedDisplayName; + + /// + /// The full name of the backend type, including its namespace. + /// + public string FullName { get; } + + internal DataTypeName DataTypeName { get; } + + /// + /// A display name for this backend type, including the namespace unless it is pg_catalog (the namespace + /// for all built-in types). + /// + public string DisplayName => DataTypeName.DisplayName; + + /// + /// The data type's internal PostgreSQL name (e.g. _int4 not integer[]). + /// See for a more user-friendly name. + /// + public string InternalName => DataTypeName.UnqualifiedName; + + /// + /// If a PostgreSQL array type exists for this type, it will be referenced here. + /// Otherwise null. + /// + public PostgresArrayType? Array { get; internal set; } + + /// + /// If a PostgreSQL range type exists for this type, it will be referenced here. + /// Otherwise null. + /// + public PostgresRangeType? Range { get; internal set; } + + #endregion + + internal virtual string GetPartialNameWithFacets(int typeModifier) => Name; + + /// + /// Generates the type name including any facts (size, precision, scale), given the PostgreSQL type modifier. + /// + internal string GetDisplayNameWithFacets(int typeModifier) + => Namespace == "pg_catalog" + ? GetPartialNameWithFacets(typeModifier) + : Namespace + '.' + GetPartialNameWithFacets(typeModifier); + + internal virtual PostgresFacets GetFacets(int typeModifier) => PostgresFacets.None; + + /// + /// Returns a string that represents the current object. + /// + public override string ToString() => DisplayName; + + PostgresType? _representationalType; + + /// Canonizes (nested) domain types to underlying types, does not handle composites. + internal PostgresType GetRepresentationalType() { - #region Constructors - - /// - /// Constructs a representation of a PostgreSQL data type. - /// - /// The data type's namespace (or schema). - /// The data type's name. - /// The data type's OID. - protected PostgresType(string ns, string name, uint oid) - : this(ns, name, name, oid) {} - - /// - /// Constructs a representation of a PostgreSQL data type. - /// - /// The data type's namespace (or schema). - /// The data type's name. - /// The data type's internal name (e.g. _int4 for integer[]). - /// The data type's OID. - protected PostgresType(string ns, string name, string internalName, uint oid) - { - Namespace = ns; - Name = name; - FullName = Namespace + '.' + Name; - InternalName = internalName; - OID = oid; - } - - #endregion - - #region Public Properties - - /// - /// The data type's OID - a unique id identifying the data type in a given database (in pg_type). - /// - public uint OID { get; } - - /// - /// The data type's namespace (or schema). - /// - public string Namespace { get; } - - /// - /// The data type's name. - /// - /// - /// Note that this is the standard, user-displayable type name (e.g. integer[]) rather than the internal - /// PostgreSQL name as it is in pg_type (_int4). See for the latter. - /// - public string Name { get; } - - /// - /// The full name of the backend type, including its namespace. - /// - public string FullName { get; } - - /// - /// A display name for this backend type, including the namespace unless it is pg_catalog (the namespace - /// for all built-in types). - /// - public string DisplayName => Namespace == "pg_catalog" ? Name : FullName; - - /// - /// The data type's internal PostgreSQL name (e.g. integer[] not _int4). - /// See for a more user-friendly name. - /// - public string InternalName { get; } - - /// - /// If a PostgreSQL array type exists for this type, it will be referenced here. - /// Otherwise null. - /// - public PostgresArrayType? Array { get; internal set; } - - /// - /// If a PostgreSQL range type exists for this type, it will be referenced here. - /// Otherwise null. - /// - public PostgresRangeType? Range { get; internal set; } - - #endregion - - internal virtual string GetPartialNameWithFacets(int typeModifier) => Name; - - /// - /// Generates the type name including any facts (size, precision, scale), given the PostgreSQL type modifier. - /// - internal string GetDisplayNameWithFacets(int typeModifier) - => Namespace == "pg_catalog" - ? GetPartialNameWithFacets(typeModifier) - : Namespace + '.' + GetPartialNameWithFacets(typeModifier); - - internal virtual PostgresFacets GetFacets(int typeModifier) => PostgresFacets.None; - - /// - /// Returns a string that represents the current object. - /// - public override string ToString() => DisplayName; + return _representationalType ??= Core(this) ?? throw new InvalidOperationException("Couldn't map type to representational type"); + + static PostgresType? Core(PostgresType? postgresType) + => (postgresType as PostgresDomainType)?.BaseType ?? postgresType switch + { + PostgresArrayType { Element: PostgresDomainType domain } => Core(domain.BaseType)?.Array, + PostgresMultirangeType { Subrange.Subtype: PostgresDomainType domain } => domain.BaseType.Range?.Multirange, + PostgresRangeType { Subtype: PostgresDomainType domain } => domain.Range, + var type => type + }; } } diff --git a/src/Npgsql/PostgresTypes/PostgresTypeKind.cs b/src/Npgsql/PostgresTypes/PostgresTypeKind.cs new file mode 100644 index 0000000000..03330f9050 --- /dev/null +++ b/src/Npgsql/PostgresTypes/PostgresTypeKind.cs @@ -0,0 +1,21 @@ +namespace Npgsql.PostgresTypes; + +enum PostgresTypeKind +{ + /// A base type. + Base, + /// An enum carrying its variants. + Enum, + /// A pseudo type like anyarray. + Pseudo, + // An array carrying its element type. + Array, + // A range carrying its element type. + Range, + // A multi-range carrying its element type. + Multirange, + // A domain carrying its underlying type. + Domain, + // A composite carrying its constituent fields. + Composite +} diff --git a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs index eed53b85a2..bbe952726d 100644 --- a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs +++ b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs @@ -1,17 +1,16 @@ -namespace Npgsql.PostgresTypes +namespace Npgsql.PostgresTypes; + +/// +/// Represents a PostgreSQL data type that isn't known to Npgsql and cannot be handled. +/// +public sealed class UnknownBackendType : PostgresType { + internal static readonly PostgresType Instance = new UnknownBackendType(); + /// - /// Represents a PostgreSQL data type that isn't known to Npgsql and cannot be handled. + /// Constructs a the unknown backend type. /// - public class UnknownBackendType : PostgresType - { - internal static readonly PostgresType Instance = new UnknownBackendType(); - - /// - /// Constructs a the unknown backend type. - /// #pragma warning disable CA2222 // Do not decrease inherited member visibility - UnknownBackendType() : base("", "", 0) { } + UnknownBackendType() : base("", "", 0) { } #pragma warning restore CA2222 // Do not decrease inherited member visibility - } } diff --git a/src/Npgsql/PregeneratedMessages.cs b/src/Npgsql/PregeneratedMessages.cs index eefbf70c96..b6d2e4dd02 100644 --- a/src/Npgsql/PregeneratedMessages.cs +++ b/src/Npgsql/PregeneratedMessages.cs @@ -1,57 +1,54 @@ -using System.Diagnostics; -using System.IO; -using System.Linq; +using System.IO; using System.Text; +using Npgsql.Internal; +using Npgsql.Util; -namespace Npgsql +namespace Npgsql; + +static class PregeneratedMessages { - static class PregeneratedMessages + static PregeneratedMessages() { - static PregeneratedMessages() - { #pragma warning disable CS8625 - // This is the only use of a write buffer without a connector, for in-memory construction of - // pregenerated messages. - using var buf = new NpgsqlWriteBuffer(null, new MemoryStream(), null, NpgsqlWriteBuffer.MinimumSize, Encoding.ASCII); + // This is the only use of a write buffer without a connector, for in-memory construction of + // pregenerated messages. + using var buf = new NpgsqlWriteBuffer(null, new MemoryStream(), null, NpgsqlWriteBuffer.MinimumSize, Encoding.ASCII); #pragma warning restore CS8625 - BeginTransRepeatableRead = Generate(buf, "BEGIN ISOLATION LEVEL REPEATABLE READ"); - BeginTransSerializable = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE"); - BeginTransReadCommitted = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); - BeginTransReadUncommitted = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL READ UNCOMMITTED"); - CommitTransaction = Generate(buf, "COMMIT"); - RollbackTransaction = Generate(buf, "ROLLBACK"); - KeepAlive = Generate(buf, "SELECT NULL"); - DiscardAll = Generate(buf, "DISCARD ALL"); - } - - internal static byte[] Generate(NpgsqlWriteBuffer buf, string query) - { - Debug.Assert(query.All(c => c < 128)); - - var queryByteLen = Encoding.ASCII.GetByteCount(query); - - buf.WriteByte(FrontendMessageCode.Query); - buf.WriteInt32(4 + // Message length (including self excluding code) - queryByteLen + // Query byte length - 1); // Null terminator - - buf.WriteString(query, queryByteLen, false).Wait(); - buf.WriteByte(0); - - var bytes = buf.GetContents(); - buf.Clear(); - return bytes; - } - - internal static readonly byte[] BeginTransRepeatableRead; - internal static readonly byte[] BeginTransSerializable; - internal static readonly byte[] BeginTransReadCommitted; - internal static readonly byte[] BeginTransReadUncommitted; - internal static readonly byte[] CommitTransaction; - internal static readonly byte[] RollbackTransaction; - internal static readonly byte[] KeepAlive; - - internal static readonly byte[] DiscardAll; + BeginTransRepeatableRead = Generate(buf, "BEGIN ISOLATION LEVEL REPEATABLE READ"); + BeginTransSerializable = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE"); + BeginTransReadCommitted = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); + BeginTransReadUncommitted = Generate(buf, "BEGIN TRANSACTION ISOLATION LEVEL READ UNCOMMITTED"); + CommitTransaction = Generate(buf, "COMMIT"); + RollbackTransaction = Generate(buf, "ROLLBACK"); + DiscardAll = Generate(buf, "DISCARD ALL"); + } + + internal static byte[] Generate(NpgsqlWriteBuffer buf, string query) + { + NpgsqlWriteBuffer.AssertASCIIOnly(query); + + var queryByteLen = Encoding.ASCII.GetByteCount(query); + + buf.WriteByte(FrontendMessageCode.Query); + buf.WriteInt32(4 + // Message length (including self excluding code) + queryByteLen + // Query byte length + 1); // Null terminator + + buf.WriteString(query, queryByteLen, false).Wait(); + buf.WriteByte(0); + + var bytes = buf.GetContents(); + buf.Clear(); + return bytes; } + + internal static readonly byte[] BeginTransRepeatableRead; + internal static readonly byte[] BeginTransSerializable; + internal static readonly byte[] BeginTransReadCommitted; + internal static readonly byte[] BeginTransReadUncommitted; + internal static readonly byte[] CommitTransaction; + internal static readonly byte[] RollbackTransaction; + + internal static readonly byte[] DiscardAll; } diff --git a/src/Npgsql/PreparedStatement.cs b/src/Npgsql/PreparedStatement.cs index 334b58cbdc..f24905eb41 100644 --- a/src/Npgsql/PreparedStatement.cs +++ b/src/Npgsql/PreparedStatement.cs @@ -1,157 +1,176 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Text; using Npgsql.BackendMessages; +using Npgsql.Internal.Postgres; -namespace Npgsql +namespace Npgsql; + +/// +/// Internally represents a statement has been prepared, is in the process of being prepared, or is a +/// candidate for preparation (i.e. awaiting further usages). +/// +[DebuggerDisplay("{Name} ({State}): {Sql}")] +sealed class PreparedStatement { - /// - /// Internally represents a statement has been prepared, is in the process of being prepared, or is a - /// candidate for preparation (i.e. awaiting further usages). - /// - [DebuggerDisplay("{Name} ({State}): {Sql}")] - class PreparedStatement - { - readonly PreparedStatementManager _manager; + readonly PreparedStatementManager _manager; - internal string Sql { get; } + internal string Sql { get; } - internal string? Name; + internal byte[]? Name; - internal RowDescriptionMessage? Description; + internal RowDescriptionMessage? Description; - internal int Usages; + internal int Usages; - internal PreparedState State { get; set; } + internal PreparedState State { get; set; } - internal bool IsPrepared => State == PreparedState.Prepared; + internal bool IsPrepared => State == PreparedState.Prepared; - /// - /// If true, the user explicitly requested this statement be prepared. It does not get closed as part of - /// the automatic preparation LRU mechanism. - /// - internal bool IsExplicit { get; } + /// + /// If true, the user explicitly requested this statement be prepared. It does not get closed as part of + /// the automatic preparation LRU mechanism. + /// + internal bool IsExplicit { get; } - /// - /// If this statement is about to be prepared, but replaces a previous statement which needs to be closed, - /// this holds the name of the previous statement. Otherwise null. - /// - internal PreparedStatement? StatementBeingReplaced; + /// + /// If this statement is about to be prepared, but replaces a previous statement which needs to be closed, + /// this holds the name of the previous statement. Otherwise null. + /// + internal PreparedStatement? StatementBeingReplaced; - internal DateTime LastUsed { get; set; } + internal int AutoPreparedSlotIndex { get; set; } - /// - /// Contains the handler types for a prepared statement's parameters, for overloaded cases (same SQL, different param types) - /// Only populated after the statement has been prepared (i.e. null for candidates). - /// - internal Type[]? HandlerParamTypes { get; private set; } + internal long LastUsed { get; set; } - static readonly Type[] EmptyParamTypes = Type.EmptyTypes; + internal void RefreshLastUsed() => LastUsed = Stopwatch.GetTimestamp(); - internal static PreparedStatement CreateExplicit( - PreparedStatementManager manager, - string sql, - string name, - List parameters, - PreparedStatement? statementBeingReplaced) + /// + /// Contains the handler types for a prepared statement's parameters, for overloaded cases (same SQL, different param types) + /// Only populated after the statement has been prepared (i.e. null for candidates). + /// + PgTypeId[]? ConverterParamTypes { get; set; } + + internal static PreparedStatement CreateExplicit( + PreparedStatementManager manager, + string sql, + string name, + List parameters, + PreparedStatement? statementBeingReplaced) + { + var pStatement = new PreparedStatement(manager, sql, true) { - var pStatement = new PreparedStatement(manager, sql, true) - { - Name = name, - StatementBeingReplaced = statementBeingReplaced - }; - pStatement.SetParamTypes(parameters); - return pStatement; - } + Name = Encoding.ASCII.GetBytes(name), + StatementBeingReplaced = statementBeingReplaced + }; + pStatement.SetParamTypes(parameters); + return pStatement; + } - internal static PreparedStatement CreateAutoPrepareCandidate(PreparedStatementManager manager, string sql) - => new PreparedStatement(manager, sql, false); + internal static PreparedStatement CreateAutoPrepareCandidate(PreparedStatementManager manager, string sql) + => new(manager, sql, false); - PreparedStatement(PreparedStatementManager manager, string sql, bool isExplicit) - { - _manager = manager; - Sql = sql; - IsExplicit = isExplicit; - State = PreparedState.NotPrepared; - } + internal PreparedStatement(PreparedStatementManager manager, string sql, bool isExplicit) + { + _manager = manager; + Sql = sql; + IsExplicit = isExplicit; + State = PreparedState.NotPrepared; + } - internal void SetParamTypes(List parameters) + internal void SetParamTypes(List parameters) + { + if (parameters.Count == 0) { - Debug.Assert(HandlerParamTypes == null); - if (parameters.Count == 0) - { - HandlerParamTypes = EmptyParamTypes; - return; - } - - HandlerParamTypes = new Type[parameters.Count]; - for (var i = 0; i < parameters.Count; i++) - HandlerParamTypes[i] = parameters[i].Handler!.GetType(); + ConverterParamTypes = Array.Empty(); + return; } - internal bool DoParametersMatch(List parameters) - { - if (HandlerParamTypes!.Length != parameters.Count) + ConverterParamTypes = new PgTypeId[parameters.Count]; + for (var i = 0; i < parameters.Count; i++) + ConverterParamTypes[i] = parameters[i].PgTypeId; + } + + internal bool DoParametersMatch(List parameters) + { + var paramTypes = ConverterParamTypes!; + if (paramTypes.Length != parameters.Count) + return false; + + for (var i = 0; i < paramTypes.Length; i++) + if (paramTypes[i] != parameters[i].PgTypeId) return false; - for (var i = 0; i < HandlerParamTypes.Length; i++) - if (HandlerParamTypes[i] != parameters[i].Handler!.GetType()) - return false; + return true; + } - return true; - } + internal void AbortPrepare() + { + Debug.Assert(State == PreparedState.BeingPrepared); - internal void CompletePrepare() - { - Debug.Assert(HandlerParamTypes != null); - _manager.BySql[Sql] = this; - _manager.NumPrepared++; - State = PreparedState.Prepared; - } + // We were planned for preparation, but a failure occurred and we did not carry that out. + // Remove it from the BySql dictionary, and place back the statement we were planned to replace (possibly null), setting + // its state back to prepared. + _manager.BySql.Remove(Sql); - internal void CompleteUnprepare() + if (!IsExplicit) { - _manager.BySql.Remove(Sql); - if (IsPrepared || State == PreparedState.BeingUnprepared) - _manager.NumPrepared--; - State = PreparedState.Unprepared; + _manager.AutoPrepared[AutoPreparedSlotIndex] = StatementBeingReplaced; + if (StatementBeingReplaced is not null) + StatementBeingReplaced.State = PreparedState.Prepared; } - public override string ToString() => Sql; + State = PreparedState.Unprepared; } - /// - /// The state of a . - /// - enum PreparedState + internal void CompleteUnprepare() { - /// - /// The statement hasn't been prepared yet, nor is it in the process of being prepared. - /// This is the value for autoprepare candidates which haven't been prepared yet, and is also - /// a temporary state during preparation. - /// - NotPrepared, - - /// - /// The statement is in the process of being prepared. - /// - BeingPrepared, - - /// - /// The statement has been fully prepared and can be executed. - /// - Prepared, - - /// - /// The statement is in the process of being unprepared. This is a temporary state that only occurs during - /// unprepare. Specifically, it means that a Close message for the statement has already been written - /// to the write buffer. - /// - BeingUnprepared, - - /// - /// The statement has been unprepared and is no longer usable. - /// - Unprepared + _manager.BySql.Remove(Sql); + _manager.NumPrepared--; + + State = PreparedState.Unprepared; } + + public override string ToString() => Sql; +} + +/// +/// The state of a . +/// +enum PreparedState +{ + /// + /// The statement hasn't been prepared yet, nor is it in the process of being prepared. + /// This is the value for autoprepare candidates which haven't been prepared yet, and is also + /// a temporary state during preparation. + /// + NotPrepared, + + /// + /// The statement is in the process of being prepared. + /// + BeingPrepared, + + /// + /// The statement has been fully prepared and can be executed. + /// + Prepared, + + /// + /// The statement is in the process of being unprepared. This is a temporary state that only occurs during + /// unprepare. Specifically, it means that a Close message for the statement has already been written + /// to the write buffer. + /// + BeingUnprepared, + + /// + /// The statement has been unprepared and is no longer usable. + /// + Unprepared, + + /// + /// The statement was invalidated because e.g. table schema has changed since preparation. + /// + Invalidated } diff --git a/src/Npgsql/PreparedStatementManager.cs b/src/Npgsql/PreparedStatementManager.cs index 1ddda58c5a..ef72879c6d 100644 --- a/src/Npgsql/PreparedStatementManager.cs +++ b/src/Npgsql/PreparedStatementManager.cs @@ -1,241 +1,284 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using Npgsql.Logging; +using System.Text; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; -namespace Npgsql +namespace Npgsql; + +sealed class PreparedStatementManager { - class PreparedStatementManager - { - internal int MaxAutoPrepared { get; } - internal int UsagesBeforePrepare { get; } + internal int MaxAutoPrepared { get; } + internal int UsagesBeforePrepare { get; } + + internal Dictionary BySql { get; } = new(); + internal PreparedStatement?[] AutoPrepared { get; } - internal Dictionary BySql { get; } = new Dictionary(); - readonly PreparedStatement[] _autoPrepared; - int _numAutoPrepared; + readonly PreparedStatement?[] _candidates; - readonly PreparedStatement?[] _candidates; + static readonly List EmptyParameters = new(); - /// - /// Total number of current prepared statements (whether explicit or automatic). - /// - internal int NumPrepared; + /// + /// Total number of current prepared statements (whether explicit or automatic). + /// + internal int NumPrepared; - readonly NpgsqlConnector _connector; + readonly NpgsqlConnector _connector; - internal string NextPreparedStatementName() => "_p" + (++_preparedStatementIndex); - ulong _preparedStatementIndex; + internal string NextPreparedStatementName() => "_p" + (++_preparedStatementIndex); + ulong _preparedStatementIndex; - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(PreparedStatementManager)); + readonly ILogger _commandLogger; + + internal const int CandidateCount = 100; + + internal PreparedStatementManager(NpgsqlConnector connector) + { + _connector = connector; + _commandLogger = connector.LoggingConfiguration.CommandLogger; - internal const int CandidateCount = 100; + MaxAutoPrepared = connector.Settings.MaxAutoPrepare; + UsagesBeforePrepare = connector.Settings.AutoPrepareMinUsages; - internal PreparedStatementManager(NpgsqlConnector connector) + if (MaxAutoPrepared > 0) { - _connector = connector; - MaxAutoPrepared = connector.Settings.MaxAutoPrepare; - UsagesBeforePrepare = connector.Settings.AutoPrepareMinUsages; - if (MaxAutoPrepared > 0) + if (MaxAutoPrepared > 256) + _commandLogger.LogWarning($"{nameof(MaxAutoPrepared)} is over 256, performance degradation may occur. Please report via an issue."); + AutoPrepared = new PreparedStatement[MaxAutoPrepared]; + _candidates = new PreparedStatement[CandidateCount]; + } + else + { + AutoPrepared = null!; + _candidates = null!; + } + } + + internal PreparedStatement? GetOrAddExplicit(NpgsqlBatchCommand batchCommand) + { + var sql = batchCommand.FinalCommandText!; + + PreparedStatement? statementBeingReplaced = null; + if (BySql.TryGetValue(sql, out var pStatement)) + { + Debug.Assert(pStatement.State != PreparedState.Unprepared); + if (pStatement.IsExplicit) { - if (MaxAutoPrepared > 256) - Log.Warn($"{nameof(MaxAutoPrepared)} is over 256, performance degradation may occur. Please report via an issue.", connector.Id); - _autoPrepared = new PreparedStatement[MaxAutoPrepared]; - _candidates = new PreparedStatement[CandidateCount]; + // Great, we've found an explicit prepared statement. + // We just need to check that the parameter types correspond, since prepared statements are + // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. + return pStatement.DoParametersMatch(batchCommand.CurrentParametersReadOnly) + ? pStatement + : null; } - else + + // We've found an autoprepare statement (candidate or otherwise) + switch (pStatement.State) { - _autoPrepared = null!; - _candidates = null!; + case PreparedState.NotPrepared: + // Found a candidate for autopreparation. Remove it and prepare explicitly. + RemoveCandidate(pStatement); + break; + case PreparedState.Prepared: + // The statement has already been autoprepared. We need to "promote" it to explicit. + statementBeingReplaced = pStatement; + break; + case PreparedState.Unprepared: + throw new InvalidOperationException($"Found unprepared statement in {nameof(PreparedStatementManager)}"); + default: + throw new ArgumentOutOfRangeException(); } } - internal PreparedStatement? GetOrAddExplicit(NpgsqlStatement statement) - { - var sql = statement.SQL; + // Statement hasn't been prepared yet + return BySql[sql] = PreparedStatement.CreateExplicit(this, sql, NextPreparedStatementName(), batchCommand.CurrentParametersReadOnly, statementBeingReplaced); + } - PreparedStatement? statementBeingReplaced = null; - if (BySql.TryGetValue(sql, out var pStatement)) + internal PreparedStatement? TryGetAutoPrepared(NpgsqlBatchCommand batchCommand) + { + var sql = batchCommand.FinalCommandText!; + if (!BySql.TryGetValue(sql, out var pStatement)) + { + // New candidate. Find an empty candidate slot or eject a least-used one. + int slotIndex = -1, leastUsages = int.MaxValue; + var lastUsed = long.MaxValue; + for (var i = 0; i < _candidates.Length; i++) { - Debug.Assert(pStatement.State != PreparedState.Unprepared); - if (pStatement.IsExplicit) + var candidate = _candidates[i]; + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + // ReSharper disable HeuristicUnreachableCode + if (candidate == null) // Found an unused candidate slot, return immediately + { + slotIndex = i; + break; + } + // ReSharper restore HeuristicUnreachableCode + if (candidate.Usages < leastUsages) { - // Great, we've found an explicit prepared statement. - // We just need to check that the parameter types correspond, since prepared statements are - // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. - return pStatement.DoParametersMatch(statement.InputParameters) - ? pStatement - : null; + leastUsages = candidate.Usages; + slotIndex = i; + lastUsed = candidate.LastUsed; } - - // We've found an autoprepare statement (candidate or otherwise) - switch (pStatement.State) + else if (candidate.Usages == leastUsages && candidate.LastUsed < lastUsed) { - case PreparedState.NotPrepared: - // Found a candidate for autopreparation. Remove it and prepare explicitly. - RemoveCandidate(pStatement); - break; - case PreparedState.Prepared: - // The statement has already been autoprepared. We need to "promote" it to explicit. - statementBeingReplaced = pStatement; - break; - case PreparedState.Unprepared: - throw new InvalidOperationException($"Found unprepared statement in {nameof(PreparedStatementManager)}"); - default: - throw new ArgumentOutOfRangeException(); + slotIndex = i; + lastUsed = candidate.LastUsed; } } - // Statement hasn't been prepared yet - return BySql[sql] = PreparedStatement.CreateExplicit(this, sql, NextPreparedStatementName(), statement.InputParameters, statementBeingReplaced); + var leastUsed = _candidates[slotIndex]; + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + if (leastUsed != null) + BySql.Remove(leastUsed.Sql); + pStatement = BySql[sql] = _candidates[slotIndex] = PreparedStatement.CreateAutoPrepareCandidate(this, sql); } - internal PreparedStatement? TryGetAutoPrepared(NpgsqlStatement statement) + switch (pStatement.State) { - var sql = statement.SQL; - if (!BySql.TryGetValue(sql, out var pStatement)) - { - // New candidate. Find an empty candidate slot or eject a least-used one. - int slotIndex = -1, leastUsages = int.MaxValue; - var lastUsed = DateTime.MaxValue; - for (var i = 0; i < _candidates.Length; i++) - { - var candidate = _candidates[i]; - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - // ReSharper disable HeuristicUnreachableCode - if (candidate == null) // Found an unused candidate slot, return immediately - { - slotIndex = i; - break; - } - // ReSharper restore HeuristicUnreachableCode - if (candidate.Usages < leastUsages) - { - leastUsages = candidate.Usages; - slotIndex = i; - lastUsed = candidate.LastUsed; - } - else if (candidate.Usages == leastUsages && candidate.LastUsed < lastUsed) - { - slotIndex = i; - lastUsed = candidate.LastUsed; - } - } + case PreparedState.NotPrepared: + case PreparedState.Invalidated: + break; - var leastUsed = _candidates[slotIndex]; - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - if (leastUsed != null) - BySql.Remove(leastUsed.Sql); - pStatement = BySql[sql] = _candidates[slotIndex] = PreparedStatement.CreateAutoPrepareCandidate(this, sql); - } + case PreparedState.Prepared: + case PreparedState.BeingPrepared: + // The statement has already been prepared (explicitly or automatically), or has been selected + // for preparation (earlier identical statement in the same command). + // We just need to check that the parameter types correspond, since prepared statements are + // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. + if (!pStatement.DoParametersMatch(batchCommand.CurrentParametersReadOnly)) + return null; + // Prevent this statement from being replaced within this batch + pStatement.LastUsed = long.MaxValue; + return pStatement; - switch (pStatement.State) - { - case PreparedState.NotPrepared: - break; + case PreparedState.BeingUnprepared: + // The statement is being replaced by an earlier statement in this same batch. + return null; - case PreparedState.Prepared: - case PreparedState.BeingPrepared: - // The statement has already been prepared (explicitly or automatically), or has been selected - // for preparation (earlier identical statement in the same command). - // We just need to check that the parameter types correspond, since prepared statements are - // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. - if (!pStatement.DoParametersMatch(statement.InputParameters)) - return null; - // Prevent this statement from being replaced within this batch - pStatement.LastUsed = DateTime.MaxValue; - return pStatement; - - case PreparedState.BeingUnprepared: - // The statement is being replaced by an earlier statement in this same batch. - return null; + default: + Debug.Fail($"Unexpected {nameof(PreparedState)} in auto-preparation: {pStatement.State}"); + break; + } - default: - Debug.Fail($"Unexpected {nameof(PreparedState)} in auto-preparation: {pStatement.State}"); - break; - } + if (++pStatement.Usages < UsagesBeforePrepare) + { + // Statement still hasn't passed the usage threshold, no automatic preparation. + // Return null for unprepared execution. + pStatement.RefreshLastUsed(); + return null; + } - if (++pStatement.Usages < UsagesBeforePrepare) - { - // Statement still hasn't passed the usage threshold, no automatic preparation. - // Return null for unprepared execution. - pStatement.LastUsed = DateTime.UtcNow; - return null; - } + // Bingo, we've just passed the usage threshold, statement should get prepared + LogMessages.AutoPreparingStatement(_commandLogger, sql, _connector.Id); - // Bingo, we've just passed the usage threshold, statement should get prepared - Log.Trace($"Automatically preparing statement: {sql}", _connector.Id); + // Look for either an empty autoprepare slot, or the least recently used prepared statement which we'll replace it. + var oldestLastUsed = long.MaxValue; + var selectedIndex = -1; + for (var i = 0; i < AutoPrepared.Length; i++) + { + var slot = AutoPrepared[i]; - if (_numAutoPrepared < MaxAutoPrepared) + if (slot is null or { State: PreparedState.Invalidated }) { - // We still have free slots - _autoPrepared[_numAutoPrepared++] = pStatement; - pStatement.Name = "_auto" + _numAutoPrepared; + // We found a free or invalidated slot, exit the loop immediately + selectedIndex = i; + break; } - else + + switch (slot.State) { - // We already have the maximum number of prepared statements. - // Find the least recently used prepared statement and replace it. - var oldestTimestamp = DateTime.MaxValue; - var oldestIndex = -1; - for (var i = 0; i < _autoPrepared.Length; i++) + case PreparedState.Prepared: + if (slot.LastUsed < oldestLastUsed) { - if (_autoPrepared[i].LastUsed < oldestTimestamp) - { - oldestIndex = i; - oldestTimestamp = _autoPrepared[i].LastUsed; - } + selectedIndex = i; + oldestLastUsed = slot.LastUsed; } + break; - if (oldestIndex == -1) - { - // We're here if we couldn't find a prepared statement to replace, because all of them are already - // being prepared in this batch. - return null; - } + case PreparedState.BeingPrepared: + // Slot has already been selected for preparation by an earlier statement in this batch. Skip it. + continue; - var lru = _autoPrepared[oldestIndex]; - pStatement.Name = lru.Name; - pStatement.StatementBeingReplaced = lru; - lru.State = PreparedState.BeingUnprepared; - _autoPrepared[oldestIndex] = pStatement; + default: + ThrowHelper.ThrowInvalidOperationException($"Invalid {nameof(PreparedState)} state {slot.State} encountered when scanning prepared statement slots"); + return null; } + } - RemoveCandidate(pStatement); + if (selectedIndex == -1) + { + // We're here if we couldn't find a free slot or a prepared statement to replace - this means all slots are taken by + // statements being prepared in this batch. + return null; + } - // Make sure this statement isn't replaced by a later statement in the same batch. - pStatement.LastUsed = DateTime.MaxValue; + if (pStatement.State != PreparedState.Invalidated) + RemoveCandidate(pStatement); - // Note that the parameter types are only set at the moment of preparation - in the candidate phase - // there's no differentiation between overloaded statements, which are a pretty rare case, saving - // allocations. - pStatement.SetParamTypes(statement.InputParameters); + var oldPreparedStatement = AutoPrepared[selectedIndex]; - return pStatement; + if (oldPreparedStatement is null) + { + pStatement.Name = Encoding.ASCII.GetBytes("_auto" + selectedIndex); } - - void RemoveCandidate(PreparedStatement candidate) + else { - var i = 0; - for (; i < _candidates.Length; i++) + // When executing an invalidated prepared statement, the old and the new statements are the same instance. + // Create a copy so that we have two distinct instances with their own states. + if (oldPreparedStatement == pStatement) { - if (_candidates[i] == candidate) + oldPreparedStatement = new PreparedStatement(this, oldPreparedStatement.Sql, isExplicit: false) { - _candidates[i] = null; - return; - } + Name = oldPreparedStatement.Name + }; } - Debug.Assert(i < _candidates.Length); + + pStatement.Name = oldPreparedStatement.Name; + pStatement.State = PreparedState.NotPrepared; + pStatement.StatementBeingReplaced = oldPreparedStatement; + oldPreparedStatement.State = PreparedState.BeingUnprepared; } - internal void ClearAll() + pStatement.AutoPreparedSlotIndex = selectedIndex; + AutoPrepared[selectedIndex] = pStatement; + + + // Make sure this statement isn't replaced by a later statement in the same batch. + pStatement.LastUsed = long.MaxValue; + + // Note that the parameter types are only set at the moment of preparation - in the candidate phase + // there's no differentiation between overloaded statements, which are a pretty rare case, saving + // allocations. + pStatement.SetParamTypes(batchCommand.CurrentParametersReadOnly); + + return pStatement; + } + + void RemoveCandidate(PreparedStatement candidate) + { + var i = 0; + for (; i < _candidates.Length; i++) { - BySql.Clear(); - NumPrepared = 0; - _preparedStatementIndex = 0; - _numAutoPrepared = 0; - if (_candidates != null) - for (var i = 0; i < _candidates.Length; i++) - _candidates[i] = null; + if (_candidates[i] == candidate) + { + _candidates[i] = null; + return; + } } + Debug.Assert(i < _candidates.Length); + } + + internal void ClearAll() + { + BySql.Clear(); + NumPrepared = 0; + _preparedStatementIndex = 0; + if (AutoPrepared is not null) + for (var i = 0; i < AutoPrepared.Length; i++) + AutoPrepared[i] = null; + if (_candidates != null) + for (var i = 0; i < _candidates.Length; i++) + _candidates[i] = null; } } diff --git a/src/Npgsql/PreparedTextReader.cs b/src/Npgsql/PreparedTextReader.cs new file mode 100644 index 0000000000..8862daa3e7 --- /dev/null +++ b/src/Npgsql/PreparedTextReader.cs @@ -0,0 +1,127 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; + +namespace Npgsql; + +sealed class PreparedTextReader : TextReader +{ + string _str = null!; + NpgsqlReadBuffer.ColumnStream _stream = null!; + + int _position; + bool _disposed; + + public void Init(string str, NpgsqlReadBuffer.ColumnStream stream) + { + _str = str; + _stream = stream; + _disposed = false; + _position = 0; + } + + public bool IsDisposed => _disposed; + + public override int Peek() + { + CheckDisposed(); + + return _position < _str.Length + ? _str[_position] + : -1; + } + + public override int Read() + { + CheckDisposed(); + + return _position < _str.Length + ? _str[_position++] + : -1; + } + +#if NETSTANDARD2_0 + public int Read(Span buffer) +#else + public override int Read(Span buffer) +#endif + { + CheckDisposed(); + + var toRead = Math.Min(buffer.Length, _str.Length - _position); + if (toRead == 0) + return 0; + + _str.AsSpan(_position, toRead).CopyTo(buffer); + _position += toRead; + return toRead; + } + + public override int Read(char[] buffer, int index, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (index < 0 || count < 0) + { + throw new ArgumentOutOfRangeException(index < 0 ? nameof(index) : nameof(count)); + } + if (buffer.Length - index < count) + { + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } + + return Read(buffer.AsSpan(index, count)); + } + + public override Task ReadAsync(char[] buffer, int index, int count) + => Task.FromResult(Read(buffer, index, count)); + + public +#if !NETSTANDARD2_0 + override +#endif + ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => new(Read(buffer.Span)); + + public override Task ReadLineAsync() => Task.FromResult(ReadLine()); + + public override string ReadToEnd() + { + CheckDisposed(); + + if (_position == _str.Length) + return string.Empty; + + var str = _str.Substring(_position); + _position = _str.Length; + return str; + } + + public override Task ReadToEndAsync() => Task.FromResult(ReadToEnd()); + + void CheckDisposed() + { + if (_disposed || _stream.IsDisposed) + ThrowHelper.ThrowObjectDisposedException(nameof(PreparedTextReader)); + } + + public void Restart() + { + CheckDisposed(); + _position = 0; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (disposing) + { + _disposed = true; + _stream.Dispose(); + } + } +} diff --git a/src/Npgsql/Properties/AssemblyInfo.cs b/src/Npgsql/Properties/AssemblyInfo.cs index 98d24d11ef..80500e0028 100644 --- a/src/Npgsql/Properties/AssemblyInfo.cs +++ b/src/Npgsql/Properties/AssemblyInfo.cs @@ -8,35 +8,18 @@ [assembly: AssemblyTrademark("")] [assembly: SecurityRules(SecurityRuleSet.Level1)] -[assembly: InternalsVisibleTo("Npgsql.EntityFrameworkCore.PostgreSQL, PublicKey=" + -"0024000004800000940000000602000000240000525341310004000001000100" + -"2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + -"8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + -"7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + -"29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] +#if NET5_0_OR_GREATER +[module: SkipLocalsInit] +#endif -[assembly: InternalsVisibleTo("Npgsql.EntityFrameworkCore.PostgreSQL.Design, PublicKey=" + -"0024000004800000940000000602000000240000525341310004000001000100" + -"2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + -"8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + -"7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + -"29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] - -[assembly: InternalsVisibleTo("EntityFramework6.Npgsql, PublicKey=" + -"0024000004800000940000000602000000240000525341310004000001000100" + -"2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + -"8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + -"7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + -"29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] - -[assembly: InternalsVisibleTo("EntityFramework5.Npgsql, PublicKey=" + +[assembly: InternalsVisibleTo("Npgsql.Tests, PublicKey=" + "0024000004800000940000000602000000240000525341310004000001000100" + "2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + "8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + "7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + "29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] -[assembly: InternalsVisibleTo("Npgsql.Tests, PublicKey=" + +[assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + "0024000004800000940000000602000000240000525341310004000001000100" + "2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + "8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + diff --git a/src/Npgsql/Properties/NpgsqlStrings.Designer.cs b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs new file mode 100644 index 0000000000..f00370da48 --- /dev/null +++ b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs @@ -0,0 +1,222 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Npgsql.Properties { + using System; + + + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [System.Diagnostics.DebuggerNonUserCodeAttribute()] + [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class NpgsqlStrings { + + private static System.Resources.ResourceManager resourceMan; + + private static System.Globalization.CultureInfo resourceCulture; + + [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal NpgsqlStrings() { + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Resources.ResourceManager ResourceManager { + get { + if (object.Equals(null, resourceMan)) { + System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Npgsql.Properties.NpgsqlStrings", typeof(NpgsqlStrings).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] + internal static System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + internal static string CannotUseSslVerifyWithUserCallback { + get { + return ResourceManager.GetString("CannotUseSslVerifyWithUserCallback", resourceCulture); + } + } + + internal static string CannotUseSslRootCertificateWithUserCallback { + get { + return ResourceManager.GetString("CannotUseSslRootCertificateWithUserCallback", resourceCulture); + } + } + + internal static string TransportSecurityDisabled { + get { + return ResourceManager.GetString("TransportSecurityDisabled", resourceCulture); + } + } + + internal static string IntegratedSecurityDisabled { + get { + return ResourceManager.GetString("IntegratedSecurityDisabled", resourceCulture); + } + } + + internal static string NoMultirangeTypeFound { + get { + return ResourceManager.GetString("NoMultirangeTypeFound", resourceCulture); + } + } + + internal static string NotSupportedOnDataSourceCommand { + get { + return ResourceManager.GetString("NotSupportedOnDataSourceCommand", resourceCulture); + } + } + + internal static string NotSupportedOnDataSourceBatch { + get { + return ResourceManager.GetString("NotSupportedOnDataSourceBatch", resourceCulture); + } + } + + internal static string CannotSetBothPasswordProviderAndPassword { + get { + return ResourceManager.GetString("CannotSetBothPasswordProviderAndPassword", resourceCulture); + } + } + + internal static string CannotSetMultiplePasswordProviderKinds { + get { + return ResourceManager.GetString("CannotSetMultiplePasswordProviderKinds", resourceCulture); + } + } + + internal static string SyncAndAsyncPasswordProvidersRequired { + get { + return ResourceManager.GetString("SyncAndAsyncPasswordProvidersRequired", resourceCulture); + } + } + + internal static string PasswordProviderMissing { + get { + return ResourceManager.GetString("PasswordProviderMissing", resourceCulture); + } + } + + internal static string ArgumentMustBePositive { + get { + return ResourceManager.GetString("ArgumentMustBePositive", resourceCulture); + } + } + + internal static string CannotSpecifyTargetSessionAttributes { + get { + return ResourceManager.GetString("CannotSpecifyTargetSessionAttributes", resourceCulture); + } + } + + internal static string CannotReadIntervalWithMonthsAsTimeSpan { + get { + return ResourceManager.GetString("CannotReadIntervalWithMonthsAsTimeSpan", resourceCulture); + } + } + + internal static string PositionalParameterAfterNamed { + get { + return ResourceManager.GetString("PositionalParameterAfterNamed", resourceCulture); + } + } + + internal static string CannotReadInfinityValue { + get { + return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); + } + } + + internal static string SyncAndAsyncConnectionInitializersRequired { + get { + return ResourceManager.GetString("SyncAndAsyncConnectionInitializersRequired", resourceCulture); + } + } + + internal static string CannotUseValidationRootCertificateCallbackWithUserCallback { + get { + return ResourceManager.GetString("CannotUseValidationRootCertificateCallbackWithUserCallback", resourceCulture); + } + } + + internal static string RecordsNotEnabled { + get { + return ResourceManager.GetString("RecordsNotEnabled", resourceCulture); + } + } + + internal static string FullTextSearchNotEnabled { + get { + return ResourceManager.GetString("FullTextSearchNotEnabled", resourceCulture); + } + } + + internal static string LTreeNotEnabled { + get { + return ResourceManager.GetString("LTreeNotEnabled", resourceCulture); + } + } + + internal static string RangesNotEnabled { + get { + return ResourceManager.GetString("RangesNotEnabled", resourceCulture); + } + } + + internal static string MultirangesNotEnabled { + get { + return ResourceManager.GetString("MultirangesNotEnabled", resourceCulture); + } + } + + internal static string ArraysNotEnabled { + get { + return ResourceManager.GetString("ArraysNotEnabled", resourceCulture); + } + } + + internal static string TimestampTzNoDateTimeUnspecified { + get { + return ResourceManager.GetString("TimestampTzNoDateTimeUnspecified", resourceCulture); + } + } + + internal static string TimestampNoDateTimeUtc { + get { + return ResourceManager.GetString("TimestampNoDateTimeUtc", resourceCulture); + } + } + + internal static string DynamicJsonNotEnabled { + get { + return ResourceManager.GetString("DynamicJsonNotEnabled", resourceCulture); + } + } + + internal static string UnmappedEnumsNotEnabled { + get { + return ResourceManager.GetString("UnmappedEnumsNotEnabled", resourceCulture); + } + } + + internal static string UnmappedRangesNotEnabled { + get { + return ResourceManager.GetString("UnmappedRangesNotEnabled", resourceCulture); + } + } + } +} diff --git a/src/Npgsql/Properties/NpgsqlStrings.resx b/src/Npgsql/Properties/NpgsqlStrings.resx new file mode 100644 index 0000000000..5dbc58acdf --- /dev/null +++ b/src/Npgsql/Properties/NpgsqlStrings.resx @@ -0,0 +1,110 @@ + + + + + + + + + + text/microsoft-resx + + + 1.3 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + SslMode.{0} cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + + RootCertificate cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + + Transport security hasn't been enabled; please call {0} on NpgsqlSlimDataSourceBuilder to enable it. + + + Integrated security hasn't been enabled; please call {0} on NpgsqlSlimDataSourceBuilder to enable it. + + + No multirange type could be found in the database for subtype {0}. + + + Connection and transaction access is not supported on commands created from DbDataSource. + + + Connection and transaction access is not supported on batches created from DbDataSource. + + + When registering a password provider, a password or password file may not be set. + + + Multiple kinds of password providers were found, only one kind may be configured per DbDataSource. + + + Both sync and async password providers must be provided. + + + The right type of password provider (sync or async) was not found. + + + '{0}' must be positive. + + + When creating a multi-host data source, TargetSessionAttributes cannot be specified. Create without TargetSessionAttributes, and then obtain DataSource wrappers from it. Consult the docs for more information. + + + Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it. + + + When using CommandType.StoredProcedure, all positional parameters must come before named parameters. + + + Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled. + + + Both sync and async connection initializers must be provided. + + + ValidationRootCertificateCallback cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + + Could not read a PostgreSQL record. If you're attempting to read a record as a .NET tuple, call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/basic.html and the 8.0 release notes for more details). If you're reading a record as a .NET object array using NpgsqlSlimDataSourceBuilder, call '{2}'. + + + + Full-text search isn't enabled; please call {0} on {1} to enable full-text search. + + + Ltree isn't enabled; please call {0} on {1} to enable LTree. + + + Ranges aren't enabled; please call {0} on {1} to enable ranges. + + + Multiranges aren't enabled; please call {0} on {1} to enable multiranges. + + + Arrays aren't enabled; please call {0} on {1} to enable arrays. + + + Cannot write DateTime with Kind={0} to PostgreSQL type '{1}', only UTC is supported. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange. + + + Cannot write DateTime with Kind=UTC to PostgreSQL type '{0}', consider using '{1}'. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange. + + + Type '{0}' required dynamic JSON serialization, which requires an explicit opt-in; call '{1}' on '{2}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/json.html and the 8.0 release notes for more details). Alternatively, if you meant to use Newtonsoft JSON.NET instead of System.Text.Json, call UseJsonNet() instead. + + + + Reading and writing unmapped enums requires an explicit opt-in; call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/enums_and_composites.html and the 8.0 release notes for more details). + + + Reading and writing unmapped ranges and multiranges requires an explicit opt-in; call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/ranges.html and the 8.0 release notes for more details). + + diff --git a/src/Npgsql/PublicAPI.Shipped.txt b/src/Npgsql/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..3ec604ddc0 --- /dev/null +++ b/src/Npgsql/PublicAPI.Shipped.txt @@ -0,0 +1,1938 @@ +#nullable enable +abstract Npgsql.Replication.PgOutput.Messages.UpdateMessage.NewRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +abstract NpgsqlTypes.NpgsqlTsQuery.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +const Npgsql.NpgsqlConnection.DefaultPort = 5432 -> int +const Npgsql.PostgresErrorCodes.ActiveSqlTransaction = "25001" -> string! +const Npgsql.PostgresErrorCodes.AdminShutdown = "57P01" -> string! +const Npgsql.PostgresErrorCodes.AmbiguousAlias = "42P09" -> string! +const Npgsql.PostgresErrorCodes.AmbiguousColumn = "42702" -> string! +const Npgsql.PostgresErrorCodes.AmbiguousFunction = "42725" -> string! +const Npgsql.PostgresErrorCodes.AmbiguousParameter = "42P08" -> string! +const Npgsql.PostgresErrorCodes.ArraySubscriptError = "2202E" -> string! +const Npgsql.PostgresErrorCodes.AssertFailure = "P0004" -> string! +const Npgsql.PostgresErrorCodes.BadCopyFileFormat = "22P04" -> string! +const Npgsql.PostgresErrorCodes.BranchTransactionAlreadyActive = "25002" -> string! +const Npgsql.PostgresErrorCodes.CannotCoerce = "42846" -> string! +const Npgsql.PostgresErrorCodes.CannotConnectNow = "57P03" -> string! +const Npgsql.PostgresErrorCodes.CantChangeRuntimeParam = "55P02" -> string! +const Npgsql.PostgresErrorCodes.CardinalityViolation = "21000" -> string! +const Npgsql.PostgresErrorCodes.CaseNotFound = "20000" -> string! +const Npgsql.PostgresErrorCodes.CharacterNotInRepertoire = "22021" -> string! +const Npgsql.PostgresErrorCodes.CheckViolation = "23514" -> string! +const Npgsql.PostgresErrorCodes.CollationMismatch = "42P21" -> string! +const Npgsql.PostgresErrorCodes.ConfigFileError = "F0000" -> string! +const Npgsql.PostgresErrorCodes.ConfigurationLimitExceeded = "53400" -> string! +const Npgsql.PostgresErrorCodes.ConnectionDoesNotExist = "08003" -> string! +const Npgsql.PostgresErrorCodes.ConnectionException = "08000" -> string! +const Npgsql.PostgresErrorCodes.ConnectionFailure = "08006" -> string! +const Npgsql.PostgresErrorCodes.ContainingSqlNotPermittedExternalRoutineException = "38001" -> string! +const Npgsql.PostgresErrorCodes.CrashShutdown = "57P02" -> string! +const Npgsql.PostgresErrorCodes.DatabaseDropped = "57P04" -> string! +const Npgsql.PostgresErrorCodes.DataCorrupted = "XX001" -> string! +const Npgsql.PostgresErrorCodes.DataException = "22000" -> string! +const Npgsql.PostgresErrorCodes.DatatypeMismatch = "42804" -> string! +const Npgsql.PostgresErrorCodes.DatetimeFieldOverflow = "22008" -> string! +const Npgsql.PostgresErrorCodes.DeadlockDetected = "40P01" -> string! +const Npgsql.PostgresErrorCodes.DependentObjectsStillExist = "2BP01" -> string! +const Npgsql.PostgresErrorCodes.DependentPrivilegeDescriptorsStillExist = "2B000" -> string! +const Npgsql.PostgresErrorCodes.DeprecatedFeatureWarning = "01P01" -> string! +const Npgsql.PostgresErrorCodes.DiagnosticsException = "0Z000" -> string! +const Npgsql.PostgresErrorCodes.DiskFull = "53100" -> string! +const Npgsql.PostgresErrorCodes.DivisionByZero = "22012" -> string! +const Npgsql.PostgresErrorCodes.DuplicateAlias = "42712" -> string! +const Npgsql.PostgresErrorCodes.DuplicateColumn = "42701" -> string! +const Npgsql.PostgresErrorCodes.DuplicateCursor = "42P03" -> string! +const Npgsql.PostgresErrorCodes.DuplicateDatabase = "42P04" -> string! +const Npgsql.PostgresErrorCodes.DuplicateFile = "58P02" -> string! +const Npgsql.PostgresErrorCodes.DuplicateFunction = "42723" -> string! +const Npgsql.PostgresErrorCodes.DuplicateObject = "42710" -> string! +const Npgsql.PostgresErrorCodes.DuplicatePreparedStatement = "42P05" -> string! +const Npgsql.PostgresErrorCodes.DuplicateSchema = "42P06" -> string! +const Npgsql.PostgresErrorCodes.DuplicateTable = "42P07" -> string! +const Npgsql.PostgresErrorCodes.DynamicResultSetsReturnedWarning = "0100C" -> string! +const Npgsql.PostgresErrorCodes.ErrorInAssignment = "22005" -> string! +const Npgsql.PostgresErrorCodes.EscapeCharacterConflict = "2200B" -> string! +const Npgsql.PostgresErrorCodes.EventTriggerProtocolViolatedExternalRoutineInvocationException = "39P03" -> string! +const Npgsql.PostgresErrorCodes.ExclusionViolation = "23P01" -> string! +const Npgsql.PostgresErrorCodes.ExternalRoutineException = "38000" -> string! +const Npgsql.PostgresErrorCodes.ExternalRoutineInvocationException = "39000" -> string! +const Npgsql.PostgresErrorCodes.FdwColumnNameNotFound = "HV005" -> string! +const Npgsql.PostgresErrorCodes.FdwDynamicParameterValueNeeded = "HV002" -> string! +const Npgsql.PostgresErrorCodes.FdwError = "HV000" -> string! +const Npgsql.PostgresErrorCodes.FdwFunctionSequenceError = "HV010" -> string! +const Npgsql.PostgresErrorCodes.FdwInconsistentDescriptorInformation = "HV021" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidAttributeValue = "HV024" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidColumnName = "HV007" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidColumnNumber = "HV008" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidDataType = "HV004" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidDataTypeDescriptors = "HV006" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidDescriptorFieldIdentifier = "HV091" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidHandle = "HV00B" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidOptionIndex = "HV00C" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidOptionName = "HV00D" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidStringFormat = "HV00A" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidStringLengthOrBufferLength = "HV090" -> string! +const Npgsql.PostgresErrorCodes.FdwInvalidUseOfNullPointer = "HV009" -> string! +const Npgsql.PostgresErrorCodes.FdwNoSchemas = "HV00P" -> string! +const Npgsql.PostgresErrorCodes.FdwOptionNameNotFound = "HV00J" -> string! +const Npgsql.PostgresErrorCodes.FdwOutOfMemory = "HV001" -> string! +const Npgsql.PostgresErrorCodes.FdwReplyHandle = "HV00K" -> string! +const Npgsql.PostgresErrorCodes.FdwSchemaNotFound = "HV00Q" -> string! +const Npgsql.PostgresErrorCodes.FdwTableNotFound = "HV00R" -> string! +const Npgsql.PostgresErrorCodes.FdwTooManyHandles = "HV014" -> string! +const Npgsql.PostgresErrorCodes.FdwUnableToCreateExecution = "HV00L" -> string! +const Npgsql.PostgresErrorCodes.FdwUnableToCreateReply = "HV00M" -> string! +const Npgsql.PostgresErrorCodes.FdwUnableToEstablishConnection = "HV00N" -> string! +const Npgsql.PostgresErrorCodes.FeatureNotSupported = "0A000" -> string! +const Npgsql.PostgresErrorCodes.FloatingPointException = "22P01" -> string! +const Npgsql.PostgresErrorCodes.ForeignKeyViolation = "23503" -> string! +const Npgsql.PostgresErrorCodes.FunctionExecutedNoReturnStatementSqlRoutineException = "2F005" -> string! +const Npgsql.PostgresErrorCodes.GroupingError = "42803" -> string! +const Npgsql.PostgresErrorCodes.HeldCursorRequiresSameIsolationLevel = "25008" -> string! +const Npgsql.PostgresErrorCodes.IdleSessionTimeout = "57P05" -> string! +const Npgsql.PostgresErrorCodes.ImplicitZeroBitPaddingWarning = "01008" -> string! +const Npgsql.PostgresErrorCodes.InappropriateAccessModeForBranchTransaction = "25003" -> string! +const Npgsql.PostgresErrorCodes.InappropriateIsolationLevelForBranchTransaction = "25004" -> string! +const Npgsql.PostgresErrorCodes.IndeterminateCollation = "42P22" -> string! +const Npgsql.PostgresErrorCodes.IndeterminateDatatype = "42P18" -> string! +const Npgsql.PostgresErrorCodes.IndexCorrupted = "XX002" -> string! +const Npgsql.PostgresErrorCodes.IndicatorOverflow = "22022" -> string! +const Npgsql.PostgresErrorCodes.InFailedSqlTransaction = "25P02" -> string! +const Npgsql.PostgresErrorCodes.InsufficientPrivilege = "42501" -> string! +const Npgsql.PostgresErrorCodes.InsufficientResources = "53000" -> string! +const Npgsql.PostgresErrorCodes.IntegrityConstraintViolation = "23000" -> string! +const Npgsql.PostgresErrorCodes.InternalError = "XX000" -> string! +const Npgsql.PostgresErrorCodes.IntervalFieldOverflow = "22015" -> string! +const Npgsql.PostgresErrorCodes.InvalidArgumentForLogarithm = "2201E" -> string! +const Npgsql.PostgresErrorCodes.InvalidArgumentForNthValueFunction = "22016" -> string! +const Npgsql.PostgresErrorCodes.InvalidArgumentForNtileFunction = "22014" -> string! +const Npgsql.PostgresErrorCodes.InvalidArgumentForPowerFunction = "2201F" -> string! +const Npgsql.PostgresErrorCodes.InvalidArgumentForWidthBucketFunction = "2201G" -> string! +const Npgsql.PostgresErrorCodes.InvalidAuthorizationSpecification = "28000" -> string! +const Npgsql.PostgresErrorCodes.InvalidBinaryRepresentation = "22P03" -> string! +const Npgsql.PostgresErrorCodes.InvalidCatalogName = "3D000" -> string! +const Npgsql.PostgresErrorCodes.InvalidCharacterValueForCast = "22018" -> string! +const Npgsql.PostgresErrorCodes.InvalidColumnDefinition = "42611" -> string! +const Npgsql.PostgresErrorCodes.InvalidColumnReference = "42P10" -> string! +const Npgsql.PostgresErrorCodes.InvalidCursorDefinition = "42P11" -> string! +const Npgsql.PostgresErrorCodes.InvalidCursorName = "34000" -> string! +const Npgsql.PostgresErrorCodes.InvalidCursorState = "24000" -> string! +const Npgsql.PostgresErrorCodes.InvalidDatabaseDefinition = "42P12" -> string! +const Npgsql.PostgresErrorCodes.InvalidDatetimeFormat = "22007" -> string! +const Npgsql.PostgresErrorCodes.InvalidEscapeCharacter = "22019" -> string! +const Npgsql.PostgresErrorCodes.InvalidEscapeOctet = "2200D" -> string! +const Npgsql.PostgresErrorCodes.InvalidEscapeSequence = "22025" -> string! +const Npgsql.PostgresErrorCodes.InvalidForeignKey = "42830" -> string! +const Npgsql.PostgresErrorCodes.InvalidFunctionDefinition = "42P13" -> string! +const Npgsql.PostgresErrorCodes.InvalidGrantOperation = "0LP01" -> string! +const Npgsql.PostgresErrorCodes.InvalidGrantor = "0L000" -> string! +const Npgsql.PostgresErrorCodes.InvalidIndicatorParameterValue = "22010" -> string! +const Npgsql.PostgresErrorCodes.InvalidLocatorSpecification = "0F001" -> string! +const Npgsql.PostgresErrorCodes.InvalidName = "42602" -> string! +const Npgsql.PostgresErrorCodes.InvalidObjectDefinition = "42P17" -> string! +const Npgsql.PostgresErrorCodes.InvalidParameterValue = "22023" -> string! +const Npgsql.PostgresErrorCodes.InvalidPassword = "28P01" -> string! +const Npgsql.PostgresErrorCodes.InvalidPreparedStatementDefinition = "42P14" -> string! +const Npgsql.PostgresErrorCodes.InvalidRecursion = "42P19" -> string! +const Npgsql.PostgresErrorCodes.InvalidRegularExpression = "2201B" -> string! +const Npgsql.PostgresErrorCodes.InvalidRoleSpecification = "0P000" -> string! +const Npgsql.PostgresErrorCodes.InvalidRowCountInLimitClause = "2201W" -> string! +const Npgsql.PostgresErrorCodes.InvalidRowCountInResultOffsetClause = "2201X" -> string! +const Npgsql.PostgresErrorCodes.InvalidSavepointSpecification = "3B001" -> string! +const Npgsql.PostgresErrorCodes.InvalidSchemaDefinition = "42P15" -> string! +const Npgsql.PostgresErrorCodes.InvalidSchemaName = "3F000" -> string! +const Npgsql.PostgresErrorCodes.InvalidSqlStatementName = "26000" -> string! +const Npgsql.PostgresErrorCodes.InvalidSqlstateReturnedExternalRoutineInvocationException = "39001" -> string! +const Npgsql.PostgresErrorCodes.InvalidTableDefinition = "42P16" -> string! +const Npgsql.PostgresErrorCodes.InvalidTablesampleArgument = "2202H" -> string! +const Npgsql.PostgresErrorCodes.InvalidTablesampleRepeat = "2202G" -> string! +const Npgsql.PostgresErrorCodes.InvalidTextRepresentation = "22P02" -> string! +const Npgsql.PostgresErrorCodes.InvalidTimeZoneDisplacementValue = "22009" -> string! +const Npgsql.PostgresErrorCodes.InvalidTransactionInitiation = "0B000" -> string! +const Npgsql.PostgresErrorCodes.InvalidTransactionState = "25000" -> string! +const Npgsql.PostgresErrorCodes.InvalidTransactionTermination = "2D000" -> string! +const Npgsql.PostgresErrorCodes.InvalidUseOfEscapeCharacter = "2200C" -> string! +const Npgsql.PostgresErrorCodes.InvalidXmlComment = "2200S" -> string! +const Npgsql.PostgresErrorCodes.InvalidXmlContent = "2200N" -> string! +const Npgsql.PostgresErrorCodes.InvalidXmlDocument = "2200M" -> string! +const Npgsql.PostgresErrorCodes.InvalidXmlProcessingInstruction = "2200T" -> string! +const Npgsql.PostgresErrorCodes.IoError = "58030" -> string! +const Npgsql.PostgresErrorCodes.LocatorException = "0F000" -> string! +const Npgsql.PostgresErrorCodes.LockFileExists = "F0001" -> string! +const Npgsql.PostgresErrorCodes.LockNotAvailable = "55P03" -> string! +const Npgsql.PostgresErrorCodes.ModifyingSqlDataNotPermittedExternalRoutineException = "38002" -> string! +const Npgsql.PostgresErrorCodes.ModifyingSqlDataNotPermittedSqlRoutineException = "2F002" -> string! +const Npgsql.PostgresErrorCodes.MostSpecificTypeMismatch = "2200G" -> string! +const Npgsql.PostgresErrorCodes.NameTooLong = "42622" -> string! +const Npgsql.PostgresErrorCodes.NoActiveSqlTransaction = "25P01" -> string! +const Npgsql.PostgresErrorCodes.NoActiveSqlTransactionForBranchTransaction = "25005" -> string! +const Npgsql.PostgresErrorCodes.NoAdditionalDynamicResultSetsReturned = "02001" -> string! +const Npgsql.PostgresErrorCodes.NoData = "02000" -> string! +const Npgsql.PostgresErrorCodes.NoDataFound = "P0002" -> string! +const Npgsql.PostgresErrorCodes.NonstandardUseOfEscapeCharacter = "22P06" -> string! +const Npgsql.PostgresErrorCodes.NotAnXmlDocument = "2200L" -> string! +const Npgsql.PostgresErrorCodes.NotNullViolation = "23502" -> string! +const Npgsql.PostgresErrorCodes.NullValueEliminatedInSetFunctionWarning = "01003" -> string! +const Npgsql.PostgresErrorCodes.NullValueNoIndicatorParameter = "22002" -> string! +const Npgsql.PostgresErrorCodes.NullValueNotAllowed = "22004" -> string! +const Npgsql.PostgresErrorCodes.NullValueNotAllowedExternalRoutineInvocationException = "39004" -> string! +const Npgsql.PostgresErrorCodes.NumericValueOutOfRange = "22003" -> string! +const Npgsql.PostgresErrorCodes.ObjectInUse = "55006" -> string! +const Npgsql.PostgresErrorCodes.ObjectNotInPrerequisiteState = "55000" -> string! +const Npgsql.PostgresErrorCodes.OperatorIntervention = "57000" -> string! +const Npgsql.PostgresErrorCodes.OutOfMemory = "53200" -> string! +const Npgsql.PostgresErrorCodes.PlpgsqlError = "P0000" -> string! +const Npgsql.PostgresErrorCodes.PrivilegeNotGrantedWarning = "01007" -> string! +const Npgsql.PostgresErrorCodes.PrivilegeNotRevokedWarning = "01006" -> string! +const Npgsql.PostgresErrorCodes.ProgramLimitExceeded = "54000" -> string! +const Npgsql.PostgresErrorCodes.ProhibitedSqlStatementAttemptedExternalRoutineException = "38003" -> string! +const Npgsql.PostgresErrorCodes.ProhibitedSqlStatementAttemptedSqlRoutineException = "2F003" -> string! +const Npgsql.PostgresErrorCodes.ProtocolViolation = "08P01" -> string! +const Npgsql.PostgresErrorCodes.QueryCanceled = "57014" -> string! +const Npgsql.PostgresErrorCodes.RaiseException = "P0001" -> string! +const Npgsql.PostgresErrorCodes.ReadingSqlDataNotPermittedExternalRoutineException = "38004" -> string! +const Npgsql.PostgresErrorCodes.ReadingSqlDataNotPermittedSqlRoutineException = "2F004" -> string! +const Npgsql.PostgresErrorCodes.ReadOnlySqlTransaction = "25006" -> string! +const Npgsql.PostgresErrorCodes.ReservedName = "42939" -> string! +const Npgsql.PostgresErrorCodes.RestrictViolation = "23001" -> string! +const Npgsql.PostgresErrorCodes.SavepointException = "3B000" -> string! +const Npgsql.PostgresErrorCodes.SchemaAndDataStatementMixingNotSupported = "25007" -> string! +const Npgsql.PostgresErrorCodes.SerializationFailure = "40001" -> string! +const Npgsql.PostgresErrorCodes.SnapshotFailure = "72000" -> string! +const Npgsql.PostgresErrorCodes.SqlClientUnableToEstablishSqlConnection = "08001" -> string! +const Npgsql.PostgresErrorCodes.SqlRoutineException = "2F000" -> string! +const Npgsql.PostgresErrorCodes.SqlServerRejectedEstablishmentOfSqlConnection = "08004" -> string! +const Npgsql.PostgresErrorCodes.SqlStatementNotYetComplete = "03000" -> string! +const Npgsql.PostgresErrorCodes.SrfProtocolViolatedExternalRoutineInvocationException = "39P02" -> string! +const Npgsql.PostgresErrorCodes.StackedDiagnosticsAccessedWithoutActiveHandler = "0Z002" -> string! +const Npgsql.PostgresErrorCodes.StatementCompletionUnknown = "40003" -> string! +const Npgsql.PostgresErrorCodes.StatementTooComplex = "54001" -> string! +const Npgsql.PostgresErrorCodes.StringDataLengthMismatch = "22026" -> string! +const Npgsql.PostgresErrorCodes.StringDataRightTruncation = "22001" -> string! +const Npgsql.PostgresErrorCodes.StringDataRightTruncationWarning = "01004" -> string! +const Npgsql.PostgresErrorCodes.SubstringError = "22011" -> string! +const Npgsql.PostgresErrorCodes.SuccessfulCompletion = "00000" -> string! +const Npgsql.PostgresErrorCodes.SyntaxError = "42601" -> string! +const Npgsql.PostgresErrorCodes.SyntaxErrorOrAccessRuleViolation = "42000" -> string! +const Npgsql.PostgresErrorCodes.SystemError = "58000" -> string! +const Npgsql.PostgresErrorCodes.TooManyArguments = "54023" -> string! +const Npgsql.PostgresErrorCodes.TooManyColumns = "54011" -> string! +const Npgsql.PostgresErrorCodes.TooManyConnections = "53300" -> string! +const Npgsql.PostgresErrorCodes.TooManyRows = "P0003" -> string! +const Npgsql.PostgresErrorCodes.TransactionIntegrityConstraintViolation = "40002" -> string! +const Npgsql.PostgresErrorCodes.TransactionResolutionUnknown = "08007" -> string! +const Npgsql.PostgresErrorCodes.TransactionRollback = "40000" -> string! +const Npgsql.PostgresErrorCodes.TriggeredActionException = "09000" -> string! +const Npgsql.PostgresErrorCodes.TriggeredDataChangeViolation = "27000" -> string! +const Npgsql.PostgresErrorCodes.TriggerProtocolViolatedExternalRoutineInvocationException = "39P01" -> string! +const Npgsql.PostgresErrorCodes.TrimError = "22027" -> string! +const Npgsql.PostgresErrorCodes.UndefinedColumn = "42703" -> string! +const Npgsql.PostgresErrorCodes.UndefinedFile = "58P01" -> string! +const Npgsql.PostgresErrorCodes.UndefinedFunction = "42883" -> string! +const Npgsql.PostgresErrorCodes.UndefinedObject = "42704" -> string! +const Npgsql.PostgresErrorCodes.UndefinedParameter = "42P02" -> string! +const Npgsql.PostgresErrorCodes.UndefinedTable = "42P01" -> string! +const Npgsql.PostgresErrorCodes.UniqueViolation = "23505" -> string! +const Npgsql.PostgresErrorCodes.UnterminatedCString = "22024" -> string! +const Npgsql.PostgresErrorCodes.UntranslatableCharacter = "22P05" -> string! +const Npgsql.PostgresErrorCodes.Warning = "01000" -> string! +const Npgsql.PostgresErrorCodes.WindowingError = "42P20" -> string! +const Npgsql.PostgresErrorCodes.WithCheckOptionViolation = "44000" -> string! +const Npgsql.PostgresErrorCodes.WrongObjectType = "42809" -> string! +const Npgsql.PostgresErrorCodes.ZeroLengthCharacterString = "2200F" -> string! +Npgsql.ArrayNullabilityMode +Npgsql.ArrayNullabilityMode.Always = 1 -> Npgsql.ArrayNullabilityMode +Npgsql.ArrayNullabilityMode.Never = 0 -> Npgsql.ArrayNullabilityMode +Npgsql.ArrayNullabilityMode.PerInstance = 2 -> Npgsql.ArrayNullabilityMode +Npgsql.BackendMessages.FieldDescription +Npgsql.BackendMessages.FieldDescription.TypeModifier.get -> int +Npgsql.BackendMessages.FieldDescription.TypeModifier.set -> void +Npgsql.BackendMessages.FieldDescription.TypeSize.get -> short +Npgsql.BackendMessages.FieldDescription.TypeSize.set -> void +Npgsql.ChannelBinding +Npgsql.ChannelBinding.Disable = 0 -> Npgsql.ChannelBinding +Npgsql.ChannelBinding.Prefer = 1 -> Npgsql.ChannelBinding +Npgsql.ChannelBinding.Require = 2 -> Npgsql.ChannelBinding +Npgsql.INpgsqlNameTranslator +Npgsql.INpgsqlNameTranslator.TranslateMemberName(string! clrName) -> string! +Npgsql.INpgsqlNameTranslator.TranslateTypeName(string! clrName) -> string! +Npgsql.NameTranslation.NpgsqlNullNameTranslator +Npgsql.NameTranslation.NpgsqlNullNameTranslator.NpgsqlNullNameTranslator() -> void +Npgsql.NameTranslation.NpgsqlNullNameTranslator.TranslateMemberName(string! clrName) -> string! +Npgsql.NameTranslation.NpgsqlNullNameTranslator.TranslateTypeName(string! clrName) -> string! +Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator +Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.NpgsqlSnakeCaseNameTranslator(bool legacyMode, System.Globalization.CultureInfo? culture = null) -> void +Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.NpgsqlSnakeCaseNameTranslator(System.Globalization.CultureInfo? culture = null) -> void +Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.TranslateMemberName(string! clrName) -> string! +Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.TranslateTypeName(string! clrName) -> string! +Npgsql.NoticeEventHandler +Npgsql.NotificationEventHandler +Npgsql.NpgsqlBatch +Npgsql.NpgsqlBatch.BatchCommands.get -> Npgsql.NpgsqlBatchCommandCollection! +Npgsql.NpgsqlBatch.Connection.get -> Npgsql.NpgsqlConnection? +Npgsql.NpgsqlBatch.Connection.set -> void +Npgsql.NpgsqlBatch.CreateBatchCommand() -> Npgsql.NpgsqlBatchCommand! +Npgsql.NpgsqlBatch.EnableErrorBarriers.get -> bool +Npgsql.NpgsqlBatch.EnableErrorBarriers.set -> void +Npgsql.NpgsqlBatch.ExecuteReader(System.Data.CommandBehavior behavior = System.Data.CommandBehavior.Default) -> Npgsql.NpgsqlDataReader! +Npgsql.NpgsqlBatch.ExecuteReaderAsync(System.Data.CommandBehavior behavior, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBatch.ExecuteReaderAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBatch.NpgsqlBatch(Npgsql.NpgsqlConnection? connection = null, Npgsql.NpgsqlTransaction? transaction = null) -> void +Npgsql.NpgsqlBatch.Transaction.get -> Npgsql.NpgsqlTransaction? +Npgsql.NpgsqlBatch.Transaction.set -> void +Npgsql.NpgsqlBatchCommand +Npgsql.NpgsqlBatchCommand.AppendErrorBarrier.get -> bool? +Npgsql.NpgsqlBatchCommand.AppendErrorBarrier.set -> void +Npgsql.NpgsqlBatchCommand.NpgsqlBatchCommand() -> void +Npgsql.NpgsqlBatchCommand.NpgsqlBatchCommand(string! commandText) -> void +Npgsql.NpgsqlBatchCommand.OID.get -> uint +Npgsql.NpgsqlBatchCommand.Parameters.get -> Npgsql.NpgsqlParameterCollection! +Npgsql.NpgsqlBatchCommand.Rows.get -> ulong +Npgsql.NpgsqlBatchCommand.StatementType.get -> Npgsql.StatementType +Npgsql.NpgsqlBatchCommandCollection +Npgsql.NpgsqlBatchCommandCollection.Add(Npgsql.NpgsqlBatchCommand! item) -> void +Npgsql.NpgsqlBatchCommandCollection.Contains(Npgsql.NpgsqlBatchCommand! item) -> bool +Npgsql.NpgsqlBatchCommandCollection.CopyTo(Npgsql.NpgsqlBatchCommand![]! array, int arrayIndex) -> void +Npgsql.NpgsqlBatchCommandCollection.IndexOf(Npgsql.NpgsqlBatchCommand! item) -> int +Npgsql.NpgsqlBatchCommandCollection.Insert(int index, Npgsql.NpgsqlBatchCommand! item) -> void +Npgsql.NpgsqlBatchCommandCollection.Remove(Npgsql.NpgsqlBatchCommand! item) -> bool +Npgsql.NpgsqlBatchCommandCollection.this[int index].get -> Npgsql.NpgsqlBatchCommand! +Npgsql.NpgsqlBatchCommandCollection.this[int index].set -> void +Npgsql.NpgsqlBinaryExporter +Npgsql.NpgsqlBinaryExporter.Cancel() -> void +Npgsql.NpgsqlBinaryExporter.CancelAsync() -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryExporter.Dispose() -> void +Npgsql.NpgsqlBinaryExporter.DisposeAsync() -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryExporter.IsNull.get -> bool +Npgsql.NpgsqlBinaryExporter.Read() -> T +Npgsql.NpgsqlBinaryExporter.Read(NpgsqlTypes.NpgsqlDbType type) -> T +Npgsql.NpgsqlBinaryExporter.ReadAsync(NpgsqlTypes.NpgsqlDbType type, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryExporter.ReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryExporter.Skip() -> void +Npgsql.NpgsqlBinaryExporter.SkipAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryExporter.StartRow() -> int +Npgsql.NpgsqlBinaryExporter.StartRowAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryExporter.Timeout.set -> void +Npgsql.NpgsqlBinaryImporter +Npgsql.NpgsqlBinaryImporter.Close() -> void +Npgsql.NpgsqlBinaryImporter.CloseAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryImporter.Complete() -> ulong +Npgsql.NpgsqlBinaryImporter.CompleteAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryImporter.Dispose() -> void +Npgsql.NpgsqlBinaryImporter.DisposeAsync() -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlBinaryImporter.StartRow() -> void +Npgsql.NpgsqlBinaryImporter.StartRowAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryImporter.Timeout.set -> void +Npgsql.NpgsqlBinaryImporter.Write(T value) -> void +Npgsql.NpgsqlBinaryImporter.Write(T value, NpgsqlTypes.NpgsqlDbType npgsqlDbType) -> void +Npgsql.NpgsqlBinaryImporter.Write(T value, string! dataTypeName) -> void +Npgsql.NpgsqlBinaryImporter.WriteAsync(T value, NpgsqlTypes.NpgsqlDbType npgsqlDbType, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryImporter.WriteAsync(T value, string! dataTypeName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryImporter.WriteAsync(T value, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryImporter.WriteNull() -> void +Npgsql.NpgsqlBinaryImporter.WriteNullAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlBinaryImporter.WriteRow(params object?[]! values) -> void +Npgsql.NpgsqlBinaryImporter.WriteRowAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken), params object?[]! values) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlCommand +Npgsql.NpgsqlCommand.AllResultTypesAreUnknown.get -> bool +Npgsql.NpgsqlCommand.AllResultTypesAreUnknown.set -> void +Npgsql.NpgsqlCommand.Connection.get -> Npgsql.NpgsqlConnection? +Npgsql.NpgsqlCommand.Connection.set -> void +Npgsql.NpgsqlCommand.CreateParameter() -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlCommand.Disposed -> System.EventHandler? +Npgsql.NpgsqlCommand.ExecuteReader(System.Data.CommandBehavior behavior = System.Data.CommandBehavior.Default) -> Npgsql.NpgsqlDataReader! +Npgsql.NpgsqlCommand.ExecuteReaderAsync(System.Data.CommandBehavior behavior, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlCommand.ExecuteReaderAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlCommand.IsPrepared.get -> bool +Npgsql.NpgsqlCommand.NpgsqlCommand() -> void +Npgsql.NpgsqlCommand.NpgsqlCommand(string? cmdText) -> void +Npgsql.NpgsqlCommand.NpgsqlCommand(string? cmdText, Npgsql.NpgsqlConnection? connection) -> void +Npgsql.NpgsqlCommand.NpgsqlCommand(string? cmdText, Npgsql.NpgsqlConnection? connection, Npgsql.NpgsqlTransaction? transaction) -> void +Npgsql.NpgsqlCommand.Parameters.get -> Npgsql.NpgsqlParameterCollection! +Npgsql.NpgsqlCommand.Statements.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.NpgsqlCommand.Transaction.get -> Npgsql.NpgsqlTransaction? +Npgsql.NpgsqlCommand.Transaction.set -> void +Npgsql.NpgsqlCommand.UnknownResultTypeList.get -> bool[]? +Npgsql.NpgsqlCommand.UnknownResultTypeList.set -> void +Npgsql.NpgsqlCommand.Unprepare() -> void +Npgsql.NpgsqlCommand.UnprepareAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlCommandBuilder +Npgsql.NpgsqlCommandBuilder.GetDeleteCommand() -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.GetDeleteCommand(bool useColumnsForParameterNames) -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.GetInsertCommand() -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.GetInsertCommand(bool useColumnsForParameterNames) -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.GetUpdateCommand() -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.GetUpdateCommand(bool useColumnsForParameterNames) -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlCommandBuilder.NpgsqlCommandBuilder() -> void +Npgsql.NpgsqlCommandBuilder.NpgsqlCommandBuilder(Npgsql.NpgsqlDataAdapter? adapter) -> void +Npgsql.NpgsqlConnection +Npgsql.NpgsqlConnection.BeginBinaryExport(string! copyToCommand) -> Npgsql.NpgsqlBinaryExporter! +Npgsql.NpgsqlConnection.BeginBinaryExportAsync(string! copyToCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginBinaryImport(string! copyFromCommand) -> Npgsql.NpgsqlBinaryImporter! +Npgsql.NpgsqlConnection.BeginBinaryImportAsync(string! copyFromCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginRawBinaryCopy(string! copyCommand) -> Npgsql.NpgsqlRawCopyStream! +Npgsql.NpgsqlConnection.BeginRawBinaryCopyAsync(string! copyCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginTextExport(string! copyToCommand) -> System.IO.TextReader! +Npgsql.NpgsqlConnection.BeginTextExportAsync(string! copyToCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginTextImport(string! copyFromCommand) -> System.IO.TextWriter! +Npgsql.NpgsqlConnection.BeginTextImportAsync(string! copyFromCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginTransaction() -> Npgsql.NpgsqlTransaction! +Npgsql.NpgsqlConnection.BeginTransaction(System.Data.IsolationLevel level) -> Npgsql.NpgsqlTransaction! +Npgsql.NpgsqlConnection.BeginTransactionAsync(System.Data.IsolationLevel level, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlConnection.BeginTransactionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlConnection.CloneWith(string! connectionString) -> Npgsql.NpgsqlConnection! +Npgsql.NpgsqlConnection.CommandTimeout.get -> int +Npgsql.NpgsqlConnection.CreateBatch() -> Npgsql.NpgsqlBatch! +Npgsql.NpgsqlConnection.CreateCommand() -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlConnection.Disposed -> System.EventHandler? +Npgsql.NpgsqlConnection.FullState.get -> System.Data.ConnectionState +Npgsql.NpgsqlConnection.HasIntegerDateTimes.get -> bool +Npgsql.NpgsqlConnection.Host.get -> string? +Npgsql.NpgsqlConnection.Notice -> Npgsql.NoticeEventHandler? +Npgsql.NpgsqlConnection.Notification -> Npgsql.NotificationEventHandler? +Npgsql.NpgsqlConnection.NpgsqlConnection() -> void +Npgsql.NpgsqlConnection.NpgsqlConnection(string? connectionString) -> void +Npgsql.NpgsqlConnection.Port.get -> int +Npgsql.NpgsqlConnection.PostgresParameters.get -> System.Collections.Generic.IReadOnlyDictionary! +Npgsql.NpgsqlConnection.PostgreSqlVersion.get -> System.Version! +Npgsql.NpgsqlConnection.ProcessID.get -> int +Npgsql.NpgsqlConnection.ProvideClientCertificatesCallback.get -> Npgsql.ProvideClientCertificatesCallback? +Npgsql.NpgsqlConnection.ProvideClientCertificatesCallback.set -> void +Npgsql.NpgsqlConnection.ProvidePasswordCallback.get -> Npgsql.ProvidePasswordCallback? +Npgsql.NpgsqlConnection.ProvidePasswordCallback.set -> void +Npgsql.NpgsqlConnection.ReloadTypes() -> void +Npgsql.NpgsqlConnection.ReloadTypesAsync() -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.Timezone.get -> string! +Npgsql.NpgsqlConnection.TypeMapper.get -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlConnection.UnprepareAll() -> void +Npgsql.NpgsqlConnection.UserCertificateValidationCallback.get -> System.Net.Security.RemoteCertificateValidationCallback? +Npgsql.NpgsqlConnection.UserCertificateValidationCallback.set -> void +Npgsql.NpgsqlConnection.UserName.get -> string? +Npgsql.NpgsqlConnection.Wait() -> void +Npgsql.NpgsqlConnection.Wait(int timeout) -> bool +Npgsql.NpgsqlConnection.Wait(System.TimeSpan timeout) -> bool +Npgsql.NpgsqlConnection.WaitAsync(int timeout, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.WaitAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.WaitAsync(System.TimeSpan timeout, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnectionStringBuilder +Npgsql.NpgsqlConnectionStringBuilder.Add(System.Collections.Generic.KeyValuePair item) -> void +Npgsql.NpgsqlConnectionStringBuilder.ApplicationName.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.ApplicationName.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ArrayNullabilityMode.get -> Npgsql.ArrayNullabilityMode +Npgsql.NpgsqlConnectionStringBuilder.ArrayNullabilityMode.set -> void +Npgsql.NpgsqlConnectionStringBuilder.AutoPrepareMinUsages.get -> int +Npgsql.NpgsqlConnectionStringBuilder.AutoPrepareMinUsages.set -> void +Npgsql.NpgsqlConnectionStringBuilder.CancellationTimeout.get -> int +Npgsql.NpgsqlConnectionStringBuilder.CancellationTimeout.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ChannelBinding.get -> Npgsql.ChannelBinding +Npgsql.NpgsqlConnectionStringBuilder.ChannelBinding.set -> void +Npgsql.NpgsqlConnectionStringBuilder.CheckCertificateRevocation.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.CheckCertificateRevocation.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ClientEncoding.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.ClientEncoding.set -> void +Npgsql.NpgsqlConnectionStringBuilder.CommandTimeout.get -> int +Npgsql.NpgsqlConnectionStringBuilder.CommandTimeout.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ConnectionIdleLifetime.get -> int +Npgsql.NpgsqlConnectionStringBuilder.ConnectionIdleLifetime.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ConnectionLifetime.get -> int +Npgsql.NpgsqlConnectionStringBuilder.ConnectionLifetime.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ConnectionPruningInterval.get -> int +Npgsql.NpgsqlConnectionStringBuilder.ConnectionPruningInterval.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Contains(System.Collections.Generic.KeyValuePair item) -> bool +Npgsql.NpgsqlConnectionStringBuilder.CopyTo(System.Collections.Generic.KeyValuePair[]! array, int arrayIndex) -> void +Npgsql.NpgsqlConnectionStringBuilder.Database.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Database.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Encoding.get -> string! +Npgsql.NpgsqlConnectionStringBuilder.Encoding.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Enlist.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.Enlist.set -> void +Npgsql.NpgsqlConnectionStringBuilder.GetEnumerator() -> System.Collections.Generic.IEnumerator>! +Npgsql.NpgsqlConnectionStringBuilder.Host.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Host.set -> void +Npgsql.NpgsqlConnectionStringBuilder.HostRecheckSeconds.get -> int +Npgsql.NpgsqlConnectionStringBuilder.HostRecheckSeconds.set -> void +Npgsql.NpgsqlConnectionStringBuilder.IncludeErrorDetail.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.IncludeErrorDetail.set -> void +Npgsql.NpgsqlConnectionStringBuilder.IncludeRealm.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.IncludeRealm.set -> void +Npgsql.NpgsqlConnectionStringBuilder.InternalCommandTimeout.get -> int +Npgsql.NpgsqlConnectionStringBuilder.InternalCommandTimeout.set -> void +Npgsql.NpgsqlConnectionStringBuilder.KeepAlive.get -> int +Npgsql.NpgsqlConnectionStringBuilder.KeepAlive.set -> void +Npgsql.NpgsqlConnectionStringBuilder.KerberosServiceName.get -> string! +Npgsql.NpgsqlConnectionStringBuilder.KerberosServiceName.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Keys.get -> System.Collections.Generic.ICollection! +Npgsql.NpgsqlConnectionStringBuilder.LoadBalanceHosts.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.LoadBalanceHosts.set -> void +Npgsql.NpgsqlConnectionStringBuilder.LoadTableComposites.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.LoadTableComposites.set -> void +Npgsql.NpgsqlConnectionStringBuilder.LogParameters.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.LogParameters.set -> void +Npgsql.NpgsqlConnectionStringBuilder.MaxAutoPrepare.get -> int +Npgsql.NpgsqlConnectionStringBuilder.MaxAutoPrepare.set -> void +Npgsql.NpgsqlConnectionStringBuilder.MaxPoolSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.MaxPoolSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.MinPoolSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.MinPoolSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Multiplexing.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.Multiplexing.set -> void +Npgsql.NpgsqlConnectionStringBuilder.NoResetOnClose.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.NoResetOnClose.set -> void +Npgsql.NpgsqlConnectionStringBuilder.NpgsqlConnectionStringBuilder() -> void +Npgsql.NpgsqlConnectionStringBuilder.NpgsqlConnectionStringBuilder(bool useOdbcRules) -> void +Npgsql.NpgsqlConnectionStringBuilder.NpgsqlConnectionStringBuilder(string? connectionString) -> void +Npgsql.NpgsqlConnectionStringBuilder.Options.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Options.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Passfile.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Passfile.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Password.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Password.set -> void +Npgsql.NpgsqlConnectionStringBuilder.PersistSecurityInfo.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.PersistSecurityInfo.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Pooling.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.Pooling.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Port.get -> int +Npgsql.NpgsqlConnectionStringBuilder.Port.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ReadBufferSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.ReadBufferSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Remove(System.Collections.Generic.KeyValuePair item) -> bool +Npgsql.NpgsqlConnectionStringBuilder.RootCertificate.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.RootCertificate.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SearchPath.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.SearchPath.set -> void +Npgsql.NpgsqlConnectionStringBuilder.ServerCompatibilityMode.get -> Npgsql.ServerCompatibilityMode +Npgsql.NpgsqlConnectionStringBuilder.ServerCompatibilityMode.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SocketReceiveBufferSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.SocketReceiveBufferSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SocketSendBufferSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.SocketSendBufferSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SslCertificate.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.SslCertificate.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SslKey.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.SslKey.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SslMode.get -> Npgsql.SslMode +Npgsql.NpgsqlConnectionStringBuilder.SslMode.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SslPassword.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.SslPassword.set -> void +Npgsql.NpgsqlConnectionStringBuilder.TargetSessionAttributes.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.TargetSessionAttributes.set -> void +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAlive.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAlive.set -> void +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAliveInterval.get -> int +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAliveInterval.set -> void +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAliveTime.get -> int +Npgsql.NpgsqlConnectionStringBuilder.TcpKeepAliveTime.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Timeout.get -> int +Npgsql.NpgsqlConnectionStringBuilder.Timeout.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Timezone.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Timezone.set -> void +Npgsql.NpgsqlConnectionStringBuilder.TrustServerCertificate.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.TrustServerCertificate.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Username.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.Username.set -> void +Npgsql.NpgsqlConnectionStringBuilder.Values.get -> System.Collections.Generic.ICollection! +Npgsql.NpgsqlConnectionStringBuilder.WriteBufferSize.get -> int +Npgsql.NpgsqlConnectionStringBuilder.WriteBufferSize.set -> void +Npgsql.NpgsqlConnectionStringBuilder.WriteCoalescingBufferThresholdBytes.get -> int +Npgsql.NpgsqlConnectionStringBuilder.WriteCoalescingBufferThresholdBytes.set -> void +Npgsql.NpgsqlCopyTextReader +Npgsql.NpgsqlCopyTextReader.Cancel() -> void +Npgsql.NpgsqlCopyTextReader.CancelAsync() -> System.Threading.Tasks.Task! +Npgsql.NpgsqlCopyTextReader.DisposeAsync() -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlCopyTextWriter +Npgsql.NpgsqlCopyTextWriter.Cancel() -> void +Npgsql.NpgsqlCopyTextWriter.CancelAsync() -> System.Threading.Tasks.Task! +Npgsql.NpgsqlDataAdapter +Npgsql.NpgsqlDataAdapter.DeleteCommand.get -> Npgsql.NpgsqlCommand? +Npgsql.NpgsqlDataAdapter.DeleteCommand.set -> void +Npgsql.NpgsqlDataAdapter.InsertCommand.get -> Npgsql.NpgsqlCommand? +Npgsql.NpgsqlDataAdapter.InsertCommand.set -> void +Npgsql.NpgsqlDataAdapter.NpgsqlDataAdapter() -> void +Npgsql.NpgsqlDataAdapter.NpgsqlDataAdapter(Npgsql.NpgsqlCommand! selectCommand) -> void +Npgsql.NpgsqlDataAdapter.NpgsqlDataAdapter(string! selectCommandText, Npgsql.NpgsqlConnection! selectConnection) -> void +Npgsql.NpgsqlDataAdapter.NpgsqlDataAdapter(string! selectCommandText, string! selectConnectionString) -> void +Npgsql.NpgsqlDataAdapter.RowUpdated -> Npgsql.NpgsqlRowUpdatedEventHandler? +Npgsql.NpgsqlDataAdapter.RowUpdating -> Npgsql.NpgsqlRowUpdatingEventHandler? +Npgsql.NpgsqlDataAdapter.SelectCommand.get -> Npgsql.NpgsqlCommand? +Npgsql.NpgsqlDataAdapter.SelectCommand.set -> void +Npgsql.NpgsqlDataAdapter.UpdateCommand.get -> Npgsql.NpgsqlCommand? +Npgsql.NpgsqlDataAdapter.UpdateCommand.set -> void +Npgsql.NpgsqlDataReader +Npgsql.NpgsqlDataReader.GetColumnSchema() -> System.Collections.ObjectModel.ReadOnlyCollection! +Npgsql.NpgsqlDataReader.GetColumnSchemaAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!>! +Npgsql.NpgsqlDataReader.GetData(int ordinal) -> Npgsql.NpgsqlNestedDataReader! +Npgsql.NpgsqlDataReader.GetDataTypeOID(int ordinal) -> uint +Npgsql.NpgsqlDataReader.GetPostgresType(int ordinal) -> Npgsql.PostgresTypes.PostgresType! +Npgsql.NpgsqlDataReader.GetStreamAsync(int ordinal, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlDataReader.GetTextReaderAsync(int ordinal, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlDataReader.GetTimeSpan(int ordinal) -> System.TimeSpan +Npgsql.NpgsqlDataReader.IsOnRow.get -> bool +Npgsql.NpgsqlDataReader.ReaderClosed -> System.EventHandler? +Npgsql.NpgsqlDataReader.Rows.get -> ulong +Npgsql.NpgsqlDataReader.Statements.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.NpgsqlDataSource +Npgsql.NpgsqlDataSource.CreateBatch() -> Npgsql.NpgsqlBatch! +Npgsql.NpgsqlDataSource.CreateCommand(string? commandText = null) -> Npgsql.NpgsqlCommand! +Npgsql.NpgsqlDataSource.CreateConnection() -> Npgsql.NpgsqlConnection! +Npgsql.NpgsqlDataSource.OpenConnection() -> Npgsql.NpgsqlConnection! +Npgsql.NpgsqlDataSource.OpenConnectionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlDataSource.Password.set -> void +Npgsql.NpgsqlDataSourceBuilder +Npgsql.NpgsqlDataSourceBuilder.AddTypeInfoResolverFactory(Npgsql.Internal.PgTypeInfoResolverFactory! factory) -> void +Npgsql.NpgsqlDataSourceBuilder.Build() -> Npgsql.NpgsqlDataSource! +Npgsql.NpgsqlDataSourceBuilder.BuildMultiHost() -> Npgsql.NpgsqlMultiHostDataSource! +Npgsql.NpgsqlDataSourceBuilder.ConfigureJsonOptions(System.Text.Json.JsonSerializerOptions! serializerOptions) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.ConnectionString.get -> string! +Npgsql.NpgsqlDataSourceBuilder.ConnectionStringBuilder.get -> Npgsql.NpgsqlConnectionStringBuilder! +Npgsql.NpgsqlDataSourceBuilder.DefaultNameTranslator.get -> Npgsql.INpgsqlNameTranslator! +Npgsql.NpgsqlDataSourceBuilder.DefaultNameTranslator.set -> void +Npgsql.NpgsqlDataSourceBuilder.EnableDynamicJson(System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.EnableParameterLogging(bool parameterLoggingEnabled = true) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.EnableRecordsAsTuples() -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.EnableUnmappedTypes() -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlDataSourceBuilder.Name.get -> string? +Npgsql.NpgsqlDataSourceBuilder.Name.set -> void +Npgsql.NpgsqlDataSourceBuilder.NpgsqlDataSourceBuilder(string? connectionString = null) -> void +Npgsql.NpgsqlDataSourceBuilder.UnmapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlDataSourceBuilder.UnmapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlDataSourceBuilder.UnmapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlDataSourceBuilder.UnmapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlDataSourceBuilder.UseClientCertificate(System.Security.Cryptography.X509Certificates.X509Certificate? clientCertificate) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseClientCertificates(System.Security.Cryptography.X509Certificates.X509CertificateCollection? clientCertificates) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseClientCertificatesCallback(System.Action? clientCertificatesCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseLoggerFactory(Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UsePasswordProvider(System.Func? passwordProvider, System.Func>? passwordProviderAsync) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UsePeriodicPasswordProvider(System.Func>? passwordProvider, System.TimeSpan successRefreshInterval, System.TimeSpan failureRefreshInterval) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UsePhysicalConnectionInitializer(System.Action? connectionInitializer, System.Func? connectionInitializerAsync) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseRootCertificate(System.Security.Cryptography.X509Certificates.X509Certificate2? rootCertificate) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseRootCertificateCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseUserCertificateValidationCallback(System.Net.Security.RemoteCertificateValidationCallback! userCertificateValidationCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlException +Npgsql.NpgsqlException.BatchCommand.get -> Npgsql.NpgsqlBatchCommand? +Npgsql.NpgsqlException.BatchCommand.set -> void +Npgsql.NpgsqlException.NpgsqlException() -> void +Npgsql.NpgsqlException.NpgsqlException(string? message) -> void +Npgsql.NpgsqlException.NpgsqlException(string? message, System.Exception? innerException) -> void +Npgsql.NpgsqlException.NpgsqlException(System.Runtime.Serialization.SerializationInfo! info, System.Runtime.Serialization.StreamingContext context) -> void +Npgsql.NpgsqlFactory +Npgsql.NpgsqlFactory.GetService(System.Type! serviceType) -> object? +Npgsql.NpgsqlLargeObjectManager +Npgsql.NpgsqlLargeObjectManager.Create(uint preferredOid = 0) -> uint +Npgsql.NpgsqlLargeObjectManager.CreateAsync(uint preferredOid, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectManager.ExportRemote(uint oid, string! path) -> void +Npgsql.NpgsqlLargeObjectManager.ExportRemoteAsync(uint oid, string! path, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectManager.Has64BitSupport.get -> bool +Npgsql.NpgsqlLargeObjectManager.ImportRemote(string! path, uint oid = 0) -> void +Npgsql.NpgsqlLargeObjectManager.ImportRemoteAsync(string! path, uint oid, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectManager.MaxTransferBlockSize.get -> int +Npgsql.NpgsqlLargeObjectManager.MaxTransferBlockSize.set -> void +Npgsql.NpgsqlLargeObjectManager.NpgsqlLargeObjectManager(Npgsql.NpgsqlConnection! connection) -> void +Npgsql.NpgsqlLargeObjectManager.OpenRead(uint oid) -> Npgsql.NpgsqlLargeObjectStream! +Npgsql.NpgsqlLargeObjectManager.OpenReadAsync(uint oid, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectManager.OpenReadWrite(uint oid) -> Npgsql.NpgsqlLargeObjectStream! +Npgsql.NpgsqlLargeObjectManager.OpenReadWriteAsync(uint oid, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectManager.Unlink(uint oid) -> void +Npgsql.NpgsqlLargeObjectManager.UnlinkAsync(uint oid, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectStream +Npgsql.NpgsqlLargeObjectStream.GetLengthAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectStream.Has64BitSupport.get -> bool +Npgsql.NpgsqlLargeObjectStream.SeekAsync(long offset, System.IO.SeekOrigin origin, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLargeObjectStream.SetLength(long value, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlLoggingConfiguration +Npgsql.NpgsqlMultiHostDataSource +Npgsql.NpgsqlMultiHostDataSource.ClearDatabaseStates() -> void +Npgsql.NpgsqlMultiHostDataSource.CreateConnection(Npgsql.TargetSessionAttributes targetSessionAttributes) -> Npgsql.NpgsqlConnection! +Npgsql.NpgsqlMultiHostDataSource.OpenConnection(Npgsql.TargetSessionAttributes targetSessionAttributes) -> Npgsql.NpgsqlConnection! +Npgsql.NpgsqlMultiHostDataSource.OpenConnectionAsync(Npgsql.TargetSessionAttributes targetSessionAttributes, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlMultiHostDataSource.WithTargetSession(Npgsql.TargetSessionAttributes targetSessionAttributes) -> Npgsql.NpgsqlDataSource! +Npgsql.NpgsqlNestedDataReader +Npgsql.NpgsqlNestedDataReader.GetData(int ordinal) -> Npgsql.NpgsqlNestedDataReader! +Npgsql.NpgsqlNoticeEventArgs +Npgsql.NpgsqlNoticeEventArgs.Notice.get -> Npgsql.PostgresNotice! +Npgsql.NpgsqlNotificationEventArgs +Npgsql.NpgsqlNotificationEventArgs.Channel.get -> string! +Npgsql.NpgsqlNotificationEventArgs.Payload.get -> string! +Npgsql.NpgsqlNotificationEventArgs.PID.get -> int +Npgsql.NpgsqlOperationInProgressException +Npgsql.NpgsqlOperationInProgressException.CommandInProgress.get -> Npgsql.NpgsqlCommand? +Npgsql.NpgsqlOperationInProgressException.NpgsqlOperationInProgressException(Npgsql.NpgsqlCommand! command) -> void +Npgsql.NpgsqlParameter +Npgsql.NpgsqlParameter.Clone() -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameter.Collection.get -> Npgsql.NpgsqlParameterCollection? +Npgsql.NpgsqlParameter.Collection.set -> void +Npgsql.NpgsqlParameter.DataTypeName.get -> string? +Npgsql.NpgsqlParameter.DataTypeName.set -> void +Npgsql.NpgsqlParameter.NpgsqlDbType.get -> NpgsqlTypes.NpgsqlDbType +Npgsql.NpgsqlParameter.NpgsqlDbType.set -> void +Npgsql.NpgsqlParameter.NpgsqlParameter() -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size, string? sourceColumn, System.Data.ParameterDirection direction, bool isNullable, byte precision, byte scale, System.Data.DataRowVersion sourceVersion, object! value) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string! parameterName, System.Data.DbType parameterType, int size, string? sourceColumn, System.Data.ParameterDirection direction, bool isNullable, byte precision, byte scale, System.Data.DataRowVersion sourceVersion, object! value) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, NpgsqlTypes.NpgsqlDbType parameterType) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size, string? sourceColumn) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, object? value) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, System.Data.DbType parameterType) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, System.Data.DbType parameterType, int size) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string? parameterName, System.Data.DbType parameterType, int size, string? sourceColumn) -> void +Npgsql.NpgsqlParameter.NpgsqlValue.get -> object? +Npgsql.NpgsqlParameter.NpgsqlValue.set -> void +Npgsql.NpgsqlParameter.PostgresType.get -> Npgsql.PostgresTypes.PostgresType? +Npgsql.NpgsqlParameter.Precision.get -> byte +Npgsql.NpgsqlParameter.Precision.set -> void +Npgsql.NpgsqlParameter.Scale.get -> byte +Npgsql.NpgsqlParameter.Scale.set -> void +Npgsql.NpgsqlParameter +Npgsql.NpgsqlParameter.NpgsqlParameter() -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string! parameterName, NpgsqlTypes.NpgsqlDbType npgsqlDbType) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string! parameterName, System.Data.DbType dbType) -> void +Npgsql.NpgsqlParameter.NpgsqlParameter(string! parameterName, T value) -> void +Npgsql.NpgsqlParameter.TypedValue.get -> T? +Npgsql.NpgsqlParameter.TypedValue.set -> void +Npgsql.NpgsqlParameterCollection +Npgsql.NpgsqlParameterCollection.Add(Npgsql.NpgsqlParameter! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.Add(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.Add(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.Add(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size, string! sourceColumn) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(NpgsqlTypes.NpgsqlDbType parameterType, object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size, object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, int size, string? sourceColumn, object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(string! parameterName, NpgsqlTypes.NpgsqlDbType parameterType, object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.AddWithValue(string! parameterName, object! value) -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.Contains(Npgsql.NpgsqlParameter! item) -> bool +Npgsql.NpgsqlParameterCollection.CopyTo(Npgsql.NpgsqlParameter![]! array, int arrayIndex) -> void +Npgsql.NpgsqlParameterCollection.IndexOf(Npgsql.NpgsqlParameter! item) -> int +Npgsql.NpgsqlParameterCollection.Insert(int index, Npgsql.NpgsqlParameter! item) -> void +Npgsql.NpgsqlParameterCollection.Remove(Npgsql.NpgsqlParameter! item) -> bool +Npgsql.NpgsqlParameterCollection.Remove(string! parameterName) -> void +Npgsql.NpgsqlParameterCollection.this[int index].get -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.this[int index].set -> void +Npgsql.NpgsqlParameterCollection.this[string! parameterName].get -> Npgsql.NpgsqlParameter! +Npgsql.NpgsqlParameterCollection.this[string! parameterName].set -> void +Npgsql.NpgsqlParameterCollection.ToArray() -> Npgsql.NpgsqlParameter![]! +Npgsql.NpgsqlParameterCollection.TryGetValue(string! parameterName, out Npgsql.NpgsqlParameter? parameter) -> bool +Npgsql.NpgsqlRawCopyStream +Npgsql.NpgsqlRawCopyStream.Cancel() -> void +Npgsql.NpgsqlRawCopyStream.CancelAsync() -> System.Threading.Tasks.Task! +Npgsql.NpgsqlRowUpdatedEventArgs +Npgsql.NpgsqlRowUpdatedEventArgs.NpgsqlRowUpdatedEventArgs(System.Data.DataRow! dataRow, System.Data.IDbCommand? command, System.Data.StatementType statementType, System.Data.Common.DataTableMapping! tableMapping) -> void +Npgsql.NpgsqlRowUpdatedEventHandler +Npgsql.NpgsqlRowUpdatingEventArgs +Npgsql.NpgsqlRowUpdatingEventArgs.NpgsqlRowUpdatingEventArgs(System.Data.DataRow! dataRow, System.Data.IDbCommand? command, System.Data.StatementType statementType, System.Data.Common.DataTableMapping! tableMapping) -> void +Npgsql.NpgsqlRowUpdatingEventHandler +Npgsql.NpgsqlSlimDataSourceBuilder +Npgsql.NpgsqlSlimDataSourceBuilder.AddTypeInfoResolverFactory(Npgsql.Internal.PgTypeInfoResolverFactory! factory) -> void +Npgsql.NpgsqlSlimDataSourceBuilder.Build() -> Npgsql.NpgsqlDataSource! +Npgsql.NpgsqlSlimDataSourceBuilder.BuildMultiHost() -> Npgsql.NpgsqlMultiHostDataSource! +Npgsql.NpgsqlSlimDataSourceBuilder.ConfigureJsonOptions(System.Text.Json.JsonSerializerOptions! serializerOptions) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.ConnectionString.get -> string! +Npgsql.NpgsqlSlimDataSourceBuilder.ConnectionStringBuilder.get -> Npgsql.NpgsqlConnectionStringBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.DefaultNameTranslator.get -> Npgsql.INpgsqlNameTranslator! +Npgsql.NpgsqlSlimDataSourceBuilder.DefaultNameTranslator.set -> void +Npgsql.NpgsqlSlimDataSourceBuilder.EnableArrays() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableDynamicJson(System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableExtraConversions() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableFullTextSearch() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableIntegratedSecurity() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableLTree() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableMultiranges() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableParameterLogging(bool parameterLoggingEnabled = true) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableRanges() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableRecords() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableRecordsAsTuples() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableTransportSecurity() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.NpgsqlSlimDataSourceBuilder.Name.get -> string? +Npgsql.NpgsqlSlimDataSourceBuilder.Name.set -> void +Npgsql.NpgsqlSlimDataSourceBuilder.NpgsqlSlimDataSourceBuilder(string? connectionString = null) -> void +Npgsql.NpgsqlSlimDataSourceBuilder.UnmapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlSlimDataSourceBuilder.UnmapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlSlimDataSourceBuilder.UnmapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlSlimDataSourceBuilder.UnmapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.NpgsqlSlimDataSourceBuilder.UseClientCertificate(System.Security.Cryptography.X509Certificates.X509Certificate? clientCertificate) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseClientCertificates(System.Security.Cryptography.X509Certificates.X509CertificateCollection? clientCertificates) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseClientCertificatesCallback(System.Action? clientCertificatesCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseLoggerFactory(Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UsePasswordProvider(System.Func? passwordProvider, System.Func>? passwordProviderAsync) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UsePeriodicPasswordProvider(System.Func>? passwordProvider, System.TimeSpan successRefreshInterval, System.TimeSpan failureRefreshInterval) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UsePhysicalConnectionInitializer(System.Action? connectionInitializer, System.Func? connectionInitializerAsync) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificate(System.Security.Cryptography.X509Certificates.X509Certificate2? rootCertificate) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificateCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseUserCertificateValidationCallback(System.Net.Security.RemoteCertificateValidationCallback! userCertificateValidationCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlTracingOptions +Npgsql.NpgsqlTracingOptions.NpgsqlTracingOptions() -> void +Npgsql.NpgsqlTransaction +Npgsql.NpgsqlTransaction.Connection.get -> Npgsql.NpgsqlConnection? +Npgsql.PostgresErrorCodes +Npgsql.PostgresException +Npgsql.PostgresException.ColumnName.get -> string? +Npgsql.PostgresException.ConstraintName.get -> string? +Npgsql.PostgresException.DataTypeName.get -> string? +Npgsql.PostgresException.Detail.get -> string? +Npgsql.PostgresException.File.get -> string? +Npgsql.PostgresException.Hint.get -> string? +Npgsql.PostgresException.InternalPosition.get -> int +Npgsql.PostgresException.InternalQuery.get -> string? +Npgsql.PostgresException.InvariantSeverity.get -> string! +Npgsql.PostgresException.Line.get -> string? +Npgsql.PostgresException.MessageText.get -> string! +Npgsql.PostgresException.Position.get -> int +Npgsql.PostgresException.PostgresException(string! messageText, string! severity, string! invariantSeverity, string! sqlState) -> void +Npgsql.PostgresException.PostgresException(string! messageText, string! severity, string! invariantSeverity, string! sqlState, string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, string? line = null, string? routine = null) -> void +Npgsql.PostgresException.Routine.get -> string? +Npgsql.PostgresException.SchemaName.get -> string? +Npgsql.PostgresException.Severity.get -> string! +Npgsql.PostgresException.TableName.get -> string? +Npgsql.PostgresException.Where.get -> string? +Npgsql.PostgresNotice +Npgsql.PostgresNotice.ColumnName.get -> string? +Npgsql.PostgresNotice.ColumnName.set -> void +Npgsql.PostgresNotice.ConstraintName.get -> string? +Npgsql.PostgresNotice.ConstraintName.set -> void +Npgsql.PostgresNotice.DataTypeName.get -> string? +Npgsql.PostgresNotice.DataTypeName.set -> void +Npgsql.PostgresNotice.Detail.get -> string? +Npgsql.PostgresNotice.Detail.set -> void +Npgsql.PostgresNotice.File.get -> string? +Npgsql.PostgresNotice.File.set -> void +Npgsql.PostgresNotice.Hint.get -> string? +Npgsql.PostgresNotice.Hint.set -> void +Npgsql.PostgresNotice.InternalPosition.get -> int +Npgsql.PostgresNotice.InternalPosition.set -> void +Npgsql.PostgresNotice.InternalQuery.get -> string? +Npgsql.PostgresNotice.InternalQuery.set -> void +Npgsql.PostgresNotice.InvariantSeverity.get -> string! +Npgsql.PostgresNotice.Line.get -> string? +Npgsql.PostgresNotice.Line.set -> void +Npgsql.PostgresNotice.MessageText.get -> string! +Npgsql.PostgresNotice.MessageText.set -> void +Npgsql.PostgresNotice.Position.get -> int +Npgsql.PostgresNotice.Position.set -> void +Npgsql.PostgresNotice.PostgresNotice(string! messageText, string! severity, string! invariantSeverity, string! sqlState, string? detail = null, string? hint = null, int position = 0, int internalPosition = 0, string? internalQuery = null, string? where = null, string? schemaName = null, string? tableName = null, string? columnName = null, string? dataTypeName = null, string? constraintName = null, string? file = null, string? line = null, string? routine = null) -> void +Npgsql.PostgresNotice.PostgresNotice(string! severity, string! invariantSeverity, string! sqlState, string! messageText) -> void +Npgsql.PostgresNotice.Routine.get -> string? +Npgsql.PostgresNotice.Routine.set -> void +Npgsql.PostgresNotice.SchemaName.get -> string? +Npgsql.PostgresNotice.SchemaName.set -> void +Npgsql.PostgresNotice.Severity.get -> string! +Npgsql.PostgresNotice.Severity.set -> void +Npgsql.PostgresNotice.SqlState.get -> string! +Npgsql.PostgresNotice.SqlState.set -> void +Npgsql.PostgresNotice.TableName.get -> string? +Npgsql.PostgresNotice.TableName.set -> void +Npgsql.PostgresNotice.Where.get -> string? +Npgsql.PostgresNotice.Where.set -> void +Npgsql.PostgresTypes.PostgresArrayType +Npgsql.PostgresTypes.PostgresArrayType.Element.get -> Npgsql.PostgresTypes.PostgresType! +Npgsql.PostgresTypes.PostgresArrayType.PostgresArrayType(string! ns, string! name, uint oid, Npgsql.PostgresTypes.PostgresType! elementPostgresType) -> void +Npgsql.PostgresTypes.PostgresBaseType +Npgsql.PostgresTypes.PostgresBaseType.PostgresBaseType(string! ns, string! name, uint oid) -> void +Npgsql.PostgresTypes.PostgresCompositeType +Npgsql.PostgresTypes.PostgresCompositeType.Field +Npgsql.PostgresTypes.PostgresCompositeType.Field.Name.get -> string! +Npgsql.PostgresTypes.PostgresCompositeType.Field.Type.get -> Npgsql.PostgresTypes.PostgresType! +Npgsql.PostgresTypes.PostgresCompositeType.Fields.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.PostgresTypes.PostgresDomainType +Npgsql.PostgresTypes.PostgresDomainType.BaseType.get -> Npgsql.PostgresTypes.PostgresType! +Npgsql.PostgresTypes.PostgresDomainType.NotNull.get -> bool +Npgsql.PostgresTypes.PostgresDomainType.PostgresDomainType(string! ns, string! name, uint oid, Npgsql.PostgresTypes.PostgresType! baseType, bool notNull) -> void +Npgsql.PostgresTypes.PostgresEnumType +Npgsql.PostgresTypes.PostgresEnumType.Labels.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.PostgresTypes.PostgresEnumType.PostgresEnumType(string! ns, string! name, uint oid) -> void +Npgsql.PostgresTypes.PostgresMultirangeType +Npgsql.PostgresTypes.PostgresMultirangeType.PostgresMultirangeType(string! ns, string! name, uint oid, Npgsql.PostgresTypes.PostgresRangeType! rangePostgresType) -> void +Npgsql.PostgresTypes.PostgresMultirangeType.Subrange.get -> Npgsql.PostgresTypes.PostgresRangeType! +Npgsql.PostgresTypes.PostgresRangeType +Npgsql.PostgresTypes.PostgresRangeType.Multirange.get -> Npgsql.PostgresTypes.PostgresMultirangeType? +Npgsql.PostgresTypes.PostgresRangeType.PostgresRangeType(string! ns, string! name, uint oid, Npgsql.PostgresTypes.PostgresType! subtypePostgresType) -> void +Npgsql.PostgresTypes.PostgresRangeType.Subtype.get -> Npgsql.PostgresTypes.PostgresType! +Npgsql.PostgresTypes.PostgresType +Npgsql.PostgresTypes.PostgresType.Array.get -> Npgsql.PostgresTypes.PostgresArrayType? +Npgsql.PostgresTypes.PostgresType.DisplayName.get -> string! +Npgsql.PostgresTypes.PostgresType.FullName.get -> string! +Npgsql.PostgresTypes.PostgresType.InternalName.get -> string! +Npgsql.PostgresTypes.PostgresType.Name.get -> string! +Npgsql.PostgresTypes.PostgresType.Namespace.get -> string! +Npgsql.PostgresTypes.PostgresType.OID.get -> uint +Npgsql.PostgresTypes.PostgresType.Range.get -> Npgsql.PostgresTypes.PostgresRangeType? +Npgsql.PostgresTypes.UnknownBackendType +Npgsql.ProvideClientCertificatesCallback +Npgsql.ProvidePasswordCallback +Npgsql.Replication.Internal.LogicalReplicationConnectionExtensions +Npgsql.Replication.Internal.LogicalReplicationSlot +Npgsql.Replication.Internal.LogicalReplicationSlot.ConsistentPoint.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.Internal.LogicalReplicationSlot.LogicalReplicationSlot(string! outputPlugin, Npgsql.Replication.ReplicationSlotOptions replicationSlotOptions) -> void +Npgsql.Replication.Internal.LogicalReplicationSlot.OutputPlugin.get -> string! +Npgsql.Replication.Internal.LogicalReplicationSlot.SnapshotName.get -> string? +Npgsql.Replication.LogicalReplicationConnection +Npgsql.Replication.LogicalReplicationConnection.LogicalReplicationConnection() -> void +Npgsql.Replication.LogicalReplicationConnection.LogicalReplicationConnection(string? connectionString) -> void +Npgsql.Replication.LogicalSlotSnapshotInitMode +Npgsql.Replication.LogicalSlotSnapshotInitMode.Export = 0 -> Npgsql.Replication.LogicalSlotSnapshotInitMode +Npgsql.Replication.LogicalSlotSnapshotInitMode.NoExport = 2 -> Npgsql.Replication.LogicalSlotSnapshotInitMode +Npgsql.Replication.LogicalSlotSnapshotInitMode.Use = 1 -> Npgsql.Replication.LogicalSlotSnapshotInitMode +Npgsql.Replication.PgOutput.Messages.BeginMessage +Npgsql.Replication.PgOutput.Messages.BeginMessage.TransactionCommitTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.BeginMessage.TransactionFinalLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.BeginPrepareMessage +Npgsql.Replication.PgOutput.Messages.CommitMessage +Npgsql.Replication.PgOutput.Messages.CommitMessage.CommitFlags +Npgsql.Replication.PgOutput.Messages.CommitMessage.CommitFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.CommitMessage.CommitFlags +Npgsql.Replication.PgOutput.Messages.CommitMessage.CommitLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.CommitMessage.Flags.get -> Npgsql.Replication.PgOutput.Messages.CommitMessage.CommitFlags +Npgsql.Replication.PgOutput.Messages.CommitMessage.TransactionCommitTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.CommitMessage.TransactionEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedFlags +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedFlags +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.Flags.get -> Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.CommitPreparedFlags +Npgsql.Replication.PgOutput.Messages.CommitPreparedMessage.TransactionCommitTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.DefaultUpdateMessage +Npgsql.Replication.PgOutput.Messages.DeleteMessage +Npgsql.Replication.PgOutput.Messages.DeleteMessage.Relation.get -> Npgsql.Replication.PgOutput.Messages.RelationMessage! +Npgsql.Replication.PgOutput.Messages.FullDeleteMessage +Npgsql.Replication.PgOutput.Messages.FullDeleteMessage.OldRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +Npgsql.Replication.PgOutput.Messages.FullUpdateMessage +Npgsql.Replication.PgOutput.Messages.FullUpdateMessage.OldRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +Npgsql.Replication.PgOutput.Messages.IndexUpdateMessage +Npgsql.Replication.PgOutput.Messages.IndexUpdateMessage.Key.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +Npgsql.Replication.PgOutput.Messages.InsertMessage +Npgsql.Replication.PgOutput.Messages.InsertMessage.NewRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +Npgsql.Replication.PgOutput.Messages.InsertMessage.Relation.get -> Npgsql.Replication.PgOutput.Messages.RelationMessage! +Npgsql.Replication.PgOutput.Messages.KeyDeleteMessage +Npgsql.Replication.PgOutput.Messages.KeyDeleteMessage.Key.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +Npgsql.Replication.PgOutput.Messages.LogicalDecodingMessage +Npgsql.Replication.PgOutput.Messages.LogicalDecodingMessage.Data.get -> System.IO.Stream! +Npgsql.Replication.PgOutput.Messages.LogicalDecodingMessage.Flags.get -> byte +Npgsql.Replication.PgOutput.Messages.LogicalDecodingMessage.MessageLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.LogicalDecodingMessage.Prefix.get -> string! +Npgsql.Replication.PgOutput.Messages.OriginMessage +Npgsql.Replication.PgOutput.Messages.OriginMessage.OriginCommitLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.OriginMessage.OriginName.get -> string! +Npgsql.Replication.PgOutput.Messages.PgOutputReplicationMessage +Npgsql.Replication.PgOutput.Messages.PgOutputReplicationMessage.PgOutputReplicationMessage() -> void +Npgsql.Replication.PgOutput.Messages.PreparedTransactionControlMessage +Npgsql.Replication.PgOutput.Messages.PreparedTransactionControlMessage.TransactionGid.get -> string! +Npgsql.Replication.PgOutput.Messages.PrepareMessage +Npgsql.Replication.PgOutput.Messages.PrepareMessage.Flags.get -> Npgsql.Replication.PgOutput.Messages.PrepareMessage.PrepareFlags +Npgsql.Replication.PgOutput.Messages.PrepareMessage.PrepareFlags +Npgsql.Replication.PgOutput.Messages.PrepareMessage.PrepareFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.PrepareMessage.PrepareFlags +Npgsql.Replication.PgOutput.Messages.PrepareMessageBase +Npgsql.Replication.PgOutput.Messages.PrepareMessageBase.PrepareEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.PrepareMessageBase.PrepareLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.PrepareMessageBase.TransactionPrepareTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.RelationMessage +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.Column() -> void +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags.PartOfKey = 1 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnName.get -> string! +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.DataTypeId.get -> uint +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.Flags.get -> Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.ColumnFlags +Npgsql.Replication.PgOutput.Messages.RelationMessage.Column.TypeModifier.get -> int +Npgsql.Replication.PgOutput.Messages.RelationMessage.Columns.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.Replication.PgOutput.Messages.RelationMessage.Namespace.get -> string! +Npgsql.Replication.PgOutput.Messages.RelationMessage.RelationId.get -> uint +Npgsql.Replication.PgOutput.Messages.RelationMessage.RelationName.get -> string! +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentity.get -> Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting.AllColumns = 102 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting.Default = 100 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting.IndexWithIndIsReplIdent = 105 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting.Nothing = 110 -> Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn.ColumnName.get -> string! +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn.DataTypeId.get -> uint +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn.Flags.get -> byte +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn.RelationMessageColumn() -> void +Npgsql.Replication.PgOutput.Messages.RelationMessageColumn.TypeModifier.get -> int +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.Flags.get -> Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.RollbackPreparedFlags +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.PreparedTransactionEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.RollbackPreparedEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.RollbackPreparedFlags +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.RollbackPreparedFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.RollbackPreparedFlags +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.TransactionPrepareTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.RollbackPreparedMessage.TransactionRollbackTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.StreamAbortMessage +Npgsql.Replication.PgOutput.Messages.StreamAbortMessage.SubtransactionXid.get -> uint +Npgsql.Replication.PgOutput.Messages.StreamCommitMessage +Npgsql.Replication.PgOutput.Messages.StreamCommitMessage.CommitLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.StreamCommitMessage.Flags.get -> byte +Npgsql.Replication.PgOutput.Messages.StreamCommitMessage.TransactionCommitTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.Messages.StreamCommitMessage.TransactionEndLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage +Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage.Flags.get -> Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage.StreamPrepareFlags +Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage.StreamPrepareFlags +Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage.StreamPrepareFlags.None = 0 -> Npgsql.Replication.PgOutput.Messages.StreamPrepareMessage.StreamPrepareFlags +Npgsql.Replication.PgOutput.Messages.StreamStartMessage +Npgsql.Replication.PgOutput.Messages.StreamStartMessage.StreamSegmentIndicator.get -> byte +Npgsql.Replication.PgOutput.Messages.StreamStopMessage +Npgsql.Replication.PgOutput.Messages.TransactionalMessage +Npgsql.Replication.PgOutput.Messages.TransactionalMessage.TransactionalMessage() -> void +Npgsql.Replication.PgOutput.Messages.TransactionalMessage.TransactionXid.get -> uint? +Npgsql.Replication.PgOutput.Messages.TransactionControlMessage +Npgsql.Replication.PgOutput.Messages.TransactionControlMessage.TransactionControlMessage() -> void +Npgsql.Replication.PgOutput.Messages.TransactionControlMessage.TransactionXid.get -> uint +Npgsql.Replication.PgOutput.Messages.TruncateMessage +Npgsql.Replication.PgOutput.Messages.TruncateMessage.Options.get -> Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions +Npgsql.Replication.PgOutput.Messages.TruncateMessage.Relations.get -> System.Collections.Generic.IReadOnlyList! +Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions +Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions.Cascade = 1 -> Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions +Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions.None = 0 -> Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions +Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions.RestartIdentity = 2 -> Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions +Npgsql.Replication.PgOutput.Messages.TypeMessage +Npgsql.Replication.PgOutput.Messages.TypeMessage.Name.get -> string! +Npgsql.Replication.PgOutput.Messages.TypeMessage.Namespace.get -> string! +Npgsql.Replication.PgOutput.Messages.TypeMessage.TypeId.get -> uint +Npgsql.Replication.PgOutput.Messages.UpdateMessage +Npgsql.Replication.PgOutput.Messages.UpdateMessage.Relation.get -> Npgsql.Replication.PgOutput.Messages.RelationMessage! +Npgsql.Replication.PgOutput.PgOutputReplicationOptions +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Binary.get -> bool? +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Equals(Npgsql.Replication.PgOutput.PgOutputReplicationOptions? other) -> bool +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Messages.get -> bool? +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.PgOutputReplicationOptions(string! publicationName, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.PgOutputReplicationOptions(System.Collections.Generic.IEnumerable! publicationNames, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.ProtocolVersion.get -> ulong +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.PublicationNames.get -> System.Collections.Generic.List! +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Streaming.get -> bool? +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.TwoPhase.get -> bool? +Npgsql.Replication.PgOutput.PgOutputReplicationSlot +Npgsql.Replication.PgOutput.PgOutputReplicationSlot.PgOutputReplicationSlot(Npgsql.Replication.PgOutput.PgOutputReplicationSlot! slot) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationSlot.PgOutputReplicationSlot(Npgsql.Replication.ReplicationSlotOptions options) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationSlot.PgOutputReplicationSlot(string! slotName) -> void +Npgsql.Replication.PgOutput.ReplicationTuple +Npgsql.Replication.PgOutput.ReplicationTuple.NumColumns.get -> ushort +Npgsql.Replication.PgOutput.ReplicationValue +Npgsql.Replication.PgOutput.ReplicationValue.Get(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.Replication.PgOutput.ReplicationValue.Get(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.Replication.PgOutput.ReplicationValue.GetDataTypeName() -> string! +Npgsql.Replication.PgOutput.ReplicationValue.GetFieldType() -> System.Type! +Npgsql.Replication.PgOutput.ReplicationValue.GetPostgresType() -> Npgsql.PostgresTypes.PostgresType! +Npgsql.Replication.PgOutput.ReplicationValue.GetStream() -> System.IO.Stream! +Npgsql.Replication.PgOutput.ReplicationValue.GetTextReader() -> System.IO.TextReader! +Npgsql.Replication.PgOutput.ReplicationValue.IsDBNull.get -> bool +Npgsql.Replication.PgOutput.ReplicationValue.IsUnchangedToastedValue.get -> bool +Npgsql.Replication.PgOutput.ReplicationValue.Kind.get -> Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutput.ReplicationValue.Length.get -> int +Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutput.TupleDataKind.BinaryValue = 98 -> Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutput.TupleDataKind.Null = 110 -> Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutput.TupleDataKind.TextValue = 116 -> Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutput.TupleDataKind.UnchangedToastedValue = 117 -> Npgsql.Replication.PgOutput.TupleDataKind +Npgsql.Replication.PgOutputConnectionExtensions +Npgsql.Replication.PhysicalReplicationConnection +Npgsql.Replication.PhysicalReplicationConnection.CreateReplicationSlot(string! slotName, bool isTemporary = false, bool reserveWal = false, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.PhysicalReplicationConnection.PhysicalReplicationConnection() -> void +Npgsql.Replication.PhysicalReplicationConnection.PhysicalReplicationConnection(string? connectionString) -> void +Npgsql.Replication.PhysicalReplicationConnection.ReadReplicationSlot(string! slotName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.PhysicalReplicationConnection.StartReplication(Npgsql.Replication.PhysicalReplicationSlot! slot, System.Threading.CancellationToken cancellationToken) -> System.Collections.Generic.IAsyncEnumerable! +Npgsql.Replication.PhysicalReplicationConnection.StartReplication(Npgsql.Replication.PhysicalReplicationSlot? slot, NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, uint timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! +Npgsql.Replication.PhysicalReplicationConnection.StartReplication(NpgsqlTypes.NpgsqlLogSequenceNumber walLocation, System.Threading.CancellationToken cancellationToken, uint timeline = 0) -> System.Collections.Generic.IAsyncEnumerable! +Npgsql.Replication.PhysicalReplicationSlot +Npgsql.Replication.PhysicalReplicationSlot.PhysicalReplicationSlot(string! slotName, NpgsqlTypes.NpgsqlLogSequenceNumber? restartLsn = null, uint? restartTimeline = null) -> void +Npgsql.Replication.PhysicalReplicationSlot.RestartLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber? +Npgsql.Replication.PhysicalReplicationSlot.RestartTimeline.get -> uint? +Npgsql.Replication.ReplicationConnection +Npgsql.Replication.ReplicationConnection.CommandTimeout.get -> System.TimeSpan +Npgsql.Replication.ReplicationConnection.CommandTimeout.set -> void +Npgsql.Replication.ReplicationConnection.ConnectionString.get -> string! +Npgsql.Replication.ReplicationConnection.ConnectionString.set -> void +Npgsql.Replication.ReplicationConnection.DisposeAsync() -> System.Threading.Tasks.ValueTask +Npgsql.Replication.ReplicationConnection.DropReplicationSlot(string! slotName, bool wait = false, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.Encoding.get -> System.Text.Encoding! +Npgsql.Replication.ReplicationConnection.IdentifySystem(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.LastAppliedLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationConnection.LastAppliedLsn.set -> void +Npgsql.Replication.ReplicationConnection.LastFlushedLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationConnection.LastFlushedLsn.set -> void +Npgsql.Replication.ReplicationConnection.LastReceivedLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationConnection.Open(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.PostgreSqlVersion.get -> System.Version! +Npgsql.Replication.ReplicationConnection.ProcessID.get -> int +Npgsql.Replication.ReplicationConnection.SendStatusUpdate(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.ServerVersion.get -> string! +Npgsql.Replication.ReplicationConnection.SetReplicationStatus(NpgsqlTypes.NpgsqlLogSequenceNumber lastAppliedAndFlushedLsn) -> void +Npgsql.Replication.ReplicationConnection.Show(string! parameterName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.TimelineHistory(uint tli, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.Replication.ReplicationConnection.WalReceiverStatusInterval.get -> System.TimeSpan +Npgsql.Replication.ReplicationConnection.WalReceiverStatusInterval.set -> void +Npgsql.Replication.ReplicationConnection.WalReceiverTimeout.get -> System.TimeSpan +Npgsql.Replication.ReplicationConnection.WalReceiverTimeout.set -> void +Npgsql.Replication.ReplicationMessage +Npgsql.Replication.ReplicationMessage.ReplicationMessage() -> void +Npgsql.Replication.ReplicationMessage.ServerClock.get -> System.DateTime +Npgsql.Replication.ReplicationMessage.WalEnd.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationMessage.WalStart.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationSlot +Npgsql.Replication.ReplicationSlot.Name.get -> string! +Npgsql.Replication.ReplicationSlotOptions +Npgsql.Replication.ReplicationSlotOptions.ConsistentPoint.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.ReplicationSlotOptions.ReplicationSlotOptions() -> void +Npgsql.Replication.ReplicationSlotOptions.ReplicationSlotOptions(string! slotName, NpgsqlTypes.NpgsqlLogSequenceNumber consistentPoint) -> void +Npgsql.Replication.ReplicationSlotOptions.ReplicationSlotOptions(string! slotName, string? consistentPoint = null) -> void +Npgsql.Replication.ReplicationSlotOptions.SlotName.get -> string! +Npgsql.Replication.ReplicationSystemIdentification +Npgsql.Replication.ReplicationSystemIdentification.DbName.get -> string? +Npgsql.Replication.ReplicationSystemIdentification.SystemId.get -> string! +Npgsql.Replication.ReplicationSystemIdentification.Timeline.get -> uint +Npgsql.Replication.ReplicationSystemIdentification.XLogPos.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.TestDecoding.TestDecodingData +Npgsql.Replication.TestDecoding.TestDecodingData.Clone() -> Npgsql.Replication.TestDecoding.TestDecodingData! +Npgsql.Replication.TestDecoding.TestDecodingData.Data.get -> string! +Npgsql.Replication.TestDecoding.TestDecodingData.TestDecodingData() -> void +Npgsql.Replication.TestDecoding.TestDecodingOptions +Npgsql.Replication.TestDecoding.TestDecodingOptions.Equals(Npgsql.Replication.TestDecoding.TestDecodingOptions? other) -> bool +Npgsql.Replication.TestDecoding.TestDecodingOptions.ForceBinary.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.IncludeRewrites.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.IncludeTimestamp.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.IncludeXids.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.OnlyLocal.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.SkipEmptyXacts.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.StreamChanges.get -> bool? +Npgsql.Replication.TestDecoding.TestDecodingOptions.TestDecodingOptions(bool? includeXids = null, bool? includeTimestamp = null, bool? forceBinary = null, bool? skipEmptyXacts = null, bool? onlyLocal = null, bool? includeRewrites = null, bool? streamChanges = null) -> void +Npgsql.Replication.TestDecoding.TestDecodingReplicationSlot +Npgsql.Replication.TestDecoding.TestDecodingReplicationSlot.TestDecodingReplicationSlot(Npgsql.Replication.ReplicationSlotOptions options) -> void +Npgsql.Replication.TestDecoding.TestDecodingReplicationSlot.TestDecodingReplicationSlot(string! slotName) -> void +Npgsql.Replication.TestDecodingConnectionExtensions +Npgsql.Replication.TimelineHistoryFile +Npgsql.Replication.TimelineHistoryFile.Content.get -> byte[]! +Npgsql.Replication.TimelineHistoryFile.FileName.get -> string! +Npgsql.Replication.TimelineHistoryFile.TimelineHistoryFile() -> void +Npgsql.Replication.XLogDataMessage +Npgsql.Replication.XLogDataMessage.Data.get -> System.IO.Stream! +Npgsql.Replication.XLogDataMessage.XLogDataMessage() -> void +Npgsql.Schema.NpgsqlDbColumn +Npgsql.Schema.NpgsqlDbColumn.AllowDBNull.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.AllowDBNull.set -> void +Npgsql.Schema.NpgsqlDbColumn.BaseCatalogName.get -> string! +Npgsql.Schema.NpgsqlDbColumn.BaseCatalogName.set -> void +Npgsql.Schema.NpgsqlDbColumn.BaseColumnName.get -> string? +Npgsql.Schema.NpgsqlDbColumn.BaseColumnName.set -> void +Npgsql.Schema.NpgsqlDbColumn.BaseSchemaName.get -> string? +Npgsql.Schema.NpgsqlDbColumn.BaseSchemaName.set -> void +Npgsql.Schema.NpgsqlDbColumn.BaseServerName.get -> string! +Npgsql.Schema.NpgsqlDbColumn.BaseServerName.set -> void +Npgsql.Schema.NpgsqlDbColumn.BaseTableName.get -> string? +Npgsql.Schema.NpgsqlDbColumn.BaseTableName.set -> void +Npgsql.Schema.NpgsqlDbColumn.ColumnAttributeNumber.get -> short? +Npgsql.Schema.NpgsqlDbColumn.ColumnName.get -> string! +Npgsql.Schema.NpgsqlDbColumn.ColumnName.set -> void +Npgsql.Schema.NpgsqlDbColumn.ColumnOrdinal.get -> int? +Npgsql.Schema.NpgsqlDbColumn.ColumnOrdinal.set -> void +Npgsql.Schema.NpgsqlDbColumn.ColumnSize.get -> int? +Npgsql.Schema.NpgsqlDbColumn.ColumnSize.set -> void +Npgsql.Schema.NpgsqlDbColumn.DataType.get -> System.Type? +Npgsql.Schema.NpgsqlDbColumn.DataType.set -> void +Npgsql.Schema.NpgsqlDbColumn.DataTypeName.get -> string! +Npgsql.Schema.NpgsqlDbColumn.DataTypeName.set -> void +Npgsql.Schema.NpgsqlDbColumn.DefaultValue.get -> string? +Npgsql.Schema.NpgsqlDbColumn.IsAliased.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsAliased.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsAutoIncrement.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsAutoIncrement.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsIdentity.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsIdentity.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsKey.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsKey.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsLong.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsLong.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsReadOnly.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsReadOnly.set -> void +Npgsql.Schema.NpgsqlDbColumn.IsUnique.get -> bool? +Npgsql.Schema.NpgsqlDbColumn.IsUnique.set -> void +Npgsql.Schema.NpgsqlDbColumn.NpgsqlDbColumn() -> void +Npgsql.Schema.NpgsqlDbColumn.NpgsqlDbType.get -> NpgsqlTypes.NpgsqlDbType? +Npgsql.Schema.NpgsqlDbColumn.NumericPrecision.get -> int? +Npgsql.Schema.NpgsqlDbColumn.NumericPrecision.set -> void +Npgsql.Schema.NpgsqlDbColumn.NumericScale.get -> int? +Npgsql.Schema.NpgsqlDbColumn.NumericScale.set -> void +Npgsql.Schema.NpgsqlDbColumn.PostgresType.get -> Npgsql.PostgresTypes.PostgresType! +Npgsql.Schema.NpgsqlDbColumn.TableOID.get -> uint +Npgsql.Schema.NpgsqlDbColumn.TypeOID.get -> uint +Npgsql.Schema.NpgsqlDbColumn.UdtAssemblyQualifiedName.get -> string? +Npgsql.Schema.NpgsqlDbColumn.UdtAssemblyQualifiedName.set -> void +Npgsql.ServerCompatibilityMode +Npgsql.ServerCompatibilityMode.None = 0 -> Npgsql.ServerCompatibilityMode +Npgsql.ServerCompatibilityMode.NoTypeLoading = 2 -> Npgsql.ServerCompatibilityMode +Npgsql.ServerCompatibilityMode.Redshift = 1 -> Npgsql.ServerCompatibilityMode +Npgsql.SslMode +Npgsql.SslMode.Allow = 1 -> Npgsql.SslMode +Npgsql.SslMode.Disable = 0 -> Npgsql.SslMode +Npgsql.SslMode.Prefer = 2 -> Npgsql.SslMode +Npgsql.SslMode.Require = 3 -> Npgsql.SslMode +Npgsql.SslMode.VerifyCA = 4 -> Npgsql.SslMode +Npgsql.SslMode.VerifyFull = 5 -> Npgsql.SslMode +Npgsql.StatementType +Npgsql.StatementType.Call = 11 -> Npgsql.StatementType +Npgsql.StatementType.Copy = 8 -> Npgsql.StatementType +Npgsql.StatementType.CreateTableAs = 5 -> Npgsql.StatementType +Npgsql.StatementType.Delete = 3 -> Npgsql.StatementType +Npgsql.StatementType.Fetch = 7 -> Npgsql.StatementType +Npgsql.StatementType.Insert = 2 -> Npgsql.StatementType +Npgsql.StatementType.Merge = 10 -> Npgsql.StatementType +Npgsql.StatementType.Move = 6 -> Npgsql.StatementType +Npgsql.StatementType.Other = 9 -> Npgsql.StatementType +Npgsql.StatementType.Select = 1 -> Npgsql.StatementType +Npgsql.StatementType.Unknown = 0 -> Npgsql.StatementType +Npgsql.StatementType.Update = 4 -> Npgsql.StatementType +Npgsql.TypeMapping.INpgsqlTypeMapper +Npgsql.TypeMapping.INpgsqlTypeMapper.AddTypeInfoResolverFactory(Npgsql.Internal.PgTypeInfoResolverFactory! factory) -> void +Npgsql.TypeMapping.INpgsqlTypeMapper.ConfigureJsonOptions(System.Text.Json.JsonSerializerOptions! serializerOptions) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.DefaultNameTranslator.get -> Npgsql.INpgsqlNameTranslator! +Npgsql.TypeMapping.INpgsqlTypeMapper.DefaultNameTranslator.set -> void +Npgsql.TypeMapping.INpgsqlTypeMapper.EnableDynamicJson(System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.EnableRecordsAsTuples() -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.EnableUnmappedTypes() -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +Npgsql.TypeMapping.INpgsqlTypeMapper.Reset() -> void +Npgsql.TypeMapping.INpgsqlTypeMapper.UnmapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.TypeMapping.INpgsqlTypeMapper.UnmapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.TypeMapping.INpgsqlTypeMapper.UnmapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.TypeMapping.INpgsqlTypeMapper.UnmapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> bool +Npgsql.TypeMapping.UserTypeMapping +Npgsql.TypeMapping.UserTypeMapping.ClrType.get -> System.Type! +Npgsql.TypeMapping.UserTypeMapping.PgTypeName.get -> string! +Npgsql.Util.NpgsqlTimeout +Npgsql.Util.NpgsqlTimeout.NpgsqlTimeout() -> void +NpgsqlTypes.NpgsqlBox +NpgsqlTypes.NpgsqlBox.Bottom.get -> double +NpgsqlTypes.NpgsqlBox.Equals(NpgsqlTypes.NpgsqlBox other) -> bool +NpgsqlTypes.NpgsqlBox.Height.get -> double +NpgsqlTypes.NpgsqlBox.IsEmpty.get -> bool +NpgsqlTypes.NpgsqlBox.Left.get -> double +NpgsqlTypes.NpgsqlBox.LowerLeft.get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlBox.LowerLeft.set -> void +NpgsqlTypes.NpgsqlBox.NpgsqlBox() -> void +NpgsqlTypes.NpgsqlBox.NpgsqlBox(double top, double right, double bottom, double left) -> void +NpgsqlTypes.NpgsqlBox.NpgsqlBox(NpgsqlTypes.NpgsqlPoint upperRight, NpgsqlTypes.NpgsqlPoint lowerLeft) -> void +NpgsqlTypes.NpgsqlBox.Right.get -> double +NpgsqlTypes.NpgsqlBox.Top.get -> double +NpgsqlTypes.NpgsqlBox.UpperRight.get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlBox.UpperRight.set -> void +NpgsqlTypes.NpgsqlBox.Width.get -> double +NpgsqlTypes.NpgsqlCidr +NpgsqlTypes.NpgsqlCidr.Address.get -> System.Net.IPAddress! +NpgsqlTypes.NpgsqlCidr.Deconstruct(out System.Net.IPAddress! address, out byte netmask) -> void +NpgsqlTypes.NpgsqlCidr.Netmask.get -> byte +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr() -> void +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr(string! addr) -> void +NpgsqlTypes.NpgsqlCidr.NpgsqlCidr(System.Net.IPAddress! address, byte netmask) -> void +NpgsqlTypes.NpgsqlCircle +NpgsqlTypes.NpgsqlCircle.Center.get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlCircle.Center.set -> void +NpgsqlTypes.NpgsqlCircle.Equals(NpgsqlTypes.NpgsqlCircle other) -> bool +NpgsqlTypes.NpgsqlCircle.NpgsqlCircle() -> void +NpgsqlTypes.NpgsqlCircle.NpgsqlCircle(double x, double y, double radius) -> void +NpgsqlTypes.NpgsqlCircle.NpgsqlCircle(NpgsqlTypes.NpgsqlPoint center, double radius) -> void +NpgsqlTypes.NpgsqlCircle.Radius.get -> double +NpgsqlTypes.NpgsqlCircle.Radius.set -> void +NpgsqlTypes.NpgsqlCircle.X.get -> double +NpgsqlTypes.NpgsqlCircle.X.set -> void +NpgsqlTypes.NpgsqlCircle.Y.get -> double +NpgsqlTypes.NpgsqlCircle.Y.set -> void +NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Abstime = 33 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Array = -2147483648 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Bigint = 1 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.BigIntMultirange = 536870913 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.BigIntRange = 1073741825 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Bit = 25 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Boolean = 2 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Box = 3 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Bytea = 4 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Char = 6 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Cid = 43 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Cidr = 44 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Circle = 5 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Citext = 51 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Date = 7 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.DateMultirange = 536870919 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.DateRange = 1073741831 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Double = 8 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Geography = 55 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Geometry = 50 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Hstore = 37 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Inet = 24 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Int2Vector = 52 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Integer = 9 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.IntegerMultirange = 536870921 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.IntegerRange = 1073741833 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.InternalChar = 38 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Interval = 30 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Json = 35 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Jsonb = 36 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.JsonPath = 57 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Line = 10 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.LQuery = 61 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.LSeg = 11 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.LTree = 60 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.LTxtQuery = 62 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.MacAddr = 34 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.MacAddr8 = 54 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Money = 12 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Multirange = 536870912 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Name = 32 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Numeric = 13 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.NumericMultirange = 536870925 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.NumericRange = 1073741837 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Oid = 41 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Oidvector = 29 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Path = 14 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.PgLsn = 59 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Point = 15 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Polygon = 16 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Range = 1073741824 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Real = 17 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Refcursor = 23 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Regconfig = 56 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Regtype = 49 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Smallint = 18 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Text = 19 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Tid = 53 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Time = 20 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Timestamp = 21 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimestampMultirange = 536870933 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimestampRange = 1073741845 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimestampTz = 26 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimestampTzMultirange = 536870938 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimestampTzRange = 1073741850 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TimeTz = 31 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TsQuery = 46 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.TsVector = 45 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Unknown = 40 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Uuid = 27 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Varbit = 39 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Varchar = 22 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Xid = 42 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Xid8 = 64 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlDbType.Xml = 28 -> NpgsqlTypes.NpgsqlDbType +NpgsqlTypes.NpgsqlInet +NpgsqlTypes.NpgsqlInet.Address.get -> System.Net.IPAddress! +NpgsqlTypes.NpgsqlInet.Deconstruct(out System.Net.IPAddress! address, out byte netmask) -> void +NpgsqlTypes.NpgsqlInet.Netmask.get -> byte +NpgsqlTypes.NpgsqlInet.NpgsqlInet() -> void +NpgsqlTypes.NpgsqlInet.NpgsqlInet(string! addr) -> void +NpgsqlTypes.NpgsqlInet.NpgsqlInet(System.Net.IPAddress! address) -> void +NpgsqlTypes.NpgsqlInet.NpgsqlInet(System.Net.IPAddress! address, byte netmask) -> void +NpgsqlTypes.NpgsqlInterval +NpgsqlTypes.NpgsqlInterval.Days.get -> int +NpgsqlTypes.NpgsqlInterval.Equals(NpgsqlTypes.NpgsqlInterval other) -> bool +NpgsqlTypes.NpgsqlInterval.Months.get -> int +NpgsqlTypes.NpgsqlInterval.NpgsqlInterval() -> void +NpgsqlTypes.NpgsqlInterval.NpgsqlInterval(int months, int days, long time) -> void +NpgsqlTypes.NpgsqlInterval.Time.get -> long +NpgsqlTypes.NpgsqlLine +NpgsqlTypes.NpgsqlLine.A.get -> double +NpgsqlTypes.NpgsqlLine.A.set -> void +NpgsqlTypes.NpgsqlLine.B.get -> double +NpgsqlTypes.NpgsqlLine.B.set -> void +NpgsqlTypes.NpgsqlLine.C.get -> double +NpgsqlTypes.NpgsqlLine.C.set -> void +NpgsqlTypes.NpgsqlLine.Equals(NpgsqlTypes.NpgsqlLine other) -> bool +NpgsqlTypes.NpgsqlLine.NpgsqlLine() -> void +NpgsqlTypes.NpgsqlLine.NpgsqlLine(double a, double b, double c) -> void +NpgsqlTypes.NpgsqlLogSequenceNumber +NpgsqlTypes.NpgsqlLogSequenceNumber.CompareTo(NpgsqlTypes.NpgsqlLogSequenceNumber value) -> int +NpgsqlTypes.NpgsqlLogSequenceNumber.Equals(NpgsqlTypes.NpgsqlLogSequenceNumber other) -> bool +NpgsqlTypes.NpgsqlLogSequenceNumber.NpgsqlLogSequenceNumber() -> void +NpgsqlTypes.NpgsqlLogSequenceNumber.NpgsqlLogSequenceNumber(ulong value) -> void +NpgsqlTypes.NpgsqlLSeg +NpgsqlTypes.NpgsqlLSeg.End.get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlLSeg.End.set -> void +NpgsqlTypes.NpgsqlLSeg.Equals(NpgsqlTypes.NpgsqlLSeg other) -> bool +NpgsqlTypes.NpgsqlLSeg.NpgsqlLSeg() -> void +NpgsqlTypes.NpgsqlLSeg.NpgsqlLSeg(double startx, double starty, double endx, double endy) -> void +NpgsqlTypes.NpgsqlLSeg.NpgsqlLSeg(NpgsqlTypes.NpgsqlPoint start, NpgsqlTypes.NpgsqlPoint end) -> void +NpgsqlTypes.NpgsqlLSeg.Start.get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlLSeg.Start.set -> void +NpgsqlTypes.NpgsqlPath +NpgsqlTypes.NpgsqlPath.Add(NpgsqlTypes.NpgsqlPoint item) -> void +NpgsqlTypes.NpgsqlPath.Capacity.get -> int +NpgsqlTypes.NpgsqlPath.Clear() -> void +NpgsqlTypes.NpgsqlPath.Contains(NpgsqlTypes.NpgsqlPoint item) -> bool +NpgsqlTypes.NpgsqlPath.CopyTo(NpgsqlTypes.NpgsqlPoint[]! array, int arrayIndex) -> void +NpgsqlTypes.NpgsqlPath.Count.get -> int +NpgsqlTypes.NpgsqlPath.Equals(NpgsqlTypes.NpgsqlPath other) -> bool +NpgsqlTypes.NpgsqlPath.GetEnumerator() -> System.Collections.Generic.IEnumerator! +NpgsqlTypes.NpgsqlPath.IndexOf(NpgsqlTypes.NpgsqlPoint item) -> int +NpgsqlTypes.NpgsqlPath.Insert(int index, NpgsqlTypes.NpgsqlPoint item) -> void +NpgsqlTypes.NpgsqlPath.IsReadOnly.get -> bool +NpgsqlTypes.NpgsqlPath.NpgsqlPath() -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(bool open) -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(int capacity) -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(int capacity, bool open) -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(params NpgsqlTypes.NpgsqlPoint[]! points) -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(System.Collections.Generic.IEnumerable! points) -> void +NpgsqlTypes.NpgsqlPath.NpgsqlPath(System.Collections.Generic.IEnumerable! points, bool open) -> void +NpgsqlTypes.NpgsqlPath.Open.get -> bool +NpgsqlTypes.NpgsqlPath.Open.set -> void +NpgsqlTypes.NpgsqlPath.Remove(NpgsqlTypes.NpgsqlPoint item) -> bool +NpgsqlTypes.NpgsqlPath.RemoveAt(int index) -> void +NpgsqlTypes.NpgsqlPath.this[int index].get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlPath.this[int index].set -> void +NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlPoint.Equals(NpgsqlTypes.NpgsqlPoint other) -> bool +NpgsqlTypes.NpgsqlPoint.NpgsqlPoint() -> void +NpgsqlTypes.NpgsqlPoint.NpgsqlPoint(double x, double y) -> void +NpgsqlTypes.NpgsqlPoint.X.get -> double +NpgsqlTypes.NpgsqlPoint.X.set -> void +NpgsqlTypes.NpgsqlPoint.Y.get -> double +NpgsqlTypes.NpgsqlPoint.Y.set -> void +NpgsqlTypes.NpgsqlPolygon +NpgsqlTypes.NpgsqlPolygon.Add(NpgsqlTypes.NpgsqlPoint item) -> void +NpgsqlTypes.NpgsqlPolygon.Capacity.get -> int +NpgsqlTypes.NpgsqlPolygon.Clear() -> void +NpgsqlTypes.NpgsqlPolygon.Contains(NpgsqlTypes.NpgsqlPoint item) -> bool +NpgsqlTypes.NpgsqlPolygon.CopyTo(NpgsqlTypes.NpgsqlPoint[]! array, int arrayIndex) -> void +NpgsqlTypes.NpgsqlPolygon.Count.get -> int +NpgsqlTypes.NpgsqlPolygon.Equals(NpgsqlTypes.NpgsqlPolygon other) -> bool +NpgsqlTypes.NpgsqlPolygon.GetEnumerator() -> System.Collections.Generic.IEnumerator! +NpgsqlTypes.NpgsqlPolygon.IndexOf(NpgsqlTypes.NpgsqlPoint item) -> int +NpgsqlTypes.NpgsqlPolygon.Insert(int index, NpgsqlTypes.NpgsqlPoint item) -> void +NpgsqlTypes.NpgsqlPolygon.IsReadOnly.get -> bool +NpgsqlTypes.NpgsqlPolygon.NpgsqlPolygon() -> void +NpgsqlTypes.NpgsqlPolygon.NpgsqlPolygon(int capacity) -> void +NpgsqlTypes.NpgsqlPolygon.NpgsqlPolygon(params NpgsqlTypes.NpgsqlPoint[]! points) -> void +NpgsqlTypes.NpgsqlPolygon.NpgsqlPolygon(System.Collections.Generic.IEnumerable! points) -> void +NpgsqlTypes.NpgsqlPolygon.Remove(NpgsqlTypes.NpgsqlPoint item) -> bool +NpgsqlTypes.NpgsqlPolygon.RemoveAt(int index) -> void +NpgsqlTypes.NpgsqlPolygon.this[int index].get -> NpgsqlTypes.NpgsqlPoint +NpgsqlTypes.NpgsqlPolygon.this[int index].set -> void +NpgsqlTypes.NpgsqlRange +NpgsqlTypes.NpgsqlRange.Equals(NpgsqlTypes.NpgsqlRange other) -> bool +NpgsqlTypes.NpgsqlRange.IsEmpty.get -> bool +NpgsqlTypes.NpgsqlRange.LowerBound.get -> T +NpgsqlTypes.NpgsqlRange.LowerBoundInfinite.get -> bool +NpgsqlTypes.NpgsqlRange.LowerBoundIsInclusive.get -> bool +NpgsqlTypes.NpgsqlRange.NpgsqlRange() -> void +NpgsqlTypes.NpgsqlRange.NpgsqlRange(T lowerBound, bool lowerBoundIsInclusive, bool lowerBoundInfinite, T upperBound, bool upperBoundIsInclusive, bool upperBoundInfinite) -> void +NpgsqlTypes.NpgsqlRange.NpgsqlRange(T lowerBound, bool lowerBoundIsInclusive, T upperBound, bool upperBoundIsInclusive) -> void +NpgsqlTypes.NpgsqlRange.NpgsqlRange(T lowerBound, T upperBound) -> void +NpgsqlTypes.NpgsqlRange.RangeTypeConverter +NpgsqlTypes.NpgsqlRange.RangeTypeConverter.RangeTypeConverter() -> void +NpgsqlTypes.NpgsqlRange.UpperBound.get -> T +NpgsqlTypes.NpgsqlRange.UpperBoundInfinite.get -> bool +NpgsqlTypes.NpgsqlRange.UpperBoundIsInclusive.get -> bool +NpgsqlTypes.NpgsqlTid +NpgsqlTypes.NpgsqlTid.BlockNumber.get -> uint +NpgsqlTypes.NpgsqlTid.Equals(NpgsqlTypes.NpgsqlTid other) -> bool +NpgsqlTypes.NpgsqlTid.NpgsqlTid() -> void +NpgsqlTypes.NpgsqlTid.NpgsqlTid(uint blockNumber, ushort offsetNumber) -> void +NpgsqlTypes.NpgsqlTid.OffsetNumber.get -> ushort +NpgsqlTypes.NpgsqlTsQuery +NpgsqlTypes.NpgsqlTsQuery.Kind.get -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.And = 2 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.Empty = -1 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.Lexeme = 0 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.Not = 1 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.Or = 3 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NodeKind.Phrase = 4 -> NpgsqlTypes.NpgsqlTsQuery.NodeKind +NpgsqlTypes.NpgsqlTsQuery.NpgsqlTsQuery(NpgsqlTypes.NpgsqlTsQuery.NodeKind kind) -> void +NpgsqlTypes.NpgsqlTsQuery.Write(System.Text.StringBuilder! stringBuilder) -> void +NpgsqlTypes.NpgsqlTsQueryAnd +NpgsqlTypes.NpgsqlTsQueryAnd.NpgsqlTsQueryAnd(NpgsqlTypes.NpgsqlTsQuery! left, NpgsqlTypes.NpgsqlTsQuery! right) -> void +NpgsqlTypes.NpgsqlTsQueryBinOp +NpgsqlTypes.NpgsqlTsQueryBinOp.Left.get -> NpgsqlTypes.NpgsqlTsQuery! +NpgsqlTypes.NpgsqlTsQueryBinOp.Left.set -> void +NpgsqlTypes.NpgsqlTsQueryBinOp.NpgsqlTsQueryBinOp(NpgsqlTypes.NpgsqlTsQuery.NodeKind kind, NpgsqlTypes.NpgsqlTsQuery! left, NpgsqlTypes.NpgsqlTsQuery! right) -> void +NpgsqlTypes.NpgsqlTsQueryBinOp.Right.get -> NpgsqlTypes.NpgsqlTsQuery! +NpgsqlTypes.NpgsqlTsQueryBinOp.Right.set -> void +NpgsqlTypes.NpgsqlTsQueryEmpty +NpgsqlTypes.NpgsqlTsQueryEmpty.NpgsqlTsQueryEmpty() -> void +NpgsqlTypes.NpgsqlTsQueryFollowedBy +NpgsqlTypes.NpgsqlTsQueryFollowedBy.Distance.get -> short +NpgsqlTypes.NpgsqlTsQueryFollowedBy.Distance.set -> void +NpgsqlTypes.NpgsqlTsQueryFollowedBy.NpgsqlTsQueryFollowedBy(NpgsqlTypes.NpgsqlTsQuery! left, short distance, NpgsqlTypes.NpgsqlTsQuery! right) -> void +NpgsqlTypes.NpgsqlTsQueryLexeme +NpgsqlTypes.NpgsqlTsQueryLexeme.IsPrefixSearch.get -> bool +NpgsqlTypes.NpgsqlTsQueryLexeme.IsPrefixSearch.set -> void +NpgsqlTypes.NpgsqlTsQueryLexeme.NpgsqlTsQueryLexeme(string! text) -> void +NpgsqlTypes.NpgsqlTsQueryLexeme.NpgsqlTsQueryLexeme(string! text, NpgsqlTypes.NpgsqlTsQueryLexeme.Weight weights) -> void +NpgsqlTypes.NpgsqlTsQueryLexeme.NpgsqlTsQueryLexeme(string! text, NpgsqlTypes.NpgsqlTsQueryLexeme.Weight weights, bool isPrefixSearch) -> void +NpgsqlTypes.NpgsqlTsQueryLexeme.Text.get -> string! +NpgsqlTypes.NpgsqlTsQueryLexeme.Text.set -> void +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight.A = 8 -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight.B = 4 -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight.C = 2 -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight.D = 1 -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weight.None = 0 -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weights.get -> NpgsqlTypes.NpgsqlTsQueryLexeme.Weight +NpgsqlTypes.NpgsqlTsQueryLexeme.Weights.set -> void +NpgsqlTypes.NpgsqlTsQueryNot +NpgsqlTypes.NpgsqlTsQueryNot.Child.get -> NpgsqlTypes.NpgsqlTsQuery! +NpgsqlTypes.NpgsqlTsQueryNot.Child.set -> void +NpgsqlTypes.NpgsqlTsQueryNot.NpgsqlTsQueryNot(NpgsqlTypes.NpgsqlTsQuery! child) -> void +NpgsqlTypes.NpgsqlTsQueryOr +NpgsqlTypes.NpgsqlTsQueryOr.NpgsqlTsQueryOr(NpgsqlTypes.NpgsqlTsQuery! left, NpgsqlTypes.NpgsqlTsQuery! right) -> void +NpgsqlTypes.NpgsqlTsVector +NpgsqlTypes.NpgsqlTsVector.Count.get -> int +NpgsqlTypes.NpgsqlTsVector.Equals(NpgsqlTypes.NpgsqlTsVector? other) -> bool +NpgsqlTypes.NpgsqlTsVector.GetEnumerator() -> System.Collections.Generic.IEnumerator! +NpgsqlTypes.NpgsqlTsVector.Lexeme +NpgsqlTypes.NpgsqlTsVector.Lexeme.Count.get -> int +NpgsqlTypes.NpgsqlTsVector.Lexeme.Equals(NpgsqlTypes.NpgsqlTsVector.Lexeme o) -> bool +NpgsqlTypes.NpgsqlTsVector.Lexeme.Lexeme() -> void +NpgsqlTypes.NpgsqlTsVector.Lexeme.Lexeme(string! text) -> void +NpgsqlTypes.NpgsqlTsVector.Lexeme.Lexeme(string! text, System.Collections.Generic.List? wordEntryPositions) -> void +NpgsqlTypes.NpgsqlTsVector.Lexeme.Text.get -> string! +NpgsqlTypes.NpgsqlTsVector.Lexeme.Text.set -> void +NpgsqlTypes.NpgsqlTsVector.Lexeme.this[int index].get -> NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos +NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight.A = 3 -> NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight.B = 2 -> NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight.C = 1 -> NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight.D = 0 -> NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.Equals(NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos o) -> bool +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.Pos.get -> int +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.Weight.get -> NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.WordEntryPos() -> void +NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.WordEntryPos(int pos, NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight weight = NpgsqlTypes.NpgsqlTsVector.Lexeme.Weight.D) -> void +NpgsqlTypes.NpgsqlTsVector.this[int index].get -> NpgsqlTypes.NpgsqlTsVector.Lexeme +NpgsqlTypes.PgNameAttribute +NpgsqlTypes.PgNameAttribute.PgName.get -> string! +NpgsqlTypes.PgNameAttribute.PgNameAttribute(string! pgName) -> void +override Npgsql.BackendMessages.FieldDescription.ToString() -> string! +override Npgsql.NpgsqlBatch.Cancel() -> void +override Npgsql.NpgsqlBatch.CreateDbBatchCommand() -> System.Data.Common.DbBatchCommand! +override Npgsql.NpgsqlBatch.DbBatchCommands.get -> System.Data.Common.DbBatchCommandCollection! +override Npgsql.NpgsqlBatch.DbConnection.get -> System.Data.Common.DbConnection? +override Npgsql.NpgsqlBatch.DbConnection.set -> void +override Npgsql.NpgsqlBatch.DbTransaction.get -> System.Data.Common.DbTransaction? +override Npgsql.NpgsqlBatch.DbTransaction.set -> void +override Npgsql.NpgsqlBatch.Dispose() -> void +override Npgsql.NpgsqlBatch.ExecuteDbDataReader(System.Data.CommandBehavior behavior) -> System.Data.Common.DbDataReader! +override Npgsql.NpgsqlBatch.ExecuteDbDataReaderAsync(System.Data.CommandBehavior behavior, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlBatch.ExecuteNonQuery() -> int +override Npgsql.NpgsqlBatch.ExecuteNonQueryAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlBatch.ExecuteScalar() -> object? +override Npgsql.NpgsqlBatch.ExecuteScalarAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlBatch.Prepare() -> void +override Npgsql.NpgsqlBatch.PrepareAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlBatch.Timeout.get -> int +override Npgsql.NpgsqlBatch.Timeout.set -> void +override Npgsql.NpgsqlBatchCommand.CanCreateParameter.get -> bool +override Npgsql.NpgsqlBatchCommand.CommandText.get -> string! +override Npgsql.NpgsqlBatchCommand.CommandText.set -> void +override Npgsql.NpgsqlBatchCommand.CommandType.get -> System.Data.CommandType +override Npgsql.NpgsqlBatchCommand.CommandType.set -> void +override Npgsql.NpgsqlBatchCommand.CreateParameter() -> Npgsql.NpgsqlParameter! +override Npgsql.NpgsqlBatchCommand.RecordsAffected.get -> int +override Npgsql.NpgsqlBatchCommand.ToString() -> string! +override Npgsql.NpgsqlBatchCommandCollection.Add(System.Data.Common.DbBatchCommand! item) -> void +override Npgsql.NpgsqlBatchCommandCollection.Clear() -> void +override Npgsql.NpgsqlBatchCommandCollection.Contains(System.Data.Common.DbBatchCommand! item) -> bool +override Npgsql.NpgsqlBatchCommandCollection.CopyTo(System.Data.Common.DbBatchCommand![]! array, int arrayIndex) -> void +override Npgsql.NpgsqlBatchCommandCollection.Count.get -> int +override Npgsql.NpgsqlBatchCommandCollection.GetEnumerator() -> System.Collections.Generic.IEnumerator! +override Npgsql.NpgsqlBatchCommandCollection.IndexOf(System.Data.Common.DbBatchCommand! item) -> int +override Npgsql.NpgsqlBatchCommandCollection.Insert(int index, System.Data.Common.DbBatchCommand! item) -> void +override Npgsql.NpgsqlBatchCommandCollection.IsReadOnly.get -> bool +override Npgsql.NpgsqlBatchCommandCollection.Remove(System.Data.Common.DbBatchCommand! item) -> bool +override Npgsql.NpgsqlBatchCommandCollection.RemoveAt(int index) -> void +override Npgsql.NpgsqlCommand.Cancel() -> void +override Npgsql.NpgsqlCommand.CommandText.get -> string! +override Npgsql.NpgsqlCommand.CommandText.set -> void +override Npgsql.NpgsqlCommand.CommandTimeout.get -> int +override Npgsql.NpgsqlCommand.CommandTimeout.set -> void +override Npgsql.NpgsqlCommand.CommandType.get -> System.Data.CommandType +override Npgsql.NpgsqlCommand.CommandType.set -> void +override Npgsql.NpgsqlCommand.CreateDbParameter() -> System.Data.Common.DbParameter! +override Npgsql.NpgsqlCommand.DbConnection.get -> System.Data.Common.DbConnection? +override Npgsql.NpgsqlCommand.DbConnection.set -> void +override Npgsql.NpgsqlCommand.DbParameterCollection.get -> System.Data.Common.DbParameterCollection! +override Npgsql.NpgsqlCommand.DbTransaction.get -> System.Data.Common.DbTransaction? +override Npgsql.NpgsqlCommand.DbTransaction.set -> void +override Npgsql.NpgsqlCommand.DesignTimeVisible.get -> bool +override Npgsql.NpgsqlCommand.DesignTimeVisible.set -> void +override Npgsql.NpgsqlCommand.Dispose(bool disposing) -> void +override Npgsql.NpgsqlCommand.ExecuteDbDataReader(System.Data.CommandBehavior behavior) -> System.Data.Common.DbDataReader! +override Npgsql.NpgsqlCommand.ExecuteDbDataReaderAsync(System.Data.CommandBehavior behavior, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlCommand.ExecuteNonQuery() -> int +override Npgsql.NpgsqlCommand.ExecuteNonQueryAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlCommand.ExecuteScalar() -> object? +override Npgsql.NpgsqlCommand.ExecuteScalarAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlCommand.Prepare() -> void +override Npgsql.NpgsqlCommand.PrepareAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlCommand.UpdatedRowSource.get -> System.Data.UpdateRowSource +override Npgsql.NpgsqlCommand.UpdatedRowSource.set -> void +override Npgsql.NpgsqlCommandBuilder.QuoteIdentifier(string! unquotedIdentifier) -> string! +override Npgsql.NpgsqlCommandBuilder.QuotePrefix.get -> string! +override Npgsql.NpgsqlCommandBuilder.QuotePrefix.set -> void +override Npgsql.NpgsqlCommandBuilder.QuoteSuffix.get -> string! +override Npgsql.NpgsqlCommandBuilder.QuoteSuffix.set -> void +override Npgsql.NpgsqlCommandBuilder.UnquoteIdentifier(string! quotedIdentifier) -> string! +override Npgsql.NpgsqlConnection.CanCreateBatch.get -> bool +override Npgsql.NpgsqlConnection.ChangeDatabase(string! dbName) -> void +override Npgsql.NpgsqlConnection.Close() -> void +override Npgsql.NpgsqlConnection.CloseAsync() -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlConnection.ConnectionString.get -> string! +override Npgsql.NpgsqlConnection.ConnectionString.set -> void +override Npgsql.NpgsqlConnection.ConnectionTimeout.get -> int +override Npgsql.NpgsqlConnection.Database.get -> string! +override Npgsql.NpgsqlConnection.DataSource.get -> string! +override Npgsql.NpgsqlConnection.DisposeAsync() -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlConnection.EnlistTransaction(System.Transactions.Transaction? transaction) -> void +override Npgsql.NpgsqlConnection.GetSchema() -> System.Data.DataTable! +override Npgsql.NpgsqlConnection.GetSchema(string? collectionName) -> System.Data.DataTable! +override Npgsql.NpgsqlConnection.GetSchema(string? collectionName, string?[]? restrictions) -> System.Data.DataTable! +override Npgsql.NpgsqlConnection.GetSchemaAsync(string! collectionName, string?[]? restrictions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlConnection.GetSchemaAsync(string! collectionName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlConnection.GetSchemaAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlConnection.Open() -> void +override Npgsql.NpgsqlConnection.OpenAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlConnection.ServerVersion.get -> string! +override Npgsql.NpgsqlConnection.State.get -> System.Data.ConnectionState +override Npgsql.NpgsqlConnectionStringBuilder.Clear() -> void +override Npgsql.NpgsqlConnectionStringBuilder.ContainsKey(string! keyword) -> bool +override Npgsql.NpgsqlConnectionStringBuilder.Equals(object? obj) -> bool +override Npgsql.NpgsqlConnectionStringBuilder.GetHashCode() -> int +override Npgsql.NpgsqlConnectionStringBuilder.Remove(string! keyword) -> bool +override Npgsql.NpgsqlConnectionStringBuilder.this[string! keyword].get -> object! +override Npgsql.NpgsqlConnectionStringBuilder.this[string! keyword].set -> void +override Npgsql.NpgsqlConnectionStringBuilder.TryGetValue(string! keyword, out object? value) -> bool +override Npgsql.NpgsqlDataReader.Close() -> void +override Npgsql.NpgsqlDataReader.CloseAsync() -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.Depth.get -> int +override Npgsql.NpgsqlDataReader.DisposeAsync() -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlDataReader.FieldCount.get -> int +override Npgsql.NpgsqlDataReader.GetBoolean(int ordinal) -> bool +override Npgsql.NpgsqlDataReader.GetByte(int ordinal) -> byte +override Npgsql.NpgsqlDataReader.GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) -> long +override Npgsql.NpgsqlDataReader.GetChar(int ordinal) -> char +override Npgsql.NpgsqlDataReader.GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) -> long +override Npgsql.NpgsqlDataReader.GetDataTypeName(int ordinal) -> string! +override Npgsql.NpgsqlDataReader.GetDateTime(int ordinal) -> System.DateTime +override Npgsql.NpgsqlDataReader.GetDecimal(int ordinal) -> decimal +override Npgsql.NpgsqlDataReader.GetDouble(int ordinal) -> double +override Npgsql.NpgsqlDataReader.GetEnumerator() -> System.Collections.IEnumerator! +override Npgsql.NpgsqlDataReader.GetFieldType(int ordinal) -> System.Type! +override Npgsql.NpgsqlDataReader.GetFieldValue(int ordinal) -> T +override Npgsql.NpgsqlDataReader.GetFieldValueAsync(int ordinal, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.GetFloat(int ordinal) -> float +override Npgsql.NpgsqlDataReader.GetGuid(int ordinal) -> System.Guid +override Npgsql.NpgsqlDataReader.GetInt16(int ordinal) -> short +override Npgsql.NpgsqlDataReader.GetInt32(int ordinal) -> int +override Npgsql.NpgsqlDataReader.GetInt64(int ordinal) -> long +override Npgsql.NpgsqlDataReader.GetName(int ordinal) -> string! +override Npgsql.NpgsqlDataReader.GetOrdinal(string! name) -> int +override Npgsql.NpgsqlDataReader.GetSchemaTable() -> System.Data.DataTable? +override Npgsql.NpgsqlDataReader.GetSchemaTableAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.GetStream(int ordinal) -> System.IO.Stream! +override Npgsql.NpgsqlDataReader.GetString(int ordinal) -> string! +override Npgsql.NpgsqlDataReader.GetTextReader(int ordinal) -> System.IO.TextReader! +override Npgsql.NpgsqlDataReader.GetValue(int ordinal) -> object! +override Npgsql.NpgsqlDataReader.GetValues(object![]! values) -> int +override Npgsql.NpgsqlDataReader.HasRows.get -> bool +override Npgsql.NpgsqlDataReader.IsClosed.get -> bool +override Npgsql.NpgsqlDataReader.IsDBNull(int ordinal) -> bool +override Npgsql.NpgsqlDataReader.IsDBNullAsync(int ordinal, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.NextResult() -> bool +override Npgsql.NpgsqlDataReader.NextResultAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.Read() -> bool +override Npgsql.NpgsqlDataReader.ReadAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlDataReader.RecordsAffected.get -> int +override Npgsql.NpgsqlDataReader.this[int ordinal].get -> object! +override Npgsql.NpgsqlDataReader.this[string! name].get -> object! +override Npgsql.NpgsqlDataSource.ConnectionString.get -> string! +override Npgsql.NpgsqlException.DbBatchCommand.get -> System.Data.Common.DbBatchCommand? +override Npgsql.NpgsqlException.IsTransient.get -> bool +override Npgsql.NpgsqlFactory.CanCreateBatch.get -> bool +override Npgsql.NpgsqlFactory.CanCreateCommandBuilder.get -> bool +override Npgsql.NpgsqlFactory.CanCreateDataAdapter.get -> bool +override Npgsql.NpgsqlFactory.CreateBatch() -> System.Data.Common.DbBatch! +override Npgsql.NpgsqlFactory.CreateBatchCommand() -> System.Data.Common.DbBatchCommand! +override Npgsql.NpgsqlFactory.CreateCommand() -> System.Data.Common.DbCommand! +override Npgsql.NpgsqlFactory.CreateCommandBuilder() -> System.Data.Common.DbCommandBuilder! +override Npgsql.NpgsqlFactory.CreateConnection() -> System.Data.Common.DbConnection! +override Npgsql.NpgsqlFactory.CreateConnectionStringBuilder() -> System.Data.Common.DbConnectionStringBuilder! +override Npgsql.NpgsqlFactory.CreateDataAdapter() -> System.Data.Common.DbDataAdapter! +override Npgsql.NpgsqlFactory.CreateDataSource(string! connectionString) -> System.Data.Common.DbDataSource! +override Npgsql.NpgsqlFactory.CreateParameter() -> System.Data.Common.DbParameter! +override Npgsql.NpgsqlLargeObjectStream.CanRead.get -> bool +override Npgsql.NpgsqlLargeObjectStream.CanSeek.get -> bool +override Npgsql.NpgsqlLargeObjectStream.CanTimeout.get -> bool +override Npgsql.NpgsqlLargeObjectStream.CanWrite.get -> bool +override Npgsql.NpgsqlLargeObjectStream.Close() -> void +override Npgsql.NpgsqlLargeObjectStream.Flush() -> void +override Npgsql.NpgsqlLargeObjectStream.Length.get -> long +override Npgsql.NpgsqlLargeObjectStream.Position.get -> long +override Npgsql.NpgsqlLargeObjectStream.Position.set -> void +override Npgsql.NpgsqlLargeObjectStream.Read(byte[]! buffer, int offset, int count) -> int +override Npgsql.NpgsqlLargeObjectStream.ReadAsync(byte[]! buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlLargeObjectStream.Seek(long offset, System.IO.SeekOrigin origin) -> long +override Npgsql.NpgsqlLargeObjectStream.SetLength(long value) -> void +override Npgsql.NpgsqlLargeObjectStream.Write(byte[]! buffer, int offset, int count) -> void +override Npgsql.NpgsqlLargeObjectStream.WriteAsync(byte[]! buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlNestedDataReader.Close() -> void +override Npgsql.NpgsqlNestedDataReader.Depth.get -> int +override Npgsql.NpgsqlNestedDataReader.FieldCount.get -> int +override Npgsql.NpgsqlNestedDataReader.GetBoolean(int ordinal) -> bool +override Npgsql.NpgsqlNestedDataReader.GetByte(int ordinal) -> byte +override Npgsql.NpgsqlNestedDataReader.GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) -> long +override Npgsql.NpgsqlNestedDataReader.GetChar(int ordinal) -> char +override Npgsql.NpgsqlNestedDataReader.GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) -> long +override Npgsql.NpgsqlNestedDataReader.GetDataTypeName(int ordinal) -> string! +override Npgsql.NpgsqlNestedDataReader.GetDateTime(int ordinal) -> System.DateTime +override Npgsql.NpgsqlNestedDataReader.GetDecimal(int ordinal) -> decimal +override Npgsql.NpgsqlNestedDataReader.GetDouble(int ordinal) -> double +override Npgsql.NpgsqlNestedDataReader.GetEnumerator() -> System.Collections.IEnumerator! +override Npgsql.NpgsqlNestedDataReader.GetFieldType(int ordinal) -> System.Type! +override Npgsql.NpgsqlNestedDataReader.GetFieldValue(int ordinal) -> T +override Npgsql.NpgsqlNestedDataReader.GetFloat(int ordinal) -> float +override Npgsql.NpgsqlNestedDataReader.GetGuid(int ordinal) -> System.Guid +override Npgsql.NpgsqlNestedDataReader.GetInt16(int ordinal) -> short +override Npgsql.NpgsqlNestedDataReader.GetInt32(int ordinal) -> int +override Npgsql.NpgsqlNestedDataReader.GetInt64(int ordinal) -> long +override Npgsql.NpgsqlNestedDataReader.GetName(int ordinal) -> string! +override Npgsql.NpgsqlNestedDataReader.GetOrdinal(string! name) -> int +override Npgsql.NpgsqlNestedDataReader.GetString(int ordinal) -> string! +override Npgsql.NpgsqlNestedDataReader.GetValue(int ordinal) -> object! +override Npgsql.NpgsqlNestedDataReader.GetValues(object![]! values) -> int +override Npgsql.NpgsqlNestedDataReader.HasRows.get -> bool +override Npgsql.NpgsqlNestedDataReader.IsClosed.get -> bool +override Npgsql.NpgsqlNestedDataReader.IsDBNull(int ordinal) -> bool +override Npgsql.NpgsqlNestedDataReader.NextResult() -> bool +override Npgsql.NpgsqlNestedDataReader.Read() -> bool +override Npgsql.NpgsqlNestedDataReader.RecordsAffected.get -> int +override Npgsql.NpgsqlNestedDataReader.this[int ordinal].get -> object! +override Npgsql.NpgsqlNestedDataReader.this[string! name].get -> object! +override Npgsql.NpgsqlParameter.ResetDbType() -> void +override Npgsql.NpgsqlParameter.Value.get -> object? +override Npgsql.NpgsqlParameter.Value.set -> void +override Npgsql.NpgsqlParameter.Value.get -> object? +override Npgsql.NpgsqlParameter.Value.set -> void +override Npgsql.NpgsqlParameterCollection.Add(object! value) -> int +override Npgsql.NpgsqlParameterCollection.AddRange(System.Array! values) -> void +override Npgsql.NpgsqlParameterCollection.Clear() -> void +override Npgsql.NpgsqlParameterCollection.Contains(object! value) -> bool +override Npgsql.NpgsqlParameterCollection.Contains(string! parameterName) -> bool +override Npgsql.NpgsqlParameterCollection.CopyTo(System.Array! array, int index) -> void +override Npgsql.NpgsqlParameterCollection.Count.get -> int +override Npgsql.NpgsqlParameterCollection.GetEnumerator() -> System.Collections.IEnumerator! +override Npgsql.NpgsqlParameterCollection.IndexOf(object! value) -> int +override Npgsql.NpgsqlParameterCollection.IndexOf(string! parameterName) -> int +override Npgsql.NpgsqlParameterCollection.Insert(int index, object! value) -> void +override Npgsql.NpgsqlParameterCollection.IsFixedSize.get -> bool +override Npgsql.NpgsqlParameterCollection.IsReadOnly.get -> bool +override Npgsql.NpgsqlParameterCollection.IsSynchronized.get -> bool +override Npgsql.NpgsqlParameterCollection.Remove(object! value) -> void +override Npgsql.NpgsqlParameterCollection.RemoveAt(int index) -> void +override Npgsql.NpgsqlParameterCollection.RemoveAt(string! parameterName) -> void +override Npgsql.NpgsqlParameterCollection.SyncRoot.get -> object! +override Npgsql.NpgsqlRawCopyStream.CanRead.get -> bool +override Npgsql.NpgsqlRawCopyStream.CanSeek.get -> bool +override Npgsql.NpgsqlRawCopyStream.CanTimeout.get -> bool +override Npgsql.NpgsqlRawCopyStream.CanWrite.get -> bool +override Npgsql.NpgsqlRawCopyStream.DisposeAsync() -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlRawCopyStream.Flush() -> void +override Npgsql.NpgsqlRawCopyStream.FlushAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlRawCopyStream.Length.get -> long +override Npgsql.NpgsqlRawCopyStream.Position.get -> long +override Npgsql.NpgsqlRawCopyStream.Position.set -> void +override Npgsql.NpgsqlRawCopyStream.Read(byte[]! buffer, int offset, int count) -> int +override Npgsql.NpgsqlRawCopyStream.Read(System.Span span) -> int +override Npgsql.NpgsqlRawCopyStream.ReadAsync(byte[]! buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlRawCopyStream.ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlRawCopyStream.ReadTimeout.get -> int +override Npgsql.NpgsqlRawCopyStream.ReadTimeout.set -> void +override Npgsql.NpgsqlRawCopyStream.Seek(long offset, System.IO.SeekOrigin origin) -> long +override Npgsql.NpgsqlRawCopyStream.SetLength(long value) -> void +override Npgsql.NpgsqlRawCopyStream.Write(byte[]! buffer, int offset, int count) -> void +override Npgsql.NpgsqlRawCopyStream.Write(System.ReadOnlySpan buffer) -> void +override Npgsql.NpgsqlRawCopyStream.WriteAsync(byte[]! buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlRawCopyStream.WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlRawCopyStream.WriteTimeout.get -> int +override Npgsql.NpgsqlRawCopyStream.WriteTimeout.set -> void +override Npgsql.NpgsqlTransaction.Commit() -> void +override Npgsql.NpgsqlTransaction.CommitAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlTransaction.DisposeAsync() -> System.Threading.Tasks.ValueTask +override Npgsql.NpgsqlTransaction.IsolationLevel.get -> System.Data.IsolationLevel +override Npgsql.NpgsqlTransaction.Release(string! name) -> void +override Npgsql.NpgsqlTransaction.ReleaseAsync(string! name, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlTransaction.Rollback() -> void +override Npgsql.NpgsqlTransaction.Rollback(string! name) -> void +override Npgsql.NpgsqlTransaction.RollbackAsync(string! name, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlTransaction.RollbackAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlTransaction.Save(string! name) -> void +override Npgsql.NpgsqlTransaction.SaveAsync(string! name, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Npgsql.NpgsqlTransaction.SupportsSavepoints.get -> bool +override Npgsql.PostgresException.GetObjectData(System.Runtime.Serialization.SerializationInfo! info, System.Runtime.Serialization.StreamingContext context) -> void +override Npgsql.PostgresException.IsTransient.get -> bool +override Npgsql.PostgresException.SqlState.get -> string! +override Npgsql.PostgresException.ToString() -> string! +override Npgsql.PostgresTypes.PostgresCompositeType.Field.ToString() -> string! +override Npgsql.PostgresTypes.PostgresType.ToString() -> string! +override Npgsql.Replication.PgOutput.Messages.DefaultUpdateMessage.NewRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +override Npgsql.Replication.PgOutput.Messages.FullUpdateMessage.NewRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +override Npgsql.Replication.PgOutput.Messages.IndexUpdateMessage.NewRow.get -> Npgsql.Replication.PgOutput.ReplicationTuple! +override Npgsql.Replication.PgOutput.Messages.PgOutputReplicationMessage.ToString() -> string! +override Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Equals(object? obj) -> bool +override Npgsql.Replication.PgOutput.PgOutputReplicationOptions.GetHashCode() -> int +override Npgsql.Replication.TestDecoding.TestDecodingData.ToString() -> string! +override Npgsql.Replication.TestDecoding.TestDecodingOptions.Equals(object? obj) -> bool +override Npgsql.Replication.TestDecoding.TestDecodingOptions.GetHashCode() -> int +override Npgsql.Schema.NpgsqlDbColumn.this[string! propertyName].get -> object? +override NpgsqlTypes.NpgsqlBox.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlBox.GetHashCode() -> int +override NpgsqlTypes.NpgsqlBox.ToString() -> string! +override NpgsqlTypes.NpgsqlCidr.ToString() -> string! +override NpgsqlTypes.NpgsqlCircle.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlCircle.GetHashCode() -> int +override NpgsqlTypes.NpgsqlCircle.ToString() -> string! +override NpgsqlTypes.NpgsqlInet.ToString() -> string! +override NpgsqlTypes.NpgsqlInterval.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlInterval.GetHashCode() -> int +override NpgsqlTypes.NpgsqlLine.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlLine.GetHashCode() -> int +override NpgsqlTypes.NpgsqlLine.ToString() -> string! +override NpgsqlTypes.NpgsqlLogSequenceNumber.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlLogSequenceNumber.GetHashCode() -> int +override NpgsqlTypes.NpgsqlLogSequenceNumber.ToString() -> string! +override NpgsqlTypes.NpgsqlLSeg.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlLSeg.GetHashCode() -> int +override NpgsqlTypes.NpgsqlLSeg.ToString() -> string! +override NpgsqlTypes.NpgsqlPath.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlPath.GetHashCode() -> int +override NpgsqlTypes.NpgsqlPath.ToString() -> string! +override NpgsqlTypes.NpgsqlPoint.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlPoint.GetHashCode() -> int +override NpgsqlTypes.NpgsqlPoint.ToString() -> string! +override NpgsqlTypes.NpgsqlPolygon.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlPolygon.GetHashCode() -> int +override NpgsqlTypes.NpgsqlPolygon.ToString() -> string! +override NpgsqlTypes.NpgsqlRange.Equals(object? o) -> bool +override NpgsqlTypes.NpgsqlRange.GetHashCode() -> int +override NpgsqlTypes.NpgsqlRange.RangeTypeConverter.CanConvertFrom(System.ComponentModel.ITypeDescriptorContext? context, System.Type! sourceType) -> bool +override NpgsqlTypes.NpgsqlRange.RangeTypeConverter.CanConvertTo(System.ComponentModel.ITypeDescriptorContext? context, System.Type? destinationType) -> bool +override NpgsqlTypes.NpgsqlRange.RangeTypeConverter.ConvertFrom(System.ComponentModel.ITypeDescriptorContext? context, System.Globalization.CultureInfo? culture, object! value) -> object? +override NpgsqlTypes.NpgsqlRange.RangeTypeConverter.ConvertTo(System.ComponentModel.ITypeDescriptorContext? context, System.Globalization.CultureInfo? culture, object? value, System.Type! destinationType) -> object? +override NpgsqlTypes.NpgsqlRange.ToString() -> string! +override NpgsqlTypes.NpgsqlTid.Equals(object? o) -> bool +override NpgsqlTypes.NpgsqlTid.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTid.ToString() -> string! +override NpgsqlTypes.NpgsqlTsQuery.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlTsQuery.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQuery.ToString() -> string! +override NpgsqlTypes.NpgsqlTsQueryAnd.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryAnd.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQueryEmpty.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryEmpty.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQueryFollowedBy.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryFollowedBy.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQueryLexeme.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryLexeme.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQueryNot.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryNot.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsQueryOr.Equals(NpgsqlTypes.NpgsqlTsQuery? other) -> bool +override NpgsqlTypes.NpgsqlTsQueryOr.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsVector.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlTsVector.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsVector.Lexeme.Equals(object? o) -> bool +override NpgsqlTypes.NpgsqlTsVector.Lexeme.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsVector.Lexeme.ToString() -> string! +override NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.Equals(object? o) -> bool +override NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.GetHashCode() -> int +override NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.ToString() -> string! +override NpgsqlTypes.NpgsqlTsVector.ToString() -> string! +override sealed Npgsql.NpgsqlParameter.DbType.get -> System.Data.DbType +override sealed Npgsql.NpgsqlParameter.DbType.set -> void +override sealed Npgsql.NpgsqlParameter.Direction.get -> System.Data.ParameterDirection +override sealed Npgsql.NpgsqlParameter.Direction.set -> void +override sealed Npgsql.NpgsqlParameter.IsNullable.get -> bool +override sealed Npgsql.NpgsqlParameter.IsNullable.set -> void +override sealed Npgsql.NpgsqlParameter.ParameterName.get -> string! +override sealed Npgsql.NpgsqlParameter.ParameterName.set -> void +override sealed Npgsql.NpgsqlParameter.Size.get -> int +override sealed Npgsql.NpgsqlParameter.Size.set -> void +override sealed Npgsql.NpgsqlParameter.SourceColumn.get -> string! +override sealed Npgsql.NpgsqlParameter.SourceColumn.set -> void +override sealed Npgsql.NpgsqlParameter.SourceColumnNullMapping.get -> bool +override sealed Npgsql.NpgsqlParameter.SourceColumnNullMapping.set -> void +override sealed Npgsql.NpgsqlParameter.SourceVersion.get -> System.Data.DataRowVersion +override sealed Npgsql.NpgsqlParameter.SourceVersion.set -> void +static Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.ConvertToSnakeCase(string! name, System.Globalization.CultureInfo? culture = null) -> string! +static Npgsql.NpgsqlCommandBuilder.DeriveParameters(Npgsql.NpgsqlCommand! command) -> void +static Npgsql.NpgsqlConnection.ClearAllPools() -> void +static Npgsql.NpgsqlConnection.ClearPool(Npgsql.NpgsqlConnection! connection) -> void +static Npgsql.NpgsqlConnection.GlobalTypeMapper.get -> Npgsql.TypeMapping.INpgsqlTypeMapper! +static Npgsql.NpgsqlDataSource.Create(Npgsql.NpgsqlConnectionStringBuilder! connectionStringBuilder) -> Npgsql.NpgsqlDataSource! +static Npgsql.NpgsqlDataSource.Create(string! connectionString) -> Npgsql.NpgsqlDataSource! +static Npgsql.NpgsqlLoggingConfiguration.InitializeLogging(Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory, bool parameterLoggingEnabled = false) -> void +static Npgsql.Replication.Internal.LogicalReplicationConnectionExtensions.CreateLogicalReplicationSlot(this Npgsql.Replication.LogicalReplicationConnection! connection, string! slotName, string! outputPlugin, bool isTemporary = false, Npgsql.Replication.LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, bool twoPhase = false, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Npgsql.Replication.Internal.LogicalReplicationConnectionExtensions.StartLogicalReplication(this Npgsql.Replication.LogicalReplicationConnection! connection, Npgsql.Replication.Internal.LogicalReplicationSlot! slot, System.Threading.CancellationToken cancellationToken, NpgsqlTypes.NpgsqlLogSequenceNumber? walLocation = null, System.Collections.Generic.IEnumerable>? options = null, bool bypassingStream = false) -> System.Collections.Generic.IAsyncEnumerable! +static Npgsql.Replication.PgOutputConnectionExtensions.CreatePgOutputReplicationSlot(this Npgsql.Replication.LogicalReplicationConnection! connection, string! slotName, bool temporarySlot = false, Npgsql.Replication.LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, bool twoPhase = false, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Npgsql.Replication.PgOutputConnectionExtensions.StartReplication(this Npgsql.Replication.LogicalReplicationConnection! connection, Npgsql.Replication.PgOutput.PgOutputReplicationSlot! slot, Npgsql.Replication.PgOutput.PgOutputReplicationOptions! options, System.Threading.CancellationToken cancellationToken, NpgsqlTypes.NpgsqlLogSequenceNumber? walLocation = null) -> System.Collections.Generic.IAsyncEnumerable! +static Npgsql.Replication.TestDecodingConnectionExtensions.CreateTestDecodingReplicationSlot(this Npgsql.Replication.LogicalReplicationConnection! connection, string! slotName, bool temporarySlot = false, Npgsql.Replication.LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, bool twoPhase = false, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Npgsql.Replication.TestDecodingConnectionExtensions.StartReplication(this Npgsql.Replication.LogicalReplicationConnection! connection, Npgsql.Replication.TestDecoding.TestDecodingReplicationSlot! slot, System.Threading.CancellationToken cancellationToken, Npgsql.Replication.TestDecoding.TestDecodingOptions? options = null, NpgsqlTypes.NpgsqlLogSequenceNumber? walLocation = null) -> System.Collections.Generic.IAsyncEnumerable! +static NpgsqlTypes.NpgsqlBox.operator !=(NpgsqlTypes.NpgsqlBox x, NpgsqlTypes.NpgsqlBox y) -> bool +static NpgsqlTypes.NpgsqlBox.operator ==(NpgsqlTypes.NpgsqlBox x, NpgsqlTypes.NpgsqlBox y) -> bool +static NpgsqlTypes.NpgsqlCidr.explicit operator System.Net.IPAddress!(NpgsqlTypes.NpgsqlCidr cidr) -> System.Net.IPAddress! +static NpgsqlTypes.NpgsqlCidr.implicit operator NpgsqlTypes.NpgsqlInet(NpgsqlTypes.NpgsqlCidr cidr) -> NpgsqlTypes.NpgsqlInet +static NpgsqlTypes.NpgsqlCircle.operator !=(NpgsqlTypes.NpgsqlCircle x, NpgsqlTypes.NpgsqlCircle y) -> bool +static NpgsqlTypes.NpgsqlCircle.operator ==(NpgsqlTypes.NpgsqlCircle x, NpgsqlTypes.NpgsqlCircle y) -> bool +static NpgsqlTypes.NpgsqlInet.explicit operator System.Net.IPAddress!(NpgsqlTypes.NpgsqlInet inet) -> System.Net.IPAddress! +static NpgsqlTypes.NpgsqlInet.implicit operator NpgsqlTypes.NpgsqlInet(System.Net.IPAddress! ip) -> NpgsqlTypes.NpgsqlInet +static NpgsqlTypes.NpgsqlLine.operator !=(NpgsqlTypes.NpgsqlLine x, NpgsqlTypes.NpgsqlLine y) -> bool +static NpgsqlTypes.NpgsqlLine.operator ==(NpgsqlTypes.NpgsqlLine x, NpgsqlTypes.NpgsqlLine y) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.explicit operator NpgsqlTypes.NpgsqlLogSequenceNumber(ulong value) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.explicit operator ulong(NpgsqlTypes.NpgsqlLogSequenceNumber value) -> ulong +static NpgsqlTypes.NpgsqlLogSequenceNumber.Larger(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator !=(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator +(NpgsqlTypes.NpgsqlLogSequenceNumber lsn, double nbytes) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator -(NpgsqlTypes.NpgsqlLogSequenceNumber first, NpgsqlTypes.NpgsqlLogSequenceNumber second) -> ulong +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator -(NpgsqlTypes.NpgsqlLogSequenceNumber lsn, double nbytes) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator <(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator <=(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator ==(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator >(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.operator >=(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.Parse(string! s) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.Parse(System.ReadOnlySpan s) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.Smaller(NpgsqlTypes.NpgsqlLogSequenceNumber value1, NpgsqlTypes.NpgsqlLogSequenceNumber value2) -> NpgsqlTypes.NpgsqlLogSequenceNumber +static NpgsqlTypes.NpgsqlLogSequenceNumber.TryParse(string! s, out NpgsqlTypes.NpgsqlLogSequenceNumber result) -> bool +static NpgsqlTypes.NpgsqlLogSequenceNumber.TryParse(System.ReadOnlySpan s, out NpgsqlTypes.NpgsqlLogSequenceNumber result) -> bool +static NpgsqlTypes.NpgsqlLSeg.operator !=(NpgsqlTypes.NpgsqlLSeg x, NpgsqlTypes.NpgsqlLSeg y) -> bool +static NpgsqlTypes.NpgsqlLSeg.operator ==(NpgsqlTypes.NpgsqlLSeg x, NpgsqlTypes.NpgsqlLSeg y) -> bool +static NpgsqlTypes.NpgsqlPath.operator !=(NpgsqlTypes.NpgsqlPath x, NpgsqlTypes.NpgsqlPath y) -> bool +static NpgsqlTypes.NpgsqlPath.operator ==(NpgsqlTypes.NpgsqlPath x, NpgsqlTypes.NpgsqlPath y) -> bool +static NpgsqlTypes.NpgsqlPoint.operator !=(NpgsqlTypes.NpgsqlPoint x, NpgsqlTypes.NpgsqlPoint y) -> bool +static NpgsqlTypes.NpgsqlPoint.operator ==(NpgsqlTypes.NpgsqlPoint x, NpgsqlTypes.NpgsqlPoint y) -> bool +static NpgsqlTypes.NpgsqlPolygon.operator !=(NpgsqlTypes.NpgsqlPolygon x, NpgsqlTypes.NpgsqlPolygon y) -> bool +static NpgsqlTypes.NpgsqlPolygon.operator ==(NpgsqlTypes.NpgsqlPolygon x, NpgsqlTypes.NpgsqlPolygon y) -> bool +static NpgsqlTypes.NpgsqlRange.operator !=(NpgsqlTypes.NpgsqlRange x, NpgsqlTypes.NpgsqlRange y) -> bool +static NpgsqlTypes.NpgsqlRange.operator ==(NpgsqlTypes.NpgsqlRange x, NpgsqlTypes.NpgsqlRange y) -> bool +static NpgsqlTypes.NpgsqlRange.Parse(string! value) -> NpgsqlTypes.NpgsqlRange +static NpgsqlTypes.NpgsqlRange.RangeTypeConverter.Register() -> void +static NpgsqlTypes.NpgsqlTid.operator !=(NpgsqlTypes.NpgsqlTid left, NpgsqlTypes.NpgsqlTid right) -> bool +static NpgsqlTypes.NpgsqlTid.operator ==(NpgsqlTypes.NpgsqlTid left, NpgsqlTypes.NpgsqlTid right) -> bool +static NpgsqlTypes.NpgsqlTsQuery.operator !=(NpgsqlTypes.NpgsqlTsQuery? left, NpgsqlTypes.NpgsqlTsQuery? right) -> bool +static NpgsqlTypes.NpgsqlTsQuery.operator ==(NpgsqlTypes.NpgsqlTsQuery? left, NpgsqlTypes.NpgsqlTsQuery? right) -> bool +static NpgsqlTypes.NpgsqlTsQuery.Parse(string! value) -> NpgsqlTypes.NpgsqlTsQuery! +static NpgsqlTypes.NpgsqlTsVector.Lexeme.operator !=(NpgsqlTypes.NpgsqlTsVector.Lexeme left, NpgsqlTypes.NpgsqlTsVector.Lexeme right) -> bool +static NpgsqlTypes.NpgsqlTsVector.Lexeme.operator ==(NpgsqlTypes.NpgsqlTsVector.Lexeme left, NpgsqlTypes.NpgsqlTsVector.Lexeme right) -> bool +static NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.operator !=(NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos left, NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos right) -> bool +static NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos.operator ==(NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos left, NpgsqlTypes.NpgsqlTsVector.Lexeme.WordEntryPos right) -> bool +static NpgsqlTypes.NpgsqlTsVector.Parse(string! value) -> NpgsqlTypes.NpgsqlTsVector! +static readonly Npgsql.NpgsqlFactory.Instance -> Npgsql.NpgsqlFactory! +static readonly NpgsqlTypes.NpgsqlLogSequenceNumber.Invalid -> NpgsqlTypes.NpgsqlLogSequenceNumber +static readonly NpgsqlTypes.NpgsqlRange.Empty -> NpgsqlTypes.NpgsqlRange +virtual Npgsql.NpgsqlCommand.Clone() -> Npgsql.NpgsqlCommand! +virtual Npgsql.Replication.PgOutput.ReplicationTuple.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Collections.Generic.IAsyncEnumerator! diff --git a/src/Npgsql/PublicAPI.Unshipped.txt b/src/Npgsql/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..ab058de62d --- /dev/null +++ b/src/Npgsql/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +#nullable enable diff --git a/src/Npgsql/README.md b/src/Npgsql/README.md new file mode 100644 index 0000000000..8a80f79588 --- /dev/null +++ b/src/Npgsql/README.md @@ -0,0 +1,44 @@ +Npgsql is the open source .NET data provider for PostgreSQL. It allows you to connect and interact with PostgreSQL server using .NET. + +## Quickstart + +Here's a basic code snippet to get you started: + +```csharp +var connString = "Host=myserver;Username=mylogin;Password=mypass;Database=mydatabase"; + +await using var conn = new NpgsqlConnection(connString); +await conn.OpenAsync(); + +// Insert some data +await using (var cmd = new NpgsqlCommand("INSERT INTO data (some_field) VALUES (@p)", conn)) +{ + cmd.Parameters.AddWithValue("p", "Hello world"); + await cmd.ExecuteNonQueryAsync(); +} + +// Retrieve all rows +await using (var cmd = new NpgsqlCommand("SELECT some_field FROM data", conn)) +await using (var reader = await cmd.ExecuteReaderAsync()) +{ +while (await reader.ReadAsync()) + Console.WriteLine(reader.GetString(0)); +} +``` + +## Key features + +* High-performance PostgreSQL driver. Regularly figures in the top contenders on the [TechEmpower Web Framework Benchmarks](https://www.techempower.com/benchmarks/). +* Full support of most PostgreSQL types, including advanced ones such as arrays, enums, ranges, multiranges, composites, JSON, PostGIS and others. +* Highly-efficient bulk import/export API. +* Failover, load balancing and general multi-host support. +* Great integration with Entity Framework Core via [Npgsql.EntityFrameworkCore.PostgreSQL](https://www.nuget.org/packages/Npgsql.EntityFrameworkCore.PostgreSQL). + +For the full documentation, please visit [the Npgsql website](https://www.npgsql.org). + +## Related packages + +* The Entity Framework Core provider that works with this provider is [Npgsql.EntityFrameworkCore.PostgreSQL](https://www.nuget.org/packages/Npgsql.EntityFrameworkCore.PostgreSQL). +* Spatial plugin to work with PostgreSQL PostGIS: [Npgsql.NetTopologySuite](https://www.nuget.org/packages/Npgsql.NetTopologySuite) +* NodaTime plugin to use better date/time types with PostgreSQL: [Npgsql.NodaTime](https://www.nuget.org/packages/Npgsql.NodaTime) +* OpenTelemetry support can be set up with [Npgsql.OpenTelemetry](https://www.nuget.org/packages/Npgsql.OpenTelemetry) \ No newline at end of file diff --git a/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs b/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs index f4a18ef729..6f703970de 100644 --- a/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs +++ b/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs @@ -1,130 +1,184 @@ using NpgsqlTypes; using System; using System.Collections.Generic; -using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; -namespace Npgsql.Replication.Internal +namespace Npgsql.Replication.Internal; + +/// +/// This API is for internal use and for implementing logical replication plugins. +/// It is not meant to be consumed in common Npgsql usage scenarios. +/// +public static class LogicalReplicationConnectionExtensions { /// /// This API is for internal use and for implementing logical replication plugins. /// It is not meant to be consumed in common Npgsql usage scenarios. /// - public static class LogicalReplicationConnectionExtensions + /// + /// Creates a new replication slot and returns information about the newly-created slot. + /// + /// The to use for creating the + /// replication slot + /// The name of the slot to create. Must be a valid replication slot name (see + /// + /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). + /// + /// The name of the output plugin used for logical decoding (see + /// + /// https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html). + /// + /// if this replication slot shall be temporary one; otherwise + /// . Temporary slots are not saved to disk and are automatically dropped on error or + /// when the session has finished. + /// A to specify what to do with the + /// snapshot created during logical slot initialization. , which is + /// also the default, will export the snapshot for use in other sessions. This option can't be used inside a + /// transaction. will use the snapshot for the current transaction + /// executing the command. This option must be used in a transaction, and + /// must be the first command run in that transaction. Finally, will + /// just use the snapshot for logical decoding as normal but won't do anything else with it. + /// + /// If , this logical replication slot supports decoding of two-phase transactions. With this option, + /// two-phase commands like PREPARE TRANSACTION, COMMIT PREPARED and ROLLBACK PREPARED are decoded and transmitted. + /// The transaction will be decoded and transmitted at PREPARE TRANSACTION time. The default is . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A representing a class that + /// can be used to initialize instances of subclasses. + public static Task CreateLogicalReplicationSlot( + this LogicalReplicationConnection connection, + string slotName, + string outputPlugin, + bool isTemporary = false, + LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, + bool twoPhase = false, + CancellationToken cancellationToken = default) { - /// - /// This API is for internal use and for implementing logical replication plugins. - /// It is not meant to be consumed in common Npgsql usage scenarios. - /// - /// - /// Creates a new replication slot and returns information about the newly-created slot. - /// - /// The to use for creating the - /// replication slot - /// The name of the slot to create. Must be a valid replication slot name (see - /// - /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). - /// - /// The name of the output plugin used for logical decoding (see - /// - /// https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html). - /// - /// if this replication slot shall be temporary one; otherwise - /// . Temporary slots are not saved to disk and are automatically dropped on error or - /// when the session has finished. - /// A to specify what to do with the - /// snapshot created during logical slot initialization. , which is - /// also the default, will export the snapshot for use in other sessions. This option can't be used inside a - /// transaction. will use the snapshot for the current transaction - /// executing the command. This option must be used in a transaction, and - /// must be the first command run in that transaction. Finally, will - /// just use the snapshot for logical decoding as normal but won't do anything else with it. - /// The token to monitor for cancellation requests. - /// The default value is . - /// A representing a class that - /// can be used to initialize instances of subclasses. - public static Task CreateLogicalReplicationSlot( - this LogicalReplicationConnection connection, - string slotName, - string outputPlugin, - bool isTemporary = false, - LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, - CancellationToken cancellationToken = default) - { - using var _ = NoSynchronizationContextScope.Enter(); - return CreateLogicalReplicationSlotCore(); + connection.CheckDisposed(); + if (slotName is null) + throw new ArgumentNullException(nameof(slotName)); + if (outputPlugin is null) + throw new ArgumentNullException(nameof(outputPlugin)); - Task CreateLogicalReplicationSlotCore() - { - if (slotName is null) - throw new ArgumentNullException(nameof(slotName)); - if (outputPlugin is null) - throw new ArgumentNullException(nameof(outputPlugin)); + cancellationToken.ThrowIfCancellationRequested(); - cancellationToken.ThrowIfCancellationRequested(); - - var builder = new StringBuilder("CREATE_REPLICATION_SLOT ").Append(slotName); - if (isTemporary) - builder.Append(" TEMPORARY"); - builder.Append(" LOGICAL ").Append(outputPlugin); + var builder = new StringBuilder("CREATE_REPLICATION_SLOT ").Append(slotName); + if (isTemporary) + builder.Append(" TEMPORARY"); + builder.Append(" LOGICAL ").Append(outputPlugin); + if (connection.PostgreSqlVersion.Major >= 15 && (slotSnapshotInitMode.HasValue || twoPhase)) + { + builder.Append('('); + if (slotSnapshotInitMode.HasValue) + { builder.Append(slotSnapshotInitMode switch { - // EXPORT_SNAPSHOT is the default since it has been introduced. - // We don't set it unless it is explicitly requested so that older backends can digest the query too. - null => string.Empty, - LogicalSlotSnapshotInitMode.Export => " EXPORT_SNAPSHOT", - LogicalSlotSnapshotInitMode.Use => " USE_SNAPSHOT", - LogicalSlotSnapshotInitMode.NoExport => " NOEXPORT_SNAPSHOT", + LogicalSlotSnapshotInitMode.Export => "SNAPSHOT 'export'", + LogicalSlotSnapshotInitMode.Use => "SNAPSHOT 'use'", + LogicalSlotSnapshotInitMode.NoExport => "SNAPSHOT 'nothing'", _ => throw new ArgumentOutOfRangeException(nameof(slotSnapshotInitMode), slotSnapshotInitMode, $"Unexpected value {slotSnapshotInitMode} for argument {nameof(slotSnapshotInitMode)}.") }); - - return connection.CreateReplicationSlot(builder.ToString(), isTemporary, cancellationToken); + if (twoPhase) + builder.Append(",TWO_PHASE"); } + else + builder.Append("TWO_PHASE"); + builder.Append(')'); + } + else + { + builder.Append(slotSnapshotInitMode switch + { + // EXPORT_SNAPSHOT is the default since it has been introduced. + // We don't set it unless it is explicitly requested so that older backends can digest the query too. + null => string.Empty, + LogicalSlotSnapshotInitMode.Export => " EXPORT_SNAPSHOT", + LogicalSlotSnapshotInitMode.Use => " USE_SNAPSHOT", + LogicalSlotSnapshotInitMode.NoExport => " NOEXPORT_SNAPSHOT", + _ => throw new ArgumentOutOfRangeException(nameof(slotSnapshotInitMode), + slotSnapshotInitMode, + $"Unexpected value {slotSnapshotInitMode} for argument {nameof(slotSnapshotInitMode)}.") + }); + if (twoPhase) + builder.Append(" TWO_PHASE"); } + var command = builder.ToString(); - /// - /// Instructs the server to start streaming the WAL for logical replication, starting at WAL location - /// or at the slot's consistent point if isn't specified. - /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. - /// - /// The to use for starting replication - /// The replication slot that will be updated as replication progresses so that the server - /// knows which WAL segments are still needed by the standby. - /// - /// The token to monitor for stopping the replication. - /// The WAL location to begin streaming at. - /// The collection of options passed to the slot's logical decoding plugin. - /// - /// Whether the plugin will be bypassing and reading directly from the buffer. - /// - /// A representing an that - /// can be used to stream WAL entries in form of instances. - public static IAsyncEnumerable StartLogicalReplication( - this LogicalReplicationConnection connection, + LogMessages.CreatingReplicationSlot(connection.ReplicationLogger, slotName, command, connection.Connector.Id); + + return connection.CreateReplicationSlot(command, cancellationToken); + } + + /// + /// Instructs the server to start streaming the WAL for logical replication, starting at WAL location + /// or at the slot's consistent point if isn't specified. + /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. + /// + /// The to use for starting replication + /// The replication slot that will be updated as replication progresses so that the server + /// knows which WAL segments are still needed by the standby. + /// + /// The token to monitor for stopping the replication. + /// The WAL location to begin streaming at. + /// The collection of options passed to the slot's logical decoding plugin. + /// + /// Whether the plugin will be bypassing and reading directly from the buffer. + /// + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public static IAsyncEnumerable StartLogicalReplication( + this LogicalReplicationConnection connection, + LogicalReplicationSlot slot, + CancellationToken cancellationToken, + NpgsqlLogSequenceNumber? walLocation = null, + IEnumerable>? options = null, + bool bypassingStream = false) + { + return StartLogicalReplicationInternal(connection, slot, cancellationToken, walLocation, options, bypassingStream); + + // Local method to avoid having to add the EnumeratorCancellation attribute to the public signature. + static async IAsyncEnumerable StartLogicalReplicationInternal( + LogicalReplicationConnection connection, LogicalReplicationSlot slot, - CancellationToken cancellationToken, - NpgsqlLogSequenceNumber? walLocation = null, - IEnumerable>? options = null, - bool bypassingStream = false) + [EnumeratorCancellation] CancellationToken cancellationToken, + NpgsqlLogSequenceNumber? walLocation, + IEnumerable>? options, + bool bypassingStream) { var builder = new StringBuilder("START_REPLICATION ") .Append("SLOT ").Append(slot.Name) .Append(" LOGICAL ") .Append(walLocation ?? slot.ConsistentPoint); - if (options?.Any() == true) + var opts = new List>(options ?? Array.Empty>()); + if (opts.Count > 0) { - builder - .Append(" (") - .Append(string.Join(", ", options.Select(kv => @$"""{kv.Key}""{(kv.Value is null ? "" : $" '{kv.Value}'")}"))) - .Append(')'); + builder.Append(" ("); + var stringOptions = new string[opts.Count]; + for (var i = 0; i < opts.Count; i++) + { + var kv = opts[i]; + stringOptions[i] = @$"""{kv.Key}""{(kv.Value is null ? "" : $" '{kv.Value}'")}"; + } + builder.Append(string.Join(", ", stringOptions)); + builder.Append(')'); } - return connection.StartReplicationInternal(builder.ToString(), bypassingStream, cancellationToken); + var command = builder.ToString(); + + LogMessages.StartingLogicalReplication(connection.ReplicationLogger, slot.Name, command, connection.Connector.Id); + + var enumerator = connection.StartReplicationInternalWrapper(command, bypassingStream, cancellationToken); + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + yield return enumerator.Current; } } } diff --git a/src/Npgsql/Replication/Internal/LogicalReplicationSlot.cs b/src/Npgsql/Replication/Internal/LogicalReplicationSlot.cs index b182f4ff13..5edfa5d823 100644 --- a/src/Npgsql/Replication/Internal/LogicalReplicationSlot.cs +++ b/src/Npgsql/Replication/Internal/LogicalReplicationSlot.cs @@ -1,41 +1,40 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.Internal +namespace Npgsql.Replication.Internal; + +/// +/// Contains information about a newly-created logical replication slot. +/// +public abstract class LogicalReplicationSlot : ReplicationSlot { /// - /// Contains information about a newly-created logical replication slot. + /// Creates a new logical replication slot /// - public abstract class LogicalReplicationSlot : ReplicationSlot + /// The logical decoding output plugin to the corresponding replication slot was created for. + /// A struct with information to create the replication slot. + protected LogicalReplicationSlot(string outputPlugin, ReplicationSlotOptions replicationSlotOptions) + : base(replicationSlotOptions.SlotName) { - /// - /// Creates a new logical replication slot - /// - /// The logical decoding output plugin to the corresponding replication slot was created for. - /// A struct with information to create the replication slot. - protected LogicalReplicationSlot(string outputPlugin, ReplicationSlotOptions replicationSlotOptions) - : base(replicationSlotOptions.SlotName) - { - OutputPlugin = outputPlugin ?? throw new ArgumentNullException(nameof(outputPlugin), $"The {nameof(outputPlugin)} argument can not be null."); - SnapshotName = replicationSlotOptions.SnapshotName; - ConsistentPoint = replicationSlotOptions.ConsistentPoint; - } + OutputPlugin = outputPlugin ?? throw new ArgumentNullException(nameof(outputPlugin), $"The {nameof(outputPlugin)} argument can not be null."); + SnapshotName = replicationSlotOptions.SnapshotName; + ConsistentPoint = replicationSlotOptions.ConsistentPoint; + } - /// - /// The identifier of the snapshot exported by the command. - /// The snapshot is valid until a new command is executed on this connection or the replication connection is closed. - /// - public string? SnapshotName { get; } + /// + /// The identifier of the snapshot exported by the command. + /// The snapshot is valid until a new command is executed on this connection or the replication connection is closed. + /// + public string? SnapshotName { get; } - /// - /// The name of the output plugin used by the newly-created logical replication slot. - /// - public string OutputPlugin { get; } + /// + /// The name of the output plugin used by the newly-created logical replication slot. + /// + public string OutputPlugin { get; } - /// - /// The WAL location at which the slot became consistent. - /// This is the earliest location from which streaming can start on this replication slot. - /// - public NpgsqlLogSequenceNumber ConsistentPoint { get; } - } -} + /// + /// The WAL location at which the slot became consistent. + /// This is the earliest location from which streaming can start on this replication slot. + /// + public NpgsqlLogSequenceNumber ConsistentPoint { get; } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/LogicalReplicationConnection.cs b/src/Npgsql/Replication/LogicalReplicationConnection.cs index 0f23d3f2eb..7172b8a060 100644 --- a/src/Npgsql/Replication/LogicalReplicationConnection.cs +++ b/src/Npgsql/Replication/LogicalReplicationConnection.cs @@ -1,21 +1,20 @@ -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Represents a logical replication connection to a PostgreSQL server. +/// +public sealed class LogicalReplicationConnection : ReplicationConnection { + private protected override ReplicationMode ReplicationMode => ReplicationMode.Logical; + /// - /// Represents a logical replication connection to a PostgreSQL server. + /// Initializes a new instance of . /// - public sealed class LogicalReplicationConnection : ReplicationConnection - { - private protected override ReplicationMode ReplicationMode => ReplicationMode.Logical; - - /// - /// Initializes a new instance of . - /// - public LogicalReplicationConnection() {} + public LogicalReplicationConnection() {} - /// - /// Initializes a new instance of with the given connection string. - /// - /// The connection used to open the PostgreSQL database. - public LogicalReplicationConnection(string? connectionString) : base(connectionString) {} - } -} + /// + /// Initializes a new instance of with the given connection string. + /// + /// The connection used to open the PostgreSQL database. + public LogicalReplicationConnection(string? connectionString) : base(connectionString) {} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/LogicalSlotSnapshotInitMode.cs b/src/Npgsql/Replication/LogicalSlotSnapshotInitMode.cs index 7fa4e3eece..3e71c7ca7b 100644 --- a/src/Npgsql/Replication/LogicalSlotSnapshotInitMode.cs +++ b/src/Npgsql/Replication/LogicalSlotSnapshotInitMode.cs @@ -1,26 +1,25 @@ -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Decides what to do with the snapshot created during logical slot initialization. +/// +public enum LogicalSlotSnapshotInitMode { /// - /// Decides what to do with the snapshot created during logical slot initialization. + /// Export the snapshot for use in other sessions. This is the default. + /// This option can't be used inside a transaction. /// - public enum LogicalSlotSnapshotInitMode - { - /// - /// Export the snapshot for use in other sessions. This is the default. - /// This option can't be used inside a transaction. - /// - Export = 0, + Export = 0, - /// - /// Use the snapshot for the current transaction executing the command. - /// This option must be used in a transaction, and CREATE_REPLICATION_SLOT must be the first command run - /// in that transaction. - /// - Use = 1, + /// + /// Use the snapshot for the current transaction executing the command. + /// This option must be used in a transaction, and CREATE_REPLICATION_SLOT must be the first command run + /// in that transaction. + /// + Use = 1, - /// - /// Just use the snapshot for logical decoding as normal but don't do anything else with it. - /// - NoExport = 2 - } -} + /// + /// Just use the snapshot for logical decoding as normal but don't do anything else with it. + /// + NoExport = 2 +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgDateTime.cs b/src/Npgsql/Replication/PgDateTime.cs new file mode 100644 index 0000000000..aa68bda7f6 --- /dev/null +++ b/src/Npgsql/Replication/PgDateTime.cs @@ -0,0 +1,16 @@ +using System; + +namespace Npgsql.Replication; + +static class PgDateTime +{ + const long PostgresTimestampOffsetTicks = 630822816000000000L; + + public static DateTime DecodeTimestamp(long value, DateTimeKind kind) + => new(value * 10 + PostgresTimestampOffsetTicks, kind); + + public static long EncodeTimestamp(DateTime value) + // Rounding here would cause problems because we would round up DateTime.MaxValue + // which would make it impossible to retrieve it back from the database, so we just drop the additional precision + => (value.Ticks - PostgresTimestampOffsetTicks) / 10; +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/BeginMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/BeginMessage.cs index edcbab43e2..6fbfcb2c37 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/BeginMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/BeginMessage.cs @@ -1,51 +1,32 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol begin message +/// +public sealed class BeginMessage : TransactionControlMessage { /// - /// Logical Replication Protocol begin message + /// The final LSN of the transaction. /// - public sealed class BeginMessage : PgOutputReplicationMessage - { - /// - /// The final LSN of the transaction. - /// - public NpgsqlLogSequenceNumber TransactionFinalLsn { get; private set; } - - /// - /// Commit timestamp of the transaction. - /// The value is in number of microseconds since PostgreSQL epoch (2000-01-01). - /// - public DateTime TransactionCommitTimestamp { get; private set; } - - /// - /// Xid of the transaction. - /// - public uint TransactionXid { get; private set; } + public NpgsqlLogSequenceNumber TransactionFinalLsn { get; private set; } - internal BeginMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, - NpgsqlLogSequenceNumber transactionFinalLsn, DateTime transactionCommitTimestamp, uint transactionXid) - { - base.Populate(walStart, walEnd, serverClock); - - TransactionFinalLsn = transactionFinalLsn; - TransactionCommitTimestamp = transactionCommitTimestamp; - TransactionXid = transactionXid; + /// + /// Commit timestamp of the transaction. + /// The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + /// + public DateTime TransactionCommitTimestamp { get; private set; } - return this; - } + internal BeginMessage() {} - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override BeginMessage Clone() -#endif - { - var clone = new BeginMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, TransactionFinalLsn, TransactionCommitTimestamp, TransactionXid); - return clone; - } + internal BeginMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + NpgsqlLogSequenceNumber transactionFinalLsn, DateTime transactionCommitTimestamp, uint transactionXid) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + TransactionFinalLsn = transactionFinalLsn; + TransactionCommitTimestamp = transactionCommitTimestamp; + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/BeginPrepareMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/BeginPrepareMessage.cs new file mode 100644 index 0000000000..288bff1e03 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/BeginPrepareMessage.cs @@ -0,0 +1,27 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol begin prepare message +/// +public sealed class BeginPrepareMessage : PrepareMessageBase +{ + internal BeginPrepareMessage() {} + + internal new BeginPrepareMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + NpgsqlLogSequenceNumber prepareLsn, NpgsqlLogSequenceNumber prepareEndLsn, DateTime transactionPrepareTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + prepareLsn: prepareLsn, + prepareEndLsn: prepareEndLsn, + transactionPrepareTimestamp: transactionPrepareTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + return this; + } +} + diff --git a/src/Npgsql/Replication/PgOutput/Messages/CommitMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/CommitMessage.cs index 26c9d3094e..f2f0b16525 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/CommitMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/CommitMessage.cs @@ -1,56 +1,58 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol commit message +/// +public sealed class CommitMessage : PgOutputReplicationMessage { /// - /// Logical Replication Protocol commit message + /// Flags; currently unused. + /// + public CommitFlags Flags { get; private set; } + + /// + /// The LSN of the commit. /// - public sealed class CommitMessage : PgOutputReplicationMessage + public NpgsqlLogSequenceNumber CommitLsn { get; private set; } + + /// + /// The end LSN of the transaction. + /// + public NpgsqlLogSequenceNumber TransactionEndLsn { get; private set; } + + /// + /// Commit timestamp of the transaction. + /// + public DateTime TransactionCommitTimestamp { get; private set; } + + internal CommitMessage() {} + + internal CommitMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + CommitFlags flags, NpgsqlLogSequenceNumber commitLsn, NpgsqlLogSequenceNumber transactionEndLsn, + DateTime transactionCommitTimestamp) { - /// - /// Flags; currently unused (must be 0). - /// - public byte Flags { get; private set; } + base.Populate(walStart, walEnd, serverClock); - /// - /// The LSN of the commit. - /// - public NpgsqlLogSequenceNumber CommitLsn { get; private set; } + Flags = flags; + CommitLsn = commitLsn; + TransactionEndLsn = transactionEndLsn; + TransactionCommitTimestamp = transactionCommitTimestamp; - /// - /// The end LSN of the transaction. - /// - public NpgsqlLogSequenceNumber TransactionEndLsn { get; private set; } + return this; + } + /// + /// Flags for the commit. + /// + [Flags] + public enum CommitFlags : byte + { /// - /// Commit timestamp of the transaction. + /// No flags. /// - public DateTime TransactionCommitTimestamp { get; private set; } - - internal CommitMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, byte flags, - NpgsqlLogSequenceNumber commitLsn, NpgsqlLogSequenceNumber transactionEndLsn, DateTime transactionCommitTimestamp) - { - base.Populate(walStart, walEnd, serverClock); - - Flags = flags; - CommitLsn = commitLsn; - TransactionEndLsn = transactionEndLsn; - TransactionCommitTimestamp = transactionCommitTimestamp; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override CommitMessage Clone() -#endif - { - var clone = new CommitMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, Flags, CommitLsn, TransactionEndLsn, TransactionCommitTimestamp); - return clone; - } + None = 0 } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/CommitPreparedMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/CommitPreparedMessage.cs new file mode 100644 index 0000000000..7ed189a981 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/CommitPreparedMessage.cs @@ -0,0 +1,59 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol commit prepared message +/// +public sealed class CommitPreparedMessage : PreparedTransactionControlMessage +{ + /// + /// Flags for the commit prepared; currently unused. + /// + public CommitPreparedFlags Flags { get; private set; } + + /// + /// The LSN of the commit prepared. + /// + public NpgsqlLogSequenceNumber CommitPreparedLsn => FirstLsn; + + /// + /// The end LSN of the commit prepared transaction. + /// + public NpgsqlLogSequenceNumber CommitPreparedEndLsn => SecondLsn; + + /// + /// Commit timestamp of the transaction. + /// + public DateTime TransactionCommitTimestamp => Timestamp; + + internal CommitPreparedMessage() {} + + internal CommitPreparedMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, CommitPreparedFlags flags, + NpgsqlLogSequenceNumber commitPreparedLsn, NpgsqlLogSequenceNumber commitPreparedEndLsn, DateTime transactionCommitTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + firstLsn: commitPreparedLsn, + secondLsn: commitPreparedEndLsn, + timestamp: transactionCommitTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + Flags = flags; + return this; + } + + /// + /// Flags for the commit prepared; currently unused. + /// + [Flags] + public enum CommitPreparedFlags : byte + { + /// + /// No flags. + /// + None = 0 + } +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs new file mode 100644 index 0000000000..6fd36d7ea0 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/DefaultUpdateMessage.cs @@ -0,0 +1,37 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; +using NpgsqlTypes; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol update message for tables with REPLICA IDENTITY set to DEFAULT. +/// +public class DefaultUpdateMessage : UpdateMessage +{ + readonly ReplicationTuple _newRow; + + /// + /// Columns representing the new row. + /// + public override ReplicationTuple NewRow => _newRow; + + internal DefaultUpdateMessage(NpgsqlConnector connector) + => _newRow = new(connector); + + internal UpdateMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) + { + base.Populate(walStart, walEnd, serverClock, transactionXid, relation); + + _newRow.Reset(numColumns, relation.RowDescription); + + return this; + } + + internal Task Consume(CancellationToken cancellationToken) + => _newRow.Consume(cancellationToken); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/DeleteMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/DeleteMessage.cs index 50d004628c..c1057dabdd 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/DeleteMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/DeleteMessage.cs @@ -1,26 +1,28 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Abstract base class for Logical Replication Protocol delete message types. +/// +public abstract class DeleteMessage : TransactionalMessage { /// - /// Abstract base class for Logical Replication Protocol delete message types. + /// The relation for this . /// - public abstract class DeleteMessage : PgOutputReplicationMessage - { - /// - /// ID of the relation corresponding to the ID in the relation message. - /// - public uint RelationId { get; private set; } + public RelationMessage Relation { get; private set; } = null!; - private protected DeleteMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId) - { - base.Populate(walStart, walEnd, serverClock); + private protected DeleteMessage() {} + + private protected DeleteMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); - RelationId = relationId; + Relation = relation; - return this; - } + return this; } } diff --git a/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs index 7d57bb5534..a426a2b6ad 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/FullDeleteMessage.cs @@ -1,38 +1,37 @@ using NpgsqlTypes; using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol delete message for tables with REPLICA IDENTITY REPLICA IDENTITY set to FULL. +/// +public sealed class FullDeleteMessage : DeleteMessage { + readonly ReplicationTuple _tupleEnumerable; + /// - /// Logical Replication Protocol delete message for tables with REPLICA IDENTITY REPLICA IDENTITY set to FULL. + /// Columns representing the deleted row. /// - public sealed class FullDeleteMessage : DeleteMessage + public ReplicationTuple OldRow => _tupleEnumerable; + + internal FullDeleteMessage(NpgsqlConnector connector) + => _tupleEnumerable = new(connector); + + internal FullDeleteMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) { - /// - /// Columns representing the old values. - /// - public ReadOnlyMemory OldRow { get; private set; } = default!; - - internal FullDeleteMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, ReadOnlyMemory oldRow) - { - base.Populate(walStart, walEnd, serverClock, relationId); - - OldRow = oldRow; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override FullDeleteMessage Clone() -#endif - { - var clone = new FullDeleteMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, OldRow.ToArray()); - return clone; - } + base.Populate(walStart, walEnd, serverClock, transactionXid, relation); + + _tupleEnumerable.Reset(numColumns, relation.RowDescription); + + return this; } -} + + internal Task Consume(CancellationToken cancellationToken) + => _tupleEnumerable.Consume(cancellationToken); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs index 0907ccf82b..814780cf37 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/FullUpdateMessage.cs @@ -1,39 +1,47 @@ -using NpgsqlTypes; -using System; +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; +using NpgsqlTypes; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol update message for tables with REPLICA IDENTITY set to FULL. +/// +public sealed class FullUpdateMessage : UpdateMessage { + readonly ReplicationTuple _oldRow; + readonly SecondRowTupleEnumerable _newRow; + + /// + /// Columns representing the old row. + /// + public ReplicationTuple OldRow => _oldRow; + /// - /// Logical Replication Protocol update message for tables with REPLICA IDENTITY REPLICA IDENTITY set to FULL. + /// Columns representing the new row. /// - public sealed class FullUpdateMessage : UpdateMessage + public override ReplicationTuple NewRow => _newRow; + + internal FullUpdateMessage(NpgsqlConnector connector) { - /// - /// Columns representing the old values. - /// - public ReadOnlyMemory OldRow { get; private set; } = default!; - - internal FullUpdateMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, - ReadOnlyMemory newRow, ReadOnlyMemory oldRow) - { - base.Populate(walStart, walEnd, serverClock, relationId, newRow); - - OldRow = oldRow; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override FullUpdateMessage Clone() -#endif - { - var clone = new FullUpdateMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, NewRow.ToArray(), OldRow.ToArray()); - return clone; - } + _oldRow = new(connector); + _newRow = new(connector, _oldRow); } -} + + internal UpdateMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) + { + base.Populate(walStart, walEnd, serverClock, transactionXid, relation); + + _oldRow.Reset(numColumns, relation.RowDescription); + _newRow.Reset(numColumns, relation.RowDescription); + + return this; + } + + internal Task Consume(CancellationToken cancellationToken) + => _newRow.Consume(cancellationToken); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs index 76dbbaf4ef..021458140d 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/IndexUpdateMessage.cs @@ -1,39 +1,47 @@ -using NpgsqlTypes; -using System; +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; +using NpgsqlTypes; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol update message for tables with REPLICA IDENTITY set to USING INDEX. +/// +public sealed class IndexUpdateMessage : UpdateMessage { + readonly ReplicationTuple _key; + readonly SecondRowTupleEnumerable _newRow; + + /// + /// Columns representing the key. + /// + public ReplicationTuple Key => _key; + /// - /// Logical Replication Protocol update message for tables with REPLICA IDENTITY set to USING INDEX. + /// Columns representing the new row. /// - public sealed class IndexUpdateMessage : UpdateMessage + public override ReplicationTuple NewRow => _newRow; + + internal IndexUpdateMessage(NpgsqlConnector connector) { - /// - /// Columns representing the key. - /// - public ReadOnlyMemory KeyRow { get; private set; } = default!; - - internal IndexUpdateMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, - ReadOnlyMemory newRow, ReadOnlyMemory keyRow) - { - base.Populate(walStart, walEnd, serverClock, relationId, newRow); - - KeyRow = keyRow; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override IndexUpdateMessage Clone() -#endif - { - var clone = new IndexUpdateMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, NewRow.ToArray(), KeyRow.ToArray()); - return clone; - } + _key = new(connector); + _newRow = new(connector, _key); } -} + + internal UpdateMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) + { + base.Populate(walStart, walEnd, serverClock, transactionXid, relation); + + _key.Reset(numColumns, relation.RowDescription); + _newRow.Reset(numColumns, relation.RowDescription); + + return this; + } + + internal Task Consume(CancellationToken cancellationToken) + => _newRow.Consume(cancellationToken); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs index 9de22f54b2..df413f6b21 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/InsertMessage.cs @@ -1,45 +1,43 @@ using NpgsqlTypes; using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol insert message +/// +public sealed class InsertMessage : TransactionalMessage { + readonly ReplicationTuple _tupleEnumerable; + + /// + /// The relation for this . + /// + public RelationMessage Relation { get; private set; } = null!; + /// - /// Logical Replication Protocol insert message + /// Columns representing the new row. /// - public sealed class InsertMessage : PgOutputReplicationMessage + public ReplicationTuple NewRow => _tupleEnumerable; + + internal InsertMessage(NpgsqlConnector connector) + => _tupleEnumerable = new(connector); + + internal InsertMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) { - /// - /// ID of the relation corresponding to the ID in the relation message. - /// - public uint RelationId { get; private set; } - - /// - /// Columns representing the new row. - /// - public ReadOnlyMemory NewRow { get; private set; } = default!; - - internal InsertMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, - ReadOnlyMemory newRow) - { - base.Populate(walStart, walEnd, serverClock); - - RelationId = relationId; - NewRow = newRow; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override InsertMessage Clone() -#endif - { - var clone = new InsertMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, NewRow.ToArray()); - return clone; - } + base.Populate(walStart, walEnd, serverClock, transactionXid); + + Relation = relation; + _tupleEnumerable.Reset(numColumns, relation.RowDescription); + + return this; } + + internal Task Consume(CancellationToken cancellationToken) + => _tupleEnumerable.Consume(cancellationToken); } diff --git a/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs index 6efdfbefc7..9b30b3e1df 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/KeyDeleteMessage.cs @@ -1,39 +1,37 @@ using NpgsqlTypes; using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol delete message for tables with REPLICA IDENTITY set to DEFAULT or USING INDEX. +/// +public sealed class KeyDeleteMessage : DeleteMessage { + readonly ReplicationTuple _tupleEnumerable; + /// - /// Logical Replication Protocol delete message for tables with REPLICA IDENTITY set to DEFAULT or USING INDEX. + /// Columns representing the key. /// - public sealed class KeyDeleteMessage : DeleteMessage + public ReplicationTuple Key => _tupleEnumerable; + + internal KeyDeleteMessage(NpgsqlConnector connector) + => _tupleEnumerable = new(connector); + + internal KeyDeleteMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation, ushort numColumns) { - /// - /// Columns representing the primary key. - /// - public ReadOnlyMemory KeyRow { get; private set; } = default!; - - internal KeyDeleteMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, - ReadOnlyMemory keyRow) - { - base.Populate(walStart, walEnd, serverClock, relationId); - - KeyRow = keyRow; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override KeyDeleteMessage Clone() -#endif - { - var clone = new KeyDeleteMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, KeyRow.ToArray()); - return clone; - } + base.Populate(walStart, walEnd, serverClock, transactionXid, relation); + + _tupleEnumerable.Reset(numColumns, relation.RowDescription); + + return this; } -} + + internal Task Consume(CancellationToken cancellationToken) + => _tupleEnumerable.Consume(cancellationToken); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/LogicalDecodingMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/LogicalDecodingMessage.cs new file mode 100644 index 0000000000..0add6103e6 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/LogicalDecodingMessage.cs @@ -0,0 +1,44 @@ +using System; +using System.IO; +using NpgsqlTypes; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol logical decoding message +/// +public sealed class LogicalDecodingMessage : TransactionalMessage +{ + /// + /// Flags; Either 0 for no flags or 1 if the logical decoding message is transactional. + /// + public byte Flags { get; private set; } + + /// + /// The LSN of the logical decoding message. + /// + public NpgsqlLogSequenceNumber MessageLsn { get; private set; } + + /// + /// The prefix of the logical decoding message. + /// + public string Prefix { get; private set; } = default!; + + /// + /// The content of the logical decoding message. + /// + public Stream Data { get; private set; } = default!; + + internal LogicalDecodingMessage() {} + + internal LogicalDecodingMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + uint? transactionXid, byte flags, NpgsqlLogSequenceNumber messageLsn, string prefix, Stream data) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + Flags = flags; + MessageLsn = messageLsn; + Prefix = prefix; + Data = data; + return this; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/OriginMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/OriginMessage.cs index aa7a977173..8356cc997a 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/OriginMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/OriginMessage.cs @@ -1,45 +1,34 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol origin message +/// +public sealed class OriginMessage : PgOutputReplicationMessage { /// - /// Logical Replication Protocol origin message + /// The LSN of the commit on the origin server. /// - public sealed class OriginMessage : PgOutputReplicationMessage - { - /// - /// The LSN of the commit on the origin server. - /// - public NpgsqlLogSequenceNumber OriginCommitLsn { get; private set; } + public NpgsqlLogSequenceNumber OriginCommitLsn { get; private set; } - /// - /// Name of the origin. - /// - public string OriginName { get; private set; } = string.Empty; + /// + /// Name of the origin. + /// + public string OriginName { get; private set; } = string.Empty; - internal OriginMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, NpgsqlLogSequenceNumber originCommitLsn, - string originName) - { - base.Populate(walStart, walEnd, serverClock); + internal OriginMessage() {} - OriginCommitLsn = originCommitLsn; - OriginName = originName; + internal OriginMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, NpgsqlLogSequenceNumber originCommitLsn, + string originName) + { + base.Populate(walStart, walEnd, serverClock); - return this; - } + OriginCommitLsn = originCommitLsn; + OriginName = originName; - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override OriginMessage Clone() -#endif - { - var clone = new OriginMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, OriginCommitLsn, OriginName); - return clone; - } + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs index ca863cc154..24de9e201f 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/PgOutputReplicationMessage.cs @@ -1,23 +1,14 @@ -using NpgsqlTypes; -using System; +namespace Npgsql.Replication.PgOutput.Messages; -namespace Npgsql.Replication.PgOutput.Messages +/// +/// The base class of all Logical Replication Protocol Messages +/// +/// +/// See https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html for details about the +/// protocol. +/// +public abstract class PgOutputReplicationMessage : ReplicationMessage { - /// - /// The base class of all Logical Replication Protocol Messages - /// - /// - /// See https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html for details about the - /// protocol. - /// - public abstract class PgOutputReplicationMessage : ReplicationMessage - { - /// - public override string ToString() => GetType().Name; - - /// - /// Returns a clone of this message, which can be accessed after other replication messages have been retrieved. - /// - public abstract PgOutputReplicationMessage Clone(); - } -} + /// + public override string ToString() => GetType().Name; +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/PrepareMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/PrepareMessage.cs new file mode 100644 index 0000000000..16cd8fa36b --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/PrepareMessage.cs @@ -0,0 +1,45 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol prepare message +/// +public sealed class PrepareMessage : PrepareMessageBase +{ + /// + /// Flags for the prepare; currently unused. + /// + public PrepareFlags Flags { get; private set; } + + internal PrepareMessage() {} + + internal PrepareMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, PrepareFlags flags, + NpgsqlLogSequenceNumber prepareLsn, NpgsqlLogSequenceNumber prepareEndLsn, DateTime transactionPrepareTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + prepareLsn: prepareLsn, + prepareEndLsn: prepareEndLsn, + transactionPrepareTimestamp: transactionPrepareTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + Flags = flags; + + return this; + } + + /// + /// Flags for the prepare; currently unused. + /// + [Flags] + public enum PrepareFlags : byte + { + /// + /// No flags. + /// + None = 0 + } +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/PrepareMessageBase.cs b/src/Npgsql/Replication/PgOutput/Messages/PrepareMessageBase.cs new file mode 100644 index 0000000000..0eda1b18d3 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/PrepareMessageBase.cs @@ -0,0 +1,41 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Abstract base class for the logical replication protocol begin prepare and prepare message +/// +public abstract class PrepareMessageBase : PreparedTransactionControlMessage +{ + /// + /// The LSN of the prepare. + /// + public NpgsqlLogSequenceNumber PrepareLsn => FirstLsn; + + /// + /// The end LSN of the prepared transaction. + /// + public NpgsqlLogSequenceNumber PrepareEndLsn => SecondLsn; + + /// + /// Prepare timestamp of the transaction. + /// + public DateTime TransactionPrepareTimestamp => Timestamp; + + private protected PrepareMessageBase() {} + + internal new PrepareMessageBase Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + NpgsqlLogSequenceNumber prepareLsn, NpgsqlLogSequenceNumber prepareEndLsn, DateTime transactionPrepareTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + firstLsn: prepareLsn, + secondLsn: prepareEndLsn, + timestamp: transactionPrepareTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + return this; + } +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/PreparedTransactionControlMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/PreparedTransactionControlMessage.cs new file mode 100644 index 0000000000..04f98be920 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/PreparedTransactionControlMessage.cs @@ -0,0 +1,37 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Abstract base class for Logical Replication Protocol prepare and begin prepare message +/// +public abstract class PreparedTransactionControlMessage : TransactionControlMessage +{ + private protected NpgsqlLogSequenceNumber FirstLsn; + private protected NpgsqlLogSequenceNumber SecondLsn; + private protected DateTime Timestamp; + + /// + /// The user defined GID of the two-phase transaction. + /// + public string TransactionGid { get; private set; } = null!; + + private protected PreparedTransactionControlMessage() {} + + private protected PreparedTransactionControlMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + NpgsqlLogSequenceNumber firstLsn, NpgsqlLogSequenceNumber secondLsn, DateTime timestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + + FirstLsn = firstLsn; + SecondLsn = secondLsn; + Timestamp = timestamp; + TransactionGid = transactionGid; + + return this; + } +} + diff --git a/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs index aedaeefee6..85d83debb7 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/RelationMessage.cs @@ -1,98 +1,138 @@ using NpgsqlTypes; using System; using System.Collections.Generic; +using Npgsql.BackendMessages; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol relation message +/// +public sealed class RelationMessage : TransactionalMessage { /// - /// Logical Replication Protocol relation message + /// ID of the relation. + /// + public uint RelationId { get; private set; } + + /// + /// Namespace (empty string for pg_catalog). /// - public sealed class RelationMessage : PgOutputReplicationMessage + public string Namespace { get; private set; } = string.Empty; + + /// + /// Relation name. + /// + public string RelationName { get; private set; } = string.Empty; + + /// + /// Replica identity setting for the relation (same as relreplident in pg_class): + /// columns used to form “replica identity” for rows. + /// + public ReplicaIdentitySetting ReplicaIdentity { get; private set; } + + /// + /// Relation columns + /// + public IReadOnlyList Columns => InternalColumns; + + internal ReadOnlyArrayBuffer InternalColumns { get; } = new(); + + internal RowDescriptionMessage RowDescription { get; set; } = null!; + + internal RelationMessage() {} + + internal RelationMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, uint relationId, string ns, + string relationName, ReplicaIdentitySetting relationReplicaIdentitySetting) { - /// - /// ID of the relation. - /// - public uint RelationId { get; private set; } + base.Populate(walStart, walEnd, serverClock, transactionXid); + + RelationId = relationId; + Namespace = ns; + RelationName = relationName; + ReplicaIdentity = relationReplicaIdentitySetting; + + return this; + } + + /// + /// Represents a column in a Logical Replication Protocol relation message + /// + public readonly struct Column + { + internal Column(ColumnFlags flags, string columnName, uint dataTypeId, int typeModifier) + { + Flags = flags; + ColumnName = columnName; + DataTypeId = dataTypeId; + TypeModifier = typeModifier; + } /// - /// Namespace (empty string for pg_catalog). + /// Flags for the column. /// - public string Namespace { get; private set; } = string.Empty; + public ColumnFlags Flags { get; } /// - /// Relation name. + /// Name of the column. /// - public string RelationName { get; private set; } = string.Empty; + public string ColumnName { get; } /// - /// Replica identity setting for the relation (same as relreplident in pg_class). + /// ID of the column's data type. /// - public char RelationReplicaIdentitySetting { get; private set; } + public uint DataTypeId { get; } /// - /// Relation columns + /// Type modifier of the column (atttypmod). /// - public ReadOnlyMemory Columns { get; private set; } = default!; - - internal RelationMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, string ns, - string relationName, char relationReplicaIdentitySetting, ReadOnlyMemory columns) - { - base.Populate(walStart, walEnd, serverClock); - - RelationId = relationId; - Namespace = ns; - RelationName = relationName; - RelationReplicaIdentitySetting = relationReplicaIdentitySetting; - Columns = columns; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override RelationMessage Clone() -#endif - { - var clone = new RelationMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, Namespace, RelationName, RelationReplicaIdentitySetting, Columns.ToArray()); - return clone; - } + public int TypeModifier { get; } /// - /// Represents a column in a Logical Replication Protocol relation message + /// Flags for the column. /// - public readonly struct Column + [Flags] + public enum ColumnFlags { - internal Column(byte flags, string columnName, uint dataTypeId, int typeModifier) - { - Flags = flags; - ColumnName = columnName; - DataTypeId = dataTypeId; - TypeModifier = typeModifier; - } - /// - /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as part of the key. + /// No flags. /// - public byte Flags { get; } + None = 0, /// - /// Name of the column. + /// Marks the column as part of the key. /// - public string ColumnName { get; } + PartOfKey = 1 + } + } - /// - /// ID of the column's data type. - /// - public uint DataTypeId { get; } + /// + /// Replica identity setting for the relation (same as relreplident in pg_class). + /// + /// + /// See + /// + public enum ReplicaIdentitySetting : byte + { + /// + /// Default (primary key, if any). + /// + Default = (byte)'d', - /// - /// Type modifier of the column (atttypmod). - /// - public int TypeModifier { get; } - } + /// + /// Nothing. + /// + Nothing = (byte)'n', + + /// + /// All columns. + /// + AllColumns = (byte)'f', + + /// + /// Index with indisreplident set (same as nothing if the index used has been dropped) + /// + IndexWithIndIsReplIdent = (byte)'i' } } diff --git a/src/Npgsql/Replication/PgOutput/Messages/RelationMessageColumn.cs b/src/Npgsql/Replication/PgOutput/Messages/RelationMessageColumn.cs index 21a98ffc74..4692e4e6c4 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/RelationMessageColumn.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/RelationMessageColumn.cs @@ -1,36 +1,35 @@ -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Represents a column in a Logical Replication Protocol relation message +/// +public readonly struct RelationMessageColumn { - /// - /// Represents a column in a Logical Replication Protocol relation message - /// - public readonly struct RelationMessageColumn + internal RelationMessageColumn(byte flags, string columnName, uint dataTypeId, int typeModifier) { - internal RelationMessageColumn(byte flags, string columnName, uint dataTypeId, int typeModifier) - { - Flags = flags; - ColumnName = columnName; - DataTypeId = dataTypeId; - TypeModifier = typeModifier; - } + Flags = flags; + ColumnName = columnName; + DataTypeId = dataTypeId; + TypeModifier = typeModifier; + } - /// - /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as part of the key. - /// - public byte Flags { get; } + /// + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as part of the key. + /// + public byte Flags { get; } - /// - /// Name of the column. - /// - public string ColumnName { get; } + /// + /// Name of the column. + /// + public string ColumnName { get; } - /// - /// ID of the column's data type. - /// - public uint DataTypeId { get; } + /// + /// ID of the column's data type. + /// + public uint DataTypeId { get; } - /// - /// Type modifier of the column (atttypmod). - /// - public int TypeModifier { get; } - } -} + /// + /// Type modifier of the column (atttypmod). + /// + public int TypeModifier { get; } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/RollbackPreparedMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/RollbackPreparedMessage.cs new file mode 100644 index 0000000000..681e7af4b6 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/RollbackPreparedMessage.cs @@ -0,0 +1,65 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol rollback prepared message +/// +public sealed class RollbackPreparedMessage : PreparedTransactionControlMessage +{ + /// + /// Flags for the rollback prepared; currently unused. + /// + public RollbackPreparedFlags Flags { get; private set; } + + /// + /// The end LSN of the prepared transaction. + /// + public NpgsqlLogSequenceNumber PreparedTransactionEndLsn => FirstLsn; + + /// + /// The end LSN of the rollback prepared transaction. + /// + public NpgsqlLogSequenceNumber RollbackPreparedEndLsn => SecondLsn; + + /// + /// Prepare timestamp of the transaction. + /// + public DateTime TransactionPrepareTimestamp => Timestamp; + + /// + /// Rollback timestamp of the transaction. + /// + public DateTime TransactionRollbackTimestamp { get; private set; } + + internal RollbackPreparedMessage() {} + + internal RollbackPreparedMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, RollbackPreparedFlags flags, + NpgsqlLogSequenceNumber preparedTransactionEndLsn, NpgsqlLogSequenceNumber rollbackPreparedEndLsn, DateTime transactionPrepareTimestamp, DateTime transactionRollbackTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + firstLsn: preparedTransactionEndLsn, + secondLsn: rollbackPreparedEndLsn, + timestamp: transactionPrepareTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + Flags = flags; + TransactionRollbackTimestamp = transactionRollbackTimestamp; + return this; + } + + /// + /// Flags for the rollback prepared; currently unused. + /// + [Flags] + public enum RollbackPreparedFlags : byte + { + /// + /// No flags. + /// + None = 0 + } +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs new file mode 100644 index 0000000000..23fc2c5a24 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs @@ -0,0 +1,25 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol stream abort message +/// +public sealed class StreamAbortMessage : TransactionControlMessage +{ + /// + /// Xid of the subtransaction (will be same as xid of the transaction for top-level transactions). + /// + public uint SubtransactionXid { get; private set; } + + internal StreamAbortMessage() {} + + internal StreamAbortMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + uint transactionXid, uint subtransactionXid) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + SubtransactionXid = subtransactionXid; + return this; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamCommitMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamCommitMessage.cs new file mode 100644 index 0000000000..ae6aacc584 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamCommitMessage.cs @@ -0,0 +1,43 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol stream commit message +/// +public sealed class StreamCommitMessage : TransactionControlMessage +{ + /// + /// Flags; currently unused (must be 0). + /// + public byte Flags { get; private set; } + + /// + /// The LSN of the commit. + /// + public NpgsqlLogSequenceNumber CommitLsn { get; private set; } + + /// + /// The end LSN of the transaction. + /// + public NpgsqlLogSequenceNumber TransactionEndLsn { get; private set; } + + /// + /// Commit timestamp of the transaction. + /// + public DateTime TransactionCommitTimestamp { get; private set; } + + internal StreamCommitMessage() {} + + internal StreamCommitMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + uint transactionXid, byte flags, NpgsqlLogSequenceNumber commitLsn, NpgsqlLogSequenceNumber transactionEndLsn, DateTime transactionCommitTimestamp) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + Flags = flags; + CommitLsn = commitLsn; + TransactionEndLsn = transactionEndLsn; + TransactionCommitTimestamp = transactionCommitTimestamp; + return this; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamPrepareMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamPrepareMessage.cs new file mode 100644 index 0000000000..4947e0d046 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamPrepareMessage.cs @@ -0,0 +1,45 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol stream prepare message +/// +public sealed class StreamPrepareMessage : PrepareMessageBase +{ + /// + /// Flags for the prepare; currently unused. + /// + public StreamPrepareFlags Flags { get; private set; } + + internal StreamPrepareMessage() {} + + internal StreamPrepareMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, StreamPrepareFlags flags, + NpgsqlLogSequenceNumber prepareLsn, NpgsqlLogSequenceNumber prepareEndLsn, DateTime transactionPrepareTimestamp, + uint transactionXid, string transactionGid) + { + base.Populate(walStart, walEnd, serverClock, + prepareLsn: prepareLsn, + prepareEndLsn: prepareEndLsn, + transactionPrepareTimestamp: transactionPrepareTimestamp, + transactionXid: transactionXid, + transactionGid: transactionGid); + Flags = flags; + + return this; + } + + /// + /// Flags for the prepare; currently unused. + /// + [Flags] + public enum StreamPrepareFlags : byte + { + /// + /// No flags. + /// + None = 0 + } +} diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamStartMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamStartMessage.cs new file mode 100644 index 0000000000..4b0ace1cf7 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamStartMessage.cs @@ -0,0 +1,25 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol stream start message +/// +public sealed class StreamStartMessage : TransactionControlMessage +{ + /// + /// A value of 1 indicates this is the first stream segment for this XID, 0 for any other stream segment. + /// + public byte StreamSegmentIndicator { get; private set; } + + internal StreamStartMessage() {} + + internal StreamStartMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + uint transactionXid, byte streamSegmentIndicator) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + StreamSegmentIndicator = streamSegmentIndicator; + return this; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamStopMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamStopMessage.cs new file mode 100644 index 0000000000..f3fd165a1e --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamStopMessage.cs @@ -0,0 +1,18 @@ +using NpgsqlTypes; +using System; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol stream stop message +/// +public sealed class StreamStopMessage : PgOutputReplicationMessage +{ + internal StreamStopMessage() {} + + internal new StreamStopMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock) + { + base.Populate(walStart, walEnd, serverClock); + return this; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/TransactionControlMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/TransactionControlMessage.cs new file mode 100644 index 0000000000..4c3c901b2f --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/TransactionControlMessage.cs @@ -0,0 +1,22 @@ +using System; +using NpgsqlTypes; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// The common base class for all replication messages that set the transaction xid of a transaction +/// +public abstract class TransactionControlMessage : PgOutputReplicationMessage +{ + /// + /// Xid of the transaction. + /// + public uint TransactionXid { get; private set; } + + private protected void Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint transactionXid) + { + base.Populate(walStart, walEnd, serverClock); + + TransactionXid = transactionXid; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/TransactionalMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/TransactionalMessage.cs new file mode 100644 index 0000000000..d5aac683a2 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/Messages/TransactionalMessage.cs @@ -0,0 +1,22 @@ +using System; +using NpgsqlTypes; + +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// The common base class for all streaming replication messages that can be part of a streaming transaction (protocol V2) +/// +public abstract class TransactionalMessage : PgOutputReplicationMessage +{ + /// + /// Xid of the transaction (only present for streamed transactions). + /// + public uint? TransactionXid { get; private set; } + + private protected void Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid) + { + base.Populate(walStart, walEnd, serverClock); + + TransactionXid = transactionXid; + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/TruncateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/TruncateMessage.cs index 0fa1d2b7ac..47837f93f3 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/TruncateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/TruncateMessage.cs @@ -1,45 +1,55 @@ using NpgsqlTypes; using System; +using System.Collections.Generic; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol truncate message +/// +public sealed class TruncateMessage : TransactionalMessage { /// - /// Logical Replication Protocol truncate message + /// Option flags for TRUNCATE + /// + public TruncateOptions Options { get; private set; } + + /// + /// The relations being truncated. + /// + public IReadOnlyList Relations { get; private set; } = ReadOnlyArrayBuffer.Empty; + + internal TruncateMessage() {} + + internal TruncateMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, TruncateOptions options, + ReadOnlyArrayBuffer relations) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + Options = options; + Relations = relations; + return this; + } + + /// + /// Enum representing the additional options for the TRUNCATE command as flags /// - public sealed class TruncateMessage : PgOutputReplicationMessage + [Flags] + public enum TruncateOptions : byte { /// - /// Option flags for TRUNCATE + /// No additional option was specified + /// + None = 0, + + /// + /// CASCADE was specified /// - public TruncateOptions Options { get; private set; } + Cascade = 1, /// - /// IDs of the relations corresponding to the ID in the relation message. + /// RESTART IDENTITY was specified /// - public uint[] RelationIds { get; private set; } = default!; - - internal TruncateMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, TruncateOptions options, - uint[] relationIds) - { - base.Populate(walStart, walEnd, serverClock); - - Options = options; - RelationIds = relationIds; - - return this; - } - - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override TruncateMessage Clone() -#endif - { - var clone = new TruncateMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, Options, RelationIds); // TODO: RelationIds... - return clone; - } + RestartIdentity = 2 } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/TruncateOptions.cs b/src/Npgsql/Replication/PgOutput/Messages/TruncateOptions.cs deleted file mode 100644 index 3b5fbcb038..0000000000 --- a/src/Npgsql/Replication/PgOutput/Messages/TruncateOptions.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; - -namespace Npgsql.Replication.PgOutput.Messages -{ - /// - /// Enum representing the additional options for the TRUNCATE command as flags - /// - [Flags] - public enum TruncateOptions : byte - { - /// - /// No additional option was specified - /// - None = 0, - - /// - /// CASCADE was specified - /// - Cascade = 1, - - /// - /// RESTART IDENTITY was specified - /// - RestartIdentity = 2 - } -} diff --git a/src/Npgsql/Replication/PgOutput/Messages/TupleData.cs b/src/Npgsql/Replication/PgOutput/Messages/TupleData.cs deleted file mode 100644 index bf1fe33a40..0000000000 --- a/src/Npgsql/Replication/PgOutput/Messages/TupleData.cs +++ /dev/null @@ -1,111 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace Npgsql.Replication.PgOutput.Messages -{ - /// - /// Represents the data transmitted for a tuple in a Logical Replication Protocol message - /// - [StructLayout(LayoutKind.Explicit)] - public readonly struct TupleData - { - internal TupleData(TupleDataKind kind) - { - Kind = kind; - _textValue = null; - // _binaryValue = null; - } - - internal TupleData(string textValue) - { - Kind = TupleDataKind.TextValue; - _textValue = textValue; - // _binaryValue = null; - } - - internal TupleData(byte[] binaryValue) - { - Kind = TupleDataKind.TextValue; - _textValue = null; - // _binaryValue = binaryValue; - } - - /// - /// The kind of data in the tuple - /// - [field: FieldOffset(0)] - public TupleDataKind Kind { get; } - - [FieldOffset(8)] readonly string? _textValue; - - /// - /// The value of the tuple, if is . Otherwise throws. - /// - public string TextValue => Kind == TupleDataKind.TextValue - ? _textValue! - : throw new InvalidOperationException("Tuple kind is " + Kind); - -#if PG14 - [FieldOffset(8)] readonly byte[]? _binaryValue; - - /// - /// The value of the tuple, if is . Otherwise throws. - /// - public byte[] BinaryValue => Kind == TupleDataKind.BinaryValue - ? _binaryValue! - : throw new InvalidOperationException("Tuple kind is " + Kind); -#endif - - /// - /// The value of the tuple, in text format if is , or other formats - /// as may be added. Otherwise . - /// - public object? Value => Kind switch - { - TupleDataKind.Null => null, - TupleDataKind.UnchangedToastedValue => null, - TupleDataKind.TextValue => _textValue, - // TupleDataKind.BinaryValue => _binaryValue, - _ => throw new ArgumentOutOfRangeException($"Unhandled {nameof(TupleDataKind)}: {Kind}") - }; - - /// - public override string ToString() => Kind switch - { - TupleDataKind.Null => "", - TupleDataKind.UnchangedToastedValue => "", - TupleDataKind.TextValue => TextValue, - // TupleDataKind.BinaryValue => , - _ => throw new ArgumentOutOfRangeException($"Unhandled {nameof(TupleDataKind)}: {Kind}") - }; - } - - /// - /// The kind of data transmitted for a tuple in a Logical Replication Protocol message - /// - public enum TupleDataKind : byte - { - /// - /// Identifies the data as NULL value. - /// - Null = (byte)'n', - - /// - /// Identifies unchanged TOASTed value (the actual value is not sent). - /// - UnchangedToastedValue = (byte)'u', - - /// - /// Identifies the data as text formatted value. - /// - TextValue = (byte)'t', - -#if PG14 - /// - /// Identifies the data as binary value. - /// - /// Added in PG14 - BinaryValue = (byte)'b' -#endif - } -} diff --git a/src/Npgsql/Replication/PgOutput/Messages/TypeMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/TypeMessage.cs index 2e0062bfcb..5e188de4fe 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/TypeMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/TypeMessage.cs @@ -1,50 +1,37 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Logical Replication Protocol type message +/// +public sealed class TypeMessage : TransactionalMessage { /// - /// Logical Replication Protocol type message + /// ID of the data type. /// - public sealed class TypeMessage : PgOutputReplicationMessage - { - /// - /// ID of the data type. - /// - public uint TypeId { get; private set; } - - /// - /// Namespace (empty string for pg_catalog). - /// - public string Namespace { get; private set; } = string.Empty; + public uint TypeId { get; private set; } - /// - /// Name of the data type. - /// - public string Name { get; private set; } = string.Empty; - - internal TypeMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint typeId, string ns, string name) - { - base.Populate(walStart, walEnd, serverClock); + /// + /// Namespace (empty string for pg_catalog). + /// + public string Namespace { get; private set; } = string.Empty; - TypeId = typeId; - Namespace = ns; - Name = name; + /// + /// Name of the data type. + /// + public string Name { get; private set; } = string.Empty; - return this; - } + internal TypeMessage() {} - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override TypeMessage Clone() -#endif - { - var clone = new TypeMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, TypeId, Namespace, Name); - return clone; - } + internal TypeMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, uint typeId, string ns, string name) + { + base.Populate(walStart, walEnd, serverClock, transactionXid); + TypeId = typeId; + Namespace = ns; + Name = name; + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/Messages/UpdateMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/UpdateMessage.cs index a446dede46..135ff0ddaf 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/UpdateMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/UpdateMessage.cs @@ -1,49 +1,75 @@ using NpgsqlTypes; using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; -namespace Npgsql.Replication.PgOutput.Messages +namespace Npgsql.Replication.PgOutput.Messages; + +/// +/// Abstract base class for Logical Replication Protocol delete message types. +/// +public abstract class UpdateMessage : TransactionalMessage { /// - /// Logical Replication Protocol update message for tables with REPLICA IDENTITY set to DEFAULT. + /// The relation for this . + /// + public RelationMessage Relation { get; private set; } = null!; + + /// + /// Columns representing the new row. /// - /// - /// This is the base type of all update messages containing only the tuples for the new row. - /// - public class UpdateMessage : PgOutputReplicationMessage + public abstract ReplicationTuple NewRow { get; } + + internal UpdateMessage() {} + + internal UpdateMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint? transactionXid, + RelationMessage relation) { - /// - /// ID of the relation corresponding to the ID in the relation message. - /// - public uint RelationId { get; private set; } - - /// - /// Columns representing the new row. - /// - public ReadOnlyMemory NewRow { get; private set; } - - internal UpdateMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, uint relationId, - ReadOnlyMemory newRow) + base.Populate(walStart, walEnd, serverClock, transactionXid); + + Relation = relation; + + return this; + } + + private protected sealed class SecondRowTupleEnumerable : ReplicationTuple + { + readonly ReplicationTuple _oldRowTupleEnumerable; + + internal SecondRowTupleEnumerable(NpgsqlConnector connector, ReplicationTuple oldRowTupleEnumerable) + : base(connector) + => _oldRowTupleEnumerable = oldRowTupleEnumerable; + + public override async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - base.Populate(walStart, walEnd, serverClock); + // This will throw if we're already reading (or consumed) the second row + var enumerator = base.GetAsyncEnumerator(cancellationToken); - RelationId = relationId; - NewRow = newRow; + await _oldRowTupleEnumerable.Consume(cancellationToken).ConfigureAwait(false); + await ReadBuffer.EnsureAsync(3).ConfigureAwait(false); + var tupleType = (TupleType)ReadBuffer.ReadByte(); + Debug.Assert(tupleType == TupleType.NewTuple); + _ = ReadBuffer.ReadUInt16(); // numColumns, - return this; + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + yield return enumerator.Current; } - /// -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - public override PgOutputReplicationMessage Clone() -#else - public override UpdateMessage Clone() -#endif + internal new async Task Consume(CancellationToken cancellationToken) { - var clone = new UpdateMessage(); - clone.Populate(WalStart, WalEnd, ServerClock, RelationId, NewRow.ToArray()); - return clone; + if (State == RowState.NotRead) + { + await _oldRowTupleEnumerable.Consume(cancellationToken).ConfigureAwait(false); + await ReadBuffer.EnsureAsync(3).ConfigureAwait(false); + var tupleType = (TupleType)ReadBuffer.ReadByte(); + Debug.Assert(tupleType == TupleType.NewTuple); + _ = ReadBuffer.ReadUInt16(); // numColumns, + } + await base.Consume(cancellationToken).ConfigureAwait(false); } } } diff --git a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs index e68ed1c9d4..ae26d229f6 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs @@ -3,305 +3,492 @@ using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Internal; using Npgsql.Replication.Internal; using Npgsql.Replication.PgOutput.Messages; -using Npgsql.TypeHandlers.DateTimeHandlers; using NpgsqlTypes; -namespace Npgsql.Replication.PgOutput +namespace Npgsql.Replication.PgOutput; + +sealed class PgOutputAsyncEnumerable : IAsyncEnumerable { - class PgOutputAsyncEnumerable : IAsyncEnumerable + readonly LogicalReplicationConnection _connection; + readonly PgOutputReplicationSlot _slot; + readonly PgOutputReplicationOptions _options; + readonly CancellationToken _baseCancellationToken; + readonly NpgsqlLogSequenceNumber? _walLocation; + + #region Cached logical streaming replication protocol messages + + // V1 + readonly BeginMessage _beginMessage = new(); + readonly LogicalDecodingMessage _logicalDecodingMessage = new(); + readonly CommitMessage _commitMessage = new(); + readonly OriginMessage _originMessage = new(); + readonly Dictionary _relations = new(); + readonly TypeMessage _typeMessage = new(); + readonly InsertMessage _insertMessage; + readonly DefaultUpdateMessage _defaultUpdateMessage; + readonly FullUpdateMessage _fullUpdateMessage; + readonly IndexUpdateMessage _indexUpdateMessage; + readonly FullDeleteMessage _fullDeleteMessage; + readonly KeyDeleteMessage _keyDeleteMessage; + readonly TruncateMessage _truncateMessage = new(); + readonly ReadOnlyArrayBuffer _truncateMessageRelations = new(); + + // V2 + readonly StreamStartMessage _streamStartMessage = new(); + readonly StreamStopMessage _streamStopMessage = new(); + readonly StreamCommitMessage _streamCommitMessage = new(); + readonly StreamAbortMessage _streamAbortMessage = new(); + + // V3 + readonly BeginPrepareMessage _beginPrepareMessage = new(); + readonly PrepareMessage _prepareMessage = new(); + readonly CommitPreparedMessage _commitPreparedMessage = new(); + readonly RollbackPreparedMessage _rollbackPreparedMessage = new(); + readonly StreamPrepareMessage _streamPrepareMessage = new(); + + #endregion + + internal PgOutputAsyncEnumerable( + LogicalReplicationConnection connection, + PgOutputReplicationSlot slot, + PgOutputReplicationOptions options, + CancellationToken cancellationToken, + NpgsqlLogSequenceNumber? walLocation = null) { - readonly LogicalReplicationConnection _connection; - readonly PgOutputReplicationSlot _slot; - readonly PgOutputReplicationOptions _options; - readonly CancellationToken _baseCancellationToken; - readonly NpgsqlLogSequenceNumber? _walLocation; - - #region Cached messages - - readonly BeginMessage _beginMessage = new BeginMessage(); - readonly CommitMessage _commitMessage = new CommitMessage(); - readonly FullDeleteMessage _fullDeleteMessage = new FullDeleteMessage(); - readonly FullUpdateMessage _fullUpdateMessage = new FullUpdateMessage(); - readonly IndexUpdateMessage _indexUpdateMessage = new IndexUpdateMessage(); - readonly InsertMessage _insertMessage = new InsertMessage(); - readonly KeyDeleteMessage _keyDeleteMessage = new KeyDeleteMessage(); - readonly OriginMessage _originMessage = new OriginMessage(); - readonly RelationMessage _relationMessage = new RelationMessage(); - readonly TruncateMessage _truncateMessage = new TruncateMessage(); - readonly TypeMessage _typeMessage = new TypeMessage(); - readonly UpdateMessage _updateMessage = new UpdateMessage(); - - TupleData[] _tupleDataArray1 = Array.Empty(); - TupleData[] _tupleDataArray2 = Array.Empty(); - RelationMessage.Column[] _relationalMessageColumns = Array.Empty(); - - #endregion - - internal PgOutputAsyncEnumerable( - LogicalReplicationConnection connection, - PgOutputReplicationSlot slot, - PgOutputReplicationOptions options, - CancellationToken cancellationToken, - NpgsqlLogSequenceNumber? walLocation = null) - { - _connection = connection; - _slot = slot; - _options = options; - _baseCancellationToken = cancellationToken; - _walLocation = walLocation; - } + _connection = connection; + _slot = slot; + _options = options; + _baseCancellationToken = cancellationToken; + _walLocation = walLocation; - public IAsyncEnumerator GetAsyncEnumerator( - CancellationToken cancellationToken = new CancellationToken()) - { - using (NoSynchronizationContextScope.Enter()) - return StartReplicationInternal( - CancellationTokenSource.CreateLinkedTokenSource(_baseCancellationToken, cancellationToken).Token); - } + var connector = _connection.Connector; + _insertMessage = new(connector); + _defaultUpdateMessage = new(connector); + _fullUpdateMessage = new(connector); + _indexUpdateMessage = new(connector); + _fullDeleteMessage = new(connector); + _keyDeleteMessage = new(connector); + } - async IAsyncEnumerator StartReplicationInternal(CancellationToken cancellationToken) - { - var stream = _connection.StartLogicalReplication( - _slot, cancellationToken, _walLocation, _options.GetOptionPairs(), bypassingStream: true); - var buf = _connection.Connector!.ReadBuffer; + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => StartReplicationInternal(CancellationTokenSource.CreateLinkedTokenSource(_baseCancellationToken, cancellationToken).Token); + + async IAsyncEnumerator StartReplicationInternal(CancellationToken cancellationToken) + { + var stream = _connection.StartLogicalReplication( + _slot, cancellationToken, _walLocation, _options.GetOptionPairs(), bypassingStream: true); + var buf = _connection.Connector!.ReadBuffer; + var inStreamingTransaction = false; + var dataFormat = _options.Binary ?? false ? DataFormat.Binary : DataFormat.Text; - await foreach (var xLogData in stream.WithCancellation(cancellationToken)) + await foreach (var xLogData in stream.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await buf.EnsureAsync(1).ConfigureAwait(false); + var messageCode = (BackendReplicationMessageCode)buf.ReadByte(); + switch (messageCode) + { + case BackendReplicationMessageCode.Begin: + { + await buf.EnsureAsync(20).ConfigureAwait(false); + yield return _beginMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionFinalLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32()); + continue; + } + case BackendReplicationMessageCode.Message: { - await buf.EnsureAsync(1); - var messageCode = (BackendReplicationMessageCode)buf.ReadByte(); - switch (messageCode) + uint? transactionXid; + if (inStreamingTransaction) { - case BackendReplicationMessageCode.Begin: + await buf.EnsureAsync(14).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); + } + else { - await buf.EnsureAsync(20); - yield return _beginMessage.Populate( - xLogData.WalStart, - xLogData.WalEnd, - xLogData.ServerClock, - new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - TimestampHandler.FromPostgresTimestamp(buf.ReadInt64()), - buf.ReadUInt32() - ); - continue; + await buf.EnsureAsync(10).ConfigureAwait(false); + transactionXid = null; } - case BackendReplicationMessageCode.Commit: + + var flags = buf.ReadByte(); + var messageLsn = new NpgsqlLogSequenceNumber(buf.ReadUInt64()); + var prefix = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + await buf.EnsureAsync(4).ConfigureAwait(false); + var length = buf.ReadUInt32(); + var data = (NpgsqlReadBuffer.ColumnStream)xLogData.Data; + data.Init(checked((int)length), canSeek: false, commandScoped: false); + yield return _logicalDecodingMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + flags, messageLsn, prefix, data); + await data.DisposeAsync().ConfigureAwait(false); + continue; + } + case BackendReplicationMessageCode.Commit: + { + await buf.EnsureAsync(25).ConfigureAwait(false); + yield return _commitMessage.Populate( + xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + (CommitMessage.CommitFlags)buf.ReadByte(), + commitLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); + continue; + } + case BackendReplicationMessageCode.Origin: + { + await buf.EnsureAsync(9).ConfigureAwait(false); + yield return _originMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + originCommitLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + originName: await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false)); + continue; + } + case BackendReplicationMessageCode.Relation: + { + uint? transactionXid; + if (inStreamingTransaction) { - await buf.EnsureAsync(25); - yield return _commitMessage.Populate( - xLogData.WalStart, - xLogData.WalEnd, - xLogData.ServerClock, - buf.ReadByte(), - new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - TimestampHandler.FromPostgresTimestamp(buf.ReadInt64()) - ); - continue; + await buf.EnsureAsync(10).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); } - case BackendReplicationMessageCode.Origin: + else { - await buf.EnsureAsync(9); - yield return _originMessage.Populate( - xLogData.WalStart, - xLogData.WalEnd, - xLogData.ServerClock, - new NpgsqlLogSequenceNumber(buf.ReadUInt64()), - await buf.ReadNullTerminatedString(async: true, cancellationToken)); - continue; + await buf.EnsureAsync(6).ConfigureAwait(false); + transactionXid = null; } - case BackendReplicationMessageCode.Relation: + + var relationId = buf.ReadUInt32(); + var ns = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + var relationName = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + await buf.EnsureAsync(3).ConfigureAwait(false); + var relationReplicaIdentitySetting = (RelationMessage.ReplicaIdentitySetting)buf.ReadByte(); + var numColumns = buf.ReadUInt16(); + + if (!_relations.TryGetValue(relationId, out var msg)) + msg = _relations[relationId] = new RelationMessage(); + + msg.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, relationId, ns, relationName, + relationReplicaIdentitySetting); + + var columns = msg.InternalColumns; + columns.Count = numColumns; + for (var i = 0; i < numColumns; i++) { - await buf.EnsureAsync(6); - var relationId = buf.ReadUInt32(); - var ns = await buf.ReadNullTerminatedString(async: true, cancellationToken); - var relationName = await buf.ReadNullTerminatedString(async: true, cancellationToken); - await buf.EnsureAsync(3); - var relationReplicaIdentitySetting = (char)buf.ReadByte(); - var numColumns = buf.ReadUInt16(); - if (numColumns > _relationalMessageColumns.Length) - _relationalMessageColumns = new RelationMessage.Column[numColumns]; - for (var i = 0; i < numColumns; i++) - { - await buf.EnsureAsync(2); - var flags = buf.ReadByte(); - var columnName = await buf.ReadNullTerminatedString(async: true, cancellationToken); - await buf.EnsureAsync(8); - var dateTypeId = buf.ReadUInt32(); - var typeModifier = buf.ReadInt32(); - _relationalMessageColumns[i] = new RelationMessage.Column(flags, columnName, dateTypeId, typeModifier); - } + await buf.EnsureAsync(2).ConfigureAwait(false); + var flags = (RelationMessage.Column.ColumnFlags)buf.ReadByte(); + var columnName = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + await buf.EnsureAsync(8).ConfigureAwait(false); + var dateTypeId = buf.ReadUInt32(); + var typeModifier = buf.ReadInt32(); + columns[i] = new RelationMessage.Column(flags, columnName, dateTypeId, typeModifier); + } - yield return _relationMessage.Populate( - xLogData.WalStart, - xLogData.WalEnd, - xLogData.ServerClock, - relationId, - ns, - relationName, - relationReplicaIdentitySetting, - new ReadOnlyMemory(_relationalMessageColumns, 0, numColumns) - ); + msg.RowDescription = RowDescriptionMessage.CreateForReplication( + _connection.Connector.SerializerOptions, relationId, dataFormat, columns); - continue; + yield return msg; + continue; + } + case BackendReplicationMessageCode.Type: + { + uint? transactionXid; + if (inStreamingTransaction) + { + await buf.EnsureAsync(9).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); } - case BackendReplicationMessageCode.Type: + else { - await buf.EnsureAsync(5); - var typeId = buf.ReadUInt32(); - var ns = await buf.ReadNullTerminatedString(async: true, cancellationToken); - var name = await buf.ReadNullTerminatedString(async: true, cancellationToken); - yield return _typeMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, typeId, ns, name); + await buf.EnsureAsync(5).ConfigureAwait(false); + transactionXid = null; + } - continue; + var typeId = buf.ReadUInt32(); + var ns = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + var name = await buf.ReadNullTerminatedString(async: true, cancellationToken).ConfigureAwait(false); + yield return _typeMessage.Populate( + xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, typeId, ns, name); + continue; + } + case BackendReplicationMessageCode.Insert: + { + uint? transactionXid; + if (inStreamingTransaction) + { + await buf.EnsureAsync(11).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); } - case BackendReplicationMessageCode.Insert: + else { - await buf.EnsureAsync(7); - var relationId = buf.ReadUInt32(); - var tupleDataType = (TupleType)buf.ReadByte(); - Debug.Assert(tupleDataType == TupleType.NewTuple); - var numColumns = buf.ReadUInt16(); - var newRow = await ReadTupleDataAsync(ref _tupleDataArray1, numColumns); - yield return _insertMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, relationId, newRow); + await buf.EnsureAsync(7).ConfigureAwait(false); + transactionXid = null; + } - continue; + var relationId = buf.ReadUInt32(); + var tupleDataType = (TupleType)buf.ReadByte(); + Debug.Assert(tupleDataType == TupleType.NewTuple); + var numColumns = buf.ReadUInt16(); + + if (!_relations.TryGetValue(relationId, out var relation)) + { + throw new InvalidOperationException( + $"Could not find previous Relation message for relation ID {relationId} when processing Insert message"); } - case BackendReplicationMessageCode.Update: + + Debug.Assert(numColumns == relation.RowDescription.Count); + + yield return _insertMessage.Populate( + xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, relation, numColumns); + await _insertMessage.Consume(cancellationToken).ConfigureAwait(false); + + continue; + } + case BackendReplicationMessageCode.Update: + { + uint? transactionXid; + if (inStreamingTransaction) { - await buf.EnsureAsync(7); - var relationId = buf.ReadUInt32(); - var tupleType = (TupleType)buf.ReadByte(); - var numColumns = buf.ReadUInt16(); - switch (tupleType) - { - case TupleType.Key: - var keyRow = await ReadTupleDataAsync(ref _tupleDataArray1, numColumns); - await buf.EnsureAsync(3); - tupleType = (TupleType)buf.ReadByte(); - Debug.Assert(tupleType == TupleType.NewTuple); - numColumns = buf.ReadUInt16(); - var newRow = await ReadTupleDataAsync(ref _tupleDataArray2, numColumns); - yield return _indexUpdateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, - xLogData.ServerClock, relationId, newRow, keyRow); - continue; - case TupleType.OldTuple: - var oldRow = await ReadTupleDataAsync(ref _tupleDataArray1, numColumns); - await buf.EnsureAsync(3); - tupleType = (TupleType)buf.ReadByte(); - Debug.Assert(tupleType == TupleType.NewTuple); - numColumns = buf.ReadUInt16(); - newRow = await ReadTupleDataAsync(ref _tupleDataArray2, numColumns); - yield return _fullUpdateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, - xLogData.ServerClock, relationId, newRow, oldRow); - continue; - case TupleType.NewTuple: - newRow = await ReadTupleDataAsync(ref _tupleDataArray1, numColumns); - yield return _updateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, - xLogData.ServerClock, relationId, newRow); - continue; - default: - throw new NotSupportedException($"The tuple type '{tupleType}' is not supported."); - } + await buf.EnsureAsync(11).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); } - case BackendReplicationMessageCode.Delete: + else { - await buf.EnsureAsync(7); - var relationId = buf.ReadUInt32(); - var tupleDataType = (TupleType)buf.ReadByte(); - var numColumns = buf.ReadUInt16(); - switch (tupleDataType) - { - case TupleType.Key: - yield return _keyDeleteMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, - relationId, await ReadTupleDataAsync(ref _tupleDataArray1, numColumns)); - continue; - case TupleType.OldTuple: - yield return _fullDeleteMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, - relationId, await ReadTupleDataAsync(ref _tupleDataArray1, numColumns)); - continue; - default: - throw new NotSupportedException($"The tuple type '{tupleDataType}' is not supported."); - } + await buf.EnsureAsync(7).ConfigureAwait(false); + transactionXid = null; } - case BackendReplicationMessageCode.Truncate: + + var relationId = buf.ReadUInt32(); + var tupleType = (TupleType)buf.ReadByte(); + var numColumns = buf.ReadUInt16(); + + if (!_relations.TryGetValue(relationId, out var relation)) { - await buf.EnsureAsync(9); - // Don't dare to truncate more than 2147483647 tables at once! - var numRels = checked((int)buf.ReadUInt32()); - var truncateOptions = (TruncateOptions)buf.ReadByte(); - var relationIds = new uint[numRels]; - await buf.EnsureAsync(checked(numRels * 4)); - - for (var i = 0; i < numRels; i++) - relationIds[i] = buf.ReadUInt32(); - - yield return _truncateMessage.Populate( - xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, truncateOptions, relationIds); - continue; + throw new InvalidOperationException( + $"Could not find previous Relation message for relation ID {relationId} when processing Update message"); } + + Debug.Assert(numColumns == relation.RowDescription.Count); + + switch (tupleType) + { + case TupleType.Key: + yield return _indexUpdateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + relation, numColumns); + await _indexUpdateMessage.Consume(cancellationToken).ConfigureAwait(false); + continue; + case TupleType.OldTuple: + yield return _fullUpdateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + relation, numColumns); + await _fullUpdateMessage.Consume(cancellationToken).ConfigureAwait(false); + continue; + case TupleType.NewTuple: + yield return _defaultUpdateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + relation, numColumns); + await _defaultUpdateMessage.Consume(cancellationToken).ConfigureAwait(false); + continue; default: - throw new NotSupportedException( - $"Invalid message code {messageCode} in Logical Replication Protocol."); + throw new NotSupportedException($"The tuple type '{tupleType}' is not supported."); } } + case BackendReplicationMessageCode.Delete: + { + uint? transactionXid; + if (inStreamingTransaction) + { + await buf.EnsureAsync(11).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); + } + else + { + await buf.EnsureAsync(7).ConfigureAwait(false); + transactionXid = null; + } + + var relationId = buf.ReadUInt32(); + var tupleDataType = (TupleType)buf.ReadByte(); + var numColumns = buf.ReadUInt16(); + + if (!_relations.TryGetValue(relationId, out var relation)) + { + throw new InvalidOperationException( + $"Could not find previous Relation message for relation ID {relationId} when processing Update message"); + } - // We never get here - the above is an endless loop that terminates only with a cancellation exception + Debug.Assert(numColumns == relation.RowDescription.Count); - ValueTask> ReadTupleDataAsync(ref TupleData[] array, ushort numberOfColumns) + switch (tupleDataType) + { + case TupleType.Key: + yield return _keyDeleteMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + relation, numColumns); + await _keyDeleteMessage.Consume(cancellationToken).ConfigureAwait(false); + continue; + case TupleType.OldTuple: + yield return _fullDeleteMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + relation, numColumns); + await _fullDeleteMessage.Consume(cancellationToken).ConfigureAwait(false); + continue; + default: + throw new NotSupportedException($"The tuple type '{tupleDataType}' is not supported."); + } + } + case BackendReplicationMessageCode.Truncate: { - if (array.Length < numberOfColumns) - array = new TupleData[numberOfColumns]; - var nonRefArray = array; - return ReadTupleDataAsync2(); + uint? transactionXid; + if (inStreamingTransaction) + { + await buf.EnsureAsync(9).ConfigureAwait(false); + transactionXid = buf.ReadUInt32(); + } + else + { + await buf.EnsureAsync(5).ConfigureAwait(false); + transactionXid = null; + } - async ValueTask> ReadTupleDataAsync2() + // Don't dare to truncate more than 2147483647 tables at once! + var numRels = checked((int)buf.ReadUInt32()); + var truncateOptions = (TruncateMessage.TruncateOptions)buf.ReadByte(); + _truncateMessageRelations.Count = numRels; + for (var i = 0; i < numRels; i++) { - for (var i = 0; i < numberOfColumns; i++) + await buf.EnsureAsync(4).ConfigureAwait(false); + + var relationId = buf.ReadUInt32(); + if (!_relations.TryGetValue(relationId, out var relation)) { - await buf.EnsureAsync(1); - var subMessageKind = (TupleDataKind)buf.ReadByte(); - switch (subMessageKind) - { - case TupleDataKind.Null: - case TupleDataKind.UnchangedToastedValue: - nonRefArray[i] = new TupleData(subMessageKind); - continue; - case TupleDataKind.TextValue: - await buf.EnsureAsync(4); - var len = buf.ReadInt32(); - await buf.EnsureAsync(len); - nonRefArray![i] = new TupleData(buf.ReadString(len)); - continue; - default: - throw new NotSupportedException($"The tuple data kind '{subMessageKind}' is not supported."); - } + throw new InvalidOperationException( + $"Could not find previous Relation message for relation ID {relationId} when processing Update message"); } - return new ReadOnlyMemory(nonRefArray, 0, numberOfColumns); + _truncateMessageRelations[i] = relation; } + + yield return _truncateMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, transactionXid, + truncateOptions, _truncateMessageRelations); + continue; + } + case BackendReplicationMessageCode.StreamStart: + { + await buf.EnsureAsync(5).ConfigureAwait(false); + inStreamingTransaction = true; + yield return _streamStartMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionXid: buf.ReadUInt32(), streamSegmentIndicator: buf.ReadByte()); + continue; + } + case BackendReplicationMessageCode.StreamStop: + { + inStreamingTransaction = false; + yield return _streamStopMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock); + continue; + } + case BackendReplicationMessageCode.StreamCommit: + { + await buf.EnsureAsync(29).ConfigureAwait(false); + yield return _streamCommitMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionXid: buf.ReadUInt32(), flags: buf.ReadByte(), commitLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); + continue; + } + case BackendReplicationMessageCode.StreamAbort: + { + await buf.EnsureAsync(8).ConfigureAwait(false); + yield return _streamAbortMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionXid: buf.ReadUInt32(), subtransactionXid: buf.ReadUInt32()); + continue; + } + case BackendReplicationMessageCode.BeginPrepare: + { + await buf.EnsureAsync(29).ConfigureAwait(false); + yield return _beginPrepareMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32(), + transactionGid: buf.ReadNullTerminatedString()); + continue; + } + case BackendReplicationMessageCode.Prepare: + { + await buf.EnsureAsync(30).ConfigureAwait(false); + yield return _prepareMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + flags: (PrepareMessage.PrepareFlags)buf.ReadByte(), + prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32(), + transactionGid: buf.ReadNullTerminatedString()); + continue; + } + case BackendReplicationMessageCode.CommitPrepared: + { + await buf.EnsureAsync(30).ConfigureAwait(false); + yield return _commitPreparedMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + flags: (CommitPreparedMessage.CommitPreparedFlags)buf.ReadByte(), + commitPreparedLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + commitPreparedEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionCommitTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32(), + transactionGid: buf.ReadNullTerminatedString()); + continue; + } + case BackendReplicationMessageCode.RollbackPrepared: + { + await buf.EnsureAsync(38).ConfigureAwait(false); + yield return _rollbackPreparedMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + flags: (RollbackPreparedMessage.RollbackPreparedFlags)buf.ReadByte(), + preparedTransactionEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + rollbackPreparedEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionRollbackTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32(), + transactionGid: buf.ReadNullTerminatedString()); + continue; + } + case BackendReplicationMessageCode.StreamPrepare: + { + await buf.EnsureAsync(30).ConfigureAwait(false); + yield return _streamPrepareMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + flags: (StreamPrepareMessage.StreamPrepareFlags)buf.ReadByte(), + prepareLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + prepareEndLsn: new NpgsqlLogSequenceNumber(buf.ReadUInt64()), + transactionPrepareTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc), + transactionXid: buf.ReadUInt32(), + transactionGid: buf.ReadNullTerminatedString()); + continue; + } + default: + throw new NotSupportedException( + $"Invalid message code {messageCode} in Logical Replication Protocol."); } } - enum BackendReplicationMessageCode : byte - { - Begin = (byte)'B', - Commit = (byte)'C', - Origin = (byte)'O', - Relation = (byte)'R', - Type = (byte)'Y', - Insert = (byte)'I', - Update = (byte)'U', - Delete = (byte)'D', - Truncate = (byte)'T' - } + // We never get here - the above is an endless loop that terminates only with a cancellation exception + } - enum TupleType : byte - { - Key = (byte)'K', - NewTuple = (byte)'N', - OldTuple = (byte)'O', - } + enum BackendReplicationMessageCode : byte + { + Begin = (byte)'B', + Message = (byte)'M', + Commit = (byte)'C', + Origin = (byte)'O', + Relation = (byte)'R', + Type = (byte)'Y', + Insert = (byte)'I', + Update = (byte)'U', + Delete = (byte)'D', + Truncate = (byte)'T', + StreamStart = (byte)'S', + StreamStop = (byte)'E', + StreamCommit = (byte)'c', + StreamAbort = (byte)'A', + BeginPrepare = (byte)'b', + Prepare = (byte)'P', + CommitPrepared = (byte)'K', + RollbackPrepared = (byte)'r', + StreamPrepare = (byte)'p', } } diff --git a/src/Npgsql/Replication/PgOutput/PgOutputConnectionExtensions.cs b/src/Npgsql/Replication/PgOutput/PgOutputConnectionExtensions.cs index d616acec1d..c67af16d58 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputConnectionExtensions.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputConnectionExtensions.cs @@ -7,82 +7,86 @@ using Npgsql.Replication.PgOutput.Messages; // ReSharper disable once CheckNamespace -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Extension methods to use with the pg_output logical decoding plugin. +/// +public static class PgOutputConnectionExtensions { /// - /// Extension methods to use with the pg_output logical decoding plugin. + /// Creates a class that wraps a replication slot using the + /// "pgoutput" logical decoding plugin and can be used to start streaming replication via the logical + /// streaming replication protocol. /// - public static class PgOutputConnectionExtensions + /// + /// See https://www.postgresql.org/docs/current/protocol-logical-replication.html + /// and https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html + /// for more information. + /// + /// The to use for creating the replication slot + /// The name of the slot to create. Must be a valid replication slot name (see + /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). + /// + /// + /// if this replication slot shall be temporary one; otherwise . + /// Temporary slots are not saved to disk and are automatically dropped on error or when the session has finished. + /// + /// + /// A to specify what to do with the snapshot created during logical slot + /// initialization. , which is also the default, will export the + /// snapshot for use in other sessions. This option can't be used inside a transaction. + /// will use the snapshot for the current transaction executing the + /// command. This option must be used in a transaction, and must be the + /// first command run in that transaction. Finally, will just use + /// the snapshot for logical decoding as normal but won't do anything else with it. + /// + /// + /// If , this logical replication slot supports decoding of two-phase transactions. With this option, + /// two-phase commands like PREPARE TRANSACTION, COMMIT PREPARED and ROLLBACK PREPARED are decoded and transmitted. + /// The transaction will be decoded and transmitted at PREPARE TRANSACTION time. The default is . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + /// A that wraps the newly-created replication slot. + /// + public static async Task CreatePgOutputReplicationSlot( + this LogicalReplicationConnection connection, + string slotName, + bool temporarySlot = false, + LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, + bool twoPhase = false, + CancellationToken cancellationToken = default) { - /// - /// Creates a class that wraps a replication slot using the - /// "pgoutput" logical decoding plugin and can be used to start streaming replication via the logical - /// streaming replication protocol. - /// - /// - /// See https://www.postgresql.org/docs/current/protocol-logical-replication.html - /// and https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html - /// for more information. - /// - /// The to use for creating the replication slot - /// The name of the slot to create. Must be a valid replication slot name (see - /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). - /// - /// - /// if this replication slot shall be temporary one; otherwise . - /// Temporary slots are not saved to disk and are automatically dropped on error or when the session has finished. - /// - /// - /// A to specify what to do with the snapshot created during logical slot - /// initialization. , which is also the default, will export the - /// snapshot for use in other sessions. This option can't be used inside a transaction. - /// will use the snapshot for the current transaction executing the - /// command. This option must be used in a transaction, and must be the - /// first command run in that transaction. Finally, will just use - /// the snapshot for logical decoding as normal but won't do anything else with it. - /// - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// - /// - /// A that wraps the newly-created replication slot. - /// - public static async Task CreatePgOutputReplicationSlot( - this LogicalReplicationConnection connection, - string slotName, - bool temporarySlot = false, - LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, - CancellationToken cancellationToken = default) - { - // We don't enter NoSynchronizationContextScope here since we (have to) do it in CreateLogicalReplicationSlot, because - // otherwise it wouldn't be set for external plugins. - var options = await connection.CreateLogicalReplicationSlot( - slotName, "pgoutput", temporarySlot, slotSnapshotInitMode, cancellationToken).ConfigureAwait(false); - return new PgOutputReplicationSlot(options); - } - - /// - /// Instructs the server to start the Logical Streaming Replication Protocol (pgoutput logical decoding plugin), - /// starting at WAL location or at the slot's consistent point if - /// isn't specified. - /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. - /// - /// The to use for starting replication - /// The replication slot that will be updated as replication progresses so that the server - /// knows which WAL segments are still needed by the standby. - /// - /// The collection of options passed to the slot's logical decoding plugin. - /// The token to monitor for stopping the replication. - /// The WAL location to begin streaming at. - /// A representing an that - /// can be used to stream WAL entries in form of instances. - public static IAsyncEnumerable StartReplication( - this LogicalReplicationConnection connection, - PgOutputReplicationSlot slot, - PgOutputReplicationOptions options, - CancellationToken cancellationToken, - NpgsqlLogSequenceNumber? walLocation = null) - => new PgOutputAsyncEnumerable(connection, slot, options, cancellationToken, walLocation); + // We don't enter NoSynchronizationContextScope here since we (have to) do it in CreateLogicalReplicationSlot, because + // otherwise it wouldn't be set for external plugins. + var options = await connection.CreateLogicalReplicationSlot( + slotName, "pgoutput", temporarySlot, slotSnapshotInitMode, twoPhase, cancellationToken).ConfigureAwait(false); + return new PgOutputReplicationSlot(options); } -} + + /// + /// Instructs the server to start the Logical Streaming Replication Protocol (pgoutput logical decoding plugin), + /// starting at WAL location or at the slot's consistent point if + /// isn't specified. + /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. + /// + /// The to use for starting replication + /// The replication slot that will be updated as replication progresses so that the server + /// knows which WAL segments are still needed by the standby. + /// + /// The collection of options passed to the slot's logical decoding plugin. + /// The token to monitor for stopping the replication. + /// The WAL location to begin streaming at. + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public static IAsyncEnumerable StartReplication( + this LogicalReplicationConnection connection, + PgOutputReplicationSlot slot, + PgOutputReplicationOptions options, + CancellationToken cancellationToken, + NpgsqlLogSequenceNumber? walLocation = null) + => new PgOutputAsyncEnumerable(connection, slot, options, cancellationToken, walLocation); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs b/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs index 13895c98bf..93039fdf25 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs @@ -2,109 +2,141 @@ using System.Collections.Generic; using System.Globalization; -namespace Npgsql.Replication.PgOutput +namespace Npgsql.Replication.PgOutput; + +/// +/// Options to be passed to the pgoutput plugin +/// +public class PgOutputReplicationOptions : IEquatable { /// - /// Options to be passed to the pgoutput plugin + /// Creates a new instance of . + /// + /// The publication names to include into the stream + /// The version of the logical streaming replication protocol + /// Send values in binary representation + /// Enable streaming of in-progress transactions + /// Write logical decoding messages into the replication stream + /// Enable streaming of prepared transactions + public PgOutputReplicationOptions(string publicationName, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) + : this(new List { publicationName ?? throw new ArgumentNullException(nameof(publicationName)) }, protocolVersion, binary, streaming, messages, twoPhase) + { } + + /// + /// Creates a new instance of . + /// + /// The publication names to include into the stream + /// The version of the logical streaming replication protocol + /// Send values in binary representation + /// Enable streaming of in-progress transactions + /// Write logical decoding messages into the replication stream + /// Enable streaming of prepared transactions + public PgOutputReplicationOptions(IEnumerable publicationNames, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) + { + var publicationNamesList = new List(publicationNames); + if (publicationNamesList.Count < 1) + throw new ArgumentException("You have to pass at least one publication name.", nameof(publicationNames)); + + foreach (var publicationName in publicationNamesList) + if (string.IsNullOrWhiteSpace(publicationName)) + throw publicationName is null + ? new ArgumentNullException(nameof(publicationName)) + : new ArgumentException("Invalid publication name", nameof(publicationName)); + + PublicationNames = publicationNamesList; + ProtocolVersion = protocolVersion; + Binary = binary; + Streaming = streaming; + Messages = messages; + TwoPhase = twoPhase; + } + + /// + /// The version of the Logical Streaming Replication Protocol + /// + public ulong ProtocolVersion { get; } + + /// + /// The publication names to stream + /// + public List PublicationNames { get; } + + /// + /// Send values in binary representation /// - public class PgOutputReplicationOptions : IEquatable + /// + /// This works in PostgreSQL versions 14+ + /// + // See: https://github.com/postgres/postgres/commit/9de77b5453130242654ff0b30a551c9c862ed661 + public bool? Binary { get; } + + /// + /// Enable streaming of in-progress transactions + /// + /// + /// This works as of logical streaming replication protocol version 2 (PostgreSQL 14+) + /// + // See: https://github.com/postgres/postgres/commit/464824323e57dc4b397e8b05854d779908b55304 + public bool? Streaming { get; } + + /// + /// Write logical decoding messages into the replication stream + /// + /// + /// This works in PostgreSQL versions 14+ + /// + // See: https://github.com/postgres/postgres/commit/ac4645c0157fc5fcef0af8ff571512aa284a2cec + public bool? Messages { get; } + + /// + /// Enable streaming of prepared transactions + /// + /// + /// This works in PostgreSQL versions 15+ + /// + // See: https://github.com/postgres/postgres/commit/a8fd13cab0ba815e9925dc9676e6309f699b5f72 + // and https://github.com/postgres/postgres/commit/63cf61cdeb7b0450dcf3b2f719c553177bac85a2 + public bool? TwoPhase { get; } + + internal IEnumerable> GetOptionPairs() + { + yield return new KeyValuePair("proto_version", ProtocolVersion.ToString(CultureInfo.InvariantCulture)); + yield return new KeyValuePair("publication_names", "\"" + string.Join("\",\"", PublicationNames) + "\""); + + if (Binary != null) + yield return new KeyValuePair("binary", Binary.Value ? "on" : "off"); + if (Streaming != null) + yield return new KeyValuePair("streaming", Streaming.Value ? "on" : "off"); + if (Messages != null) + yield return new KeyValuePair("messages", Messages.Value ? "on" : "off"); + if (TwoPhase != null) + yield return new KeyValuePair("two_phase", TwoPhase.Value ? "on" : "off"); + } + + /// + public bool Equals(PgOutputReplicationOptions? other) + => other != null && ( + ReferenceEquals(this, other) || + ProtocolVersion == other.ProtocolVersion && PublicationNames.Equals(other.PublicationNames) && Binary == other.Binary && + Streaming == other.Streaming && Messages == other.Messages && TwoPhase == other.TwoPhase); + + /// + public override bool Equals(object? obj) + => obj is PgOutputReplicationOptions other && other.Equals(this); + + /// + public override int GetHashCode() { - /// - /// Creates a new instance of . - /// - /// The publication names to include into the stream - /// The version of the logical streaming replication protocol - /// Send values in binary representation - /// Enable streaming output - public PgOutputReplicationOptions(string publicationName, ulong protocolVersion = 1UL, bool? binary = null, bool? streaming = null) - : this(new List { publicationName ?? throw new ArgumentNullException(nameof(publicationName)) }, protocolVersion, binary, streaming) - { } - - /// - /// Creates a new instance of . - /// - /// The publication names to include into the stream - /// The version of the logical streaming replication protocol - /// Send values in binary representation - /// Enable streaming output - public PgOutputReplicationOptions(IEnumerable publicationNames, ulong protocolVersion = 1UL, bool? binary = null, bool? streaming = null) - { - var publicationNamesList = new List(publicationNames); - if (publicationNamesList.Count < 1) - throw new ArgumentException("You have to pass at least one publication name.", nameof(publicationNames)); - - foreach (var publicationName in publicationNamesList) - if (string.IsNullOrWhiteSpace(publicationName)) - throw publicationName is null - ? new ArgumentNullException(nameof(publicationName)) - : new ArgumentException("Invalid publication name", nameof(publicationName)); - - PublicationNames = publicationNamesList; - ProtocolVersion = protocolVersion; - Binary = binary; - Streaming = streaming; - } - - /// - /// The version of the logical streaming replication protocol - /// - public ulong ProtocolVersion { get; } - - /// - /// The publication names to stream - /// - public List PublicationNames { get; } - - /// - /// Send values in binary representation - /// - /// - /// This works in PostgreSQL versions 14+ - /// - public bool? Binary { get; } - - /// - /// Enable streaming output - /// - /// - /// This works in PostgreSQL versions 14+ - /// - public bool? Streaming { get; } - - internal IEnumerable> GetOptionPairs() - { - yield return new KeyValuePair("proto_version", ProtocolVersion.ToString(CultureInfo.InvariantCulture)); - yield return new KeyValuePair("publication_names", "\"" + string.Join("\",\"", PublicationNames) + "\""); - - if (Binary != null) - yield return new KeyValuePair("binary", Binary.Value ? "t" : "f"); - if (Streaming != null) - yield return new KeyValuePair("streaming", Streaming.Value ? "t" : "f"); - } - - /// - public bool Equals(PgOutputReplicationOptions? other) - => other != null && ( - ReferenceEquals(this, other) || - ProtocolVersion == other.ProtocolVersion && PublicationNames.Equals(other.PublicationNames) && Binary == other.Binary && - Streaming == other.Streaming); - - /// - public override bool Equals(object? obj) - => obj is PgOutputReplicationOptions other && other.Equals(this); - - /// - public override int GetHashCode() - { #if NETSTANDARD2_0 - var hashCode = ProtocolVersion.GetHashCode(); - hashCode = (hashCode * 397) ^ PublicationNames.GetHashCode(); - hashCode = (hashCode * 397) ^ Binary.GetHashCode(); - hashCode = (hashCode * 397) ^ Streaming.GetHashCode(); - return hashCode; + var hashCode = ProtocolVersion.GetHashCode(); + hashCode = (hashCode * 397) ^ PublicationNames.GetHashCode(); + hashCode = (hashCode * 397) ^ Binary.GetHashCode(); + hashCode = (hashCode * 397) ^ Streaming.GetHashCode(); + hashCode = (hashCode * 397) ^ Messages.GetHashCode(); + hashCode = (hashCode * 397) ^ TwoPhase.GetHashCode(); + return hashCode; #else - return HashCode.Combine(ProtocolVersion, PublicationNames, Binary, Streaming); + return HashCode.Combine(ProtocolVersion, PublicationNames, Binary, Streaming, Messages, TwoPhase); #endif - } } } diff --git a/src/Npgsql/Replication/PgOutput/PgOutputReplicationSlot.cs b/src/Npgsql/Replication/PgOutput/PgOutputReplicationSlot.cs index 0141a5ddff..a873f585fc 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputReplicationSlot.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputReplicationSlot.cs @@ -1,45 +1,44 @@ using Npgsql.Replication.Internal; -namespace Npgsql.Replication.PgOutput +namespace Npgsql.Replication.PgOutput; + +/// +/// Acts as a proxy for a logical replication slot initialized for for the logical streaming replication protocol +/// (pgoutput logical decoding plugin). +/// +public class PgOutputReplicationSlot : LogicalReplicationSlot { /// - /// Acts as a proxy for a logical replication slot initialized for for the logical streaming replication protocol - /// (pgoutput logical decoding plugin). + /// Creates a new instance. /// - public class PgOutputReplicationSlot : LogicalReplicationSlot - { - /// - /// Creates a new instance. - /// - /// - /// Create a instance with this - /// constructor to wrap an existing PostgreSQL replication slot that has - /// been initialized for the pgoutput logical decoding plugin. - /// - /// The name of the existing replication slot - public PgOutputReplicationSlot(string slotName) - : this(new ReplicationSlotOptions(slotName)) { } + /// + /// Create a instance with this + /// constructor to wrap an existing PostgreSQL replication slot that has + /// been initialized for the pgoutput logical decoding plugin. + /// + /// The name of the existing replication slot + public PgOutputReplicationSlot(string slotName) + : this(new ReplicationSlotOptions(slotName)) { } - /// - /// Creates a new instance. - /// - /// - /// Create a instance with this - /// constructor to wrap an existing PostgreSQL replication slot that has - /// been initialized for the pgoutput logical decoding plugin. - /// - /// The representing the existing replication slot - public PgOutputReplicationSlot(ReplicationSlotOptions options) : base("pgoutput", options) { } + /// + /// Creates a new instance. + /// + /// + /// Create a instance with this + /// constructor to wrap an existing PostgreSQL replication slot that has + /// been initialized for the pgoutput logical decoding plugin. + /// + /// The representing the existing replication slot + public PgOutputReplicationSlot(ReplicationSlotOptions options) : base("pgoutput", options) { } - /// - /// Creates a new instance. - /// - /// - /// This constructor is intended to be consumed by plugins sitting on top of - /// - /// - /// The from which the new instance should be initialized - protected PgOutputReplicationSlot(PgOutputReplicationSlot slot) - : base(slot.OutputPlugin, new ReplicationSlotOptions(slot.Name, slot.ConsistentPoint, slot.SnapshotName)) { } - } -} + /// + /// Creates a new instance. + /// + /// + /// This constructor is intended to be consumed by plugins sitting on top of + /// + /// + /// The from which the new instance should be initialized + protected PgOutputReplicationSlot(PgOutputReplicationSlot slot) + : base(slot.OutputPlugin, new ReplicationSlotOptions(slot.Name, slot.ConsistentPoint, slot.SnapshotName)) { } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs new file mode 100644 index 0000000000..df910af4d2 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Npgsql.Replication.PgOutput; + +sealed class ReadOnlyArrayBuffer : IReadOnlyList +{ + public static readonly ReadOnlyArrayBuffer Empty = new(); + T[] _items; + int _size; + + public ReadOnlyArrayBuffer() + => _items = Array.Empty(); + + ReadOnlyArrayBuffer(T[] items) + { + _items = items; + _size = items.Length; + } + + public IEnumerator GetEnumerator() + { + for (var i = 0; i < _size; i++) + { + yield return _items[i]; + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public int Count + { + get => _size; + internal set + { + if (_items.Length < value) + _items = new T[value]; + + _size = value; + } + } + + public T this[int index] + { + get => index < _size ? _items[index] : throw new IndexOutOfRangeException(); + internal set => _items[index] = value; + } + + public ReadOnlyArrayBuffer Clone() + { + var newItems = new T[_size]; + if (_size > 0) + Array.Copy(_items, newItems, _size); + return new(newItems); + } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/ReplicationTuple.cs b/src/Npgsql/Replication/PgOutput/ReplicationTuple.cs new file mode 100644 index 0000000000..43bd08b4ac --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/ReplicationTuple.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Internal; + +namespace Npgsql.Replication.PgOutput; + +/// +/// Represents a streaming tuple containing . +/// +public class ReplicationTuple : IAsyncEnumerable +{ + private protected readonly NpgsqlReadBuffer ReadBuffer; + readonly TupleEnumerator _tupleEnumerator; + + internal RowState State; + + /// + /// The number of columns in the tuple. + /// + public ushort NumColumns { get; private set; } + + RowDescriptionMessage _rowDescription = null!; + + internal ReplicationTuple(NpgsqlConnector connector) + => (ReadBuffer, _tupleEnumerator) = (connector.ReadBuffer, new(this, connector)); + + internal void Reset(ushort numColumns, RowDescriptionMessage rowDescription) + { + State = RowState.NotRead; + (NumColumns, _rowDescription) = (numColumns, rowDescription); + } + + /// + public virtual IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + switch (State) + { + case RowState.NotRead: + _tupleEnumerator.Reset(NumColumns, _rowDescription, cancellationToken); + State = RowState.Reading; + return _tupleEnumerator; + case RowState.Reading: + throw new InvalidOperationException("The row is already been read."); + case RowState.Consumed: + throw new InvalidOperationException("The row has already been consumed."); + default: + throw new ArgumentOutOfRangeException(); + } + } + + internal async Task Consume(CancellationToken cancellationToken) + { + switch (State) + { + case RowState.NotRead: + State = RowState.Reading; + _tupleEnumerator.Reset(NumColumns, _rowDescription, cancellationToken); + while (await _tupleEnumerator.MoveNextAsync().ConfigureAwait(false)) { } + break; + case RowState.Reading: + while (await _tupleEnumerator.MoveNextAsync().ConfigureAwait(false)) { } + break; + case RowState.Consumed: + return; + default: + throw new ArgumentOutOfRangeException(); + } + } +} + +enum RowState +{ + NotRead, + Reading, + Consumed +} diff --git a/src/Npgsql/Replication/PgOutput/ReplicationValue.cs b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs new file mode 100644 index 0000000000..aed44411d7 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs @@ -0,0 +1,198 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Internal; +using Npgsql.PostgresTypes; + +namespace Npgsql.Replication.PgOutput; + +/// +/// Represents a column value in a logical replication session. +/// +public class ReplicationValue +{ + readonly NpgsqlReadBuffer _readBuffer; + + /// + /// The length of the value in bytes. + /// + public int Length { get; private set; } + + /// + /// The kind of data transmitted for a tuple in a Logical Replication Protocol message. + /// + public TupleDataKind Kind { get; private set; } + + FieldDescription _fieldDescription = null!; + ColumnInfo _lastInfo; + bool _isConsumed; + + PgReader PgReader => _readBuffer.PgReader; + + internal ReplicationValue(NpgsqlConnector connector) => _readBuffer = connector.ReadBuffer; + + internal void Reset(TupleDataKind kind, int length, FieldDescription fieldDescription) + { + Kind = kind; + Length = length; + _fieldDescription = fieldDescription; + _lastInfo = default; + _isConsumed = false; + } + + // ReSharper disable once InconsistentNaming + /// + /// Gets a value that indicates whether the column contains nonexistent or missing values. + /// + /// true if the specified column is equivalent to ; otherwise false. + public bool IsDBNull + => Kind == TupleDataKind.Null; + + /// + /// Gets a value that indicates whether the column contains an unchanged TOASTed value (the actual value is not sent). + /// + /// Whether the specified column is an unchanged TOASTed value. + public bool IsUnchangedToastedValue + => Kind == TupleDataKind.UnchangedToastedValue; + + /// + /// Gets a representation of the PostgreSQL data type for the specified field. + /// The returned representation can be used to access various information about the field. + /// + public PostgresType GetPostgresType() => _fieldDescription.PostgresType; + + /// + /// Gets the data type information for the specified field. + /// This is be the PostgreSQL type name (e.g. double precision), not the .NET type + /// (see for that). + /// + public string GetDataTypeName() => _fieldDescription.TypeDisplayName; + + /// + /// Gets the data type of the specified column. + /// + /// The data type of the specified column. + public Type GetFieldType() => _fieldDescription.FieldType; + + /// + /// Gets the value of the specified column as a type. + /// + /// The type of the value to be returned. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + public async ValueTask Get(CancellationToken cancellationToken = default) + { + CheckActive(); + + _fieldDescription.GetInfo(typeof(T), ref _lastInfo); + var info = _lastInfo; + + switch (Kind) + { + case TupleDataKind.Null: + // When T is a Nullable (and only in that case), we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + + if (typeof(T) == typeof(object)) + return (T)(object)DBNull.Value; + + ThrowHelper.ThrowInvalidCastException_NoValue(_fieldDescription); + break; + + case TupleDataKind.UnchangedToastedValue: + throw new InvalidCastException( + $"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); + } + + using var registration = _readBuffer.Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + var reader = PgReader.Init(Length, _fieldDescription.DataFormat); + await reader.StartReadAsync(info.ConverterInfo.BufferRequirement, cancellationToken).ConfigureAwait(false); + var result = info.AsObject + ? (T)await info.ConverterInfo.Converter.ReadAsObjectAsync(reader, cancellationToken).ConfigureAwait(false) + : await info.ConverterInfo.Converter.UnsafeDowncast().ReadAsync(reader, cancellationToken).ConfigureAwait(false); + await reader.EndReadAsync().ConfigureAwait(false); + return result; + } + + /// + /// Gets the value of the specified column as an instance of . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + public ValueTask Get(CancellationToken cancellationToken = default) => Get(cancellationToken); + + /// + /// Retrieves data as a . + /// + public Stream GetStream() + { + CheckActive(); + + switch (Kind) + { + case TupleDataKind.Null: + ThrowHelper.ThrowInvalidCastException_NoValue(_fieldDescription); + break; + + case TupleDataKind.UnchangedToastedValue: + throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); + } + + var reader = _readBuffer.PgReader.Init(Length, _fieldDescription.DataFormat); + return reader.GetStream(canSeek: false); + } + + /// + /// Retrieves data as a . + /// + public TextReader GetTextReader() + { + CheckActive(); + + ref var info = ref _lastInfo; + _fieldDescription.GetInfo(typeof(TextReader), ref info); + + switch (Kind) + { + case TupleDataKind.Null: + ThrowHelper.ThrowInvalidCastException_NoValue(_fieldDescription); + break; + + case TupleDataKind.UnchangedToastedValue: + throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); + } + + var reader = PgReader.Init(Length, _fieldDescription.DataFormat); + reader.StartRead(info.ConverterInfo.BufferRequirement); + var result = (TextReader)info.ConverterInfo.Converter.ReadAsObject(reader); + reader.EndRead(); + return result; + } + + internal async Task Consume(CancellationToken cancellationToken) + { + if (_isConsumed) + return; + + if (!PgReader.Initialized) + PgReader.Init(Length, _fieldDescription.DataFormat); + await PgReader.ConsumeAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + + _isConsumed = true; + } + + void CheckActive() + { + if (PgReader.Initialized) + throw new InvalidOperationException("Column has already been consumed"); + } +} diff --git a/src/Npgsql/Replication/PgOutput/TupleDataKind.cs b/src/Npgsql/Replication/PgOutput/TupleDataKind.cs new file mode 100644 index 0000000000..141e4af16e --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/TupleDataKind.cs @@ -0,0 +1,28 @@ +namespace Npgsql.Replication.PgOutput; + +/// +/// The kind of data transmitted for a tuple in a Logical Replication Protocol message. +/// +public enum TupleDataKind : byte +{ + /// + /// Identifies the data as NULL value. + /// + Null = (byte)'n', + + /// + /// Identifies unchanged TOASTed value (the actual value is not sent). + /// + UnchangedToastedValue = (byte)'u', + + /// + /// Identifies the data as text formatted value. + /// + TextValue = (byte)'t', + + /// + /// Identifies the data as binary value. + /// + /// Added in PG14 + BinaryValue = (byte)'b' +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs b/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs new file mode 100644 index 0000000000..cee25671af --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/TupleEnumerator.cs @@ -0,0 +1,93 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Internal; + +namespace Npgsql.Replication.PgOutput; + +sealed class TupleEnumerator : IAsyncEnumerator +{ + readonly ReplicationTuple _tupleEnumerable; + readonly NpgsqlReadBuffer _readBuffer; + readonly ReplicationValue _value; + + ushort _numColumns; + int _pos; + RowDescriptionMessage _rowDescription = null!; + CancellationToken _cancellationToken; + + internal TupleEnumerator(ReplicationTuple tupleEnumerable, NpgsqlConnector connector) + { + _tupleEnumerable = tupleEnumerable; + _readBuffer = connector.ReadBuffer; + _value = new(connector); + } + + internal void Reset(ushort numColumns, RowDescriptionMessage rowDescription, CancellationToken cancellationToken) + { + _pos = -1; + _numColumns = numColumns; + _rowDescription = rowDescription; + _cancellationToken = cancellationToken; + } + + public ValueTask MoveNextAsync() + { + if (_tupleEnumerable.State != RowState.Reading) + throw new ObjectDisposedException(null); + + return MoveNextCore(); + + async ValueTask MoveNextCore() + { + // Consume the previous column + if (_pos != -1) + await _value.Consume(_cancellationToken).ConfigureAwait(false); + + if (_pos + 1 == _numColumns) + return false; + _pos++; + + // Read the next column + await _readBuffer.Ensure(1, async: true).ConfigureAwait(false); + var kind = (TupleDataKind)_readBuffer.ReadByte(); + int len; + switch (kind) + { + case TupleDataKind.Null: + case TupleDataKind.UnchangedToastedValue: + len = 0; + break; + case TupleDataKind.TextValue: + case TupleDataKind.BinaryValue: + await _readBuffer.Ensure(4, async: true).ConfigureAwait(false); + len = _readBuffer.ReadInt32(); + break; + default: + throw new ArgumentOutOfRangeException(); + } + + _value.Reset(kind, len, _rowDescription[_pos]); + + return true; + } + } + + public ReplicationValue Current => _tupleEnumerable.State switch + { + RowState.NotRead => throw new ObjectDisposedException(null), + RowState.Reading => _value, + RowState.Consumed => throw new ObjectDisposedException(null), + _ => throw new ArgumentOutOfRangeException() + }; + + public async ValueTask DisposeAsync() + { + if (_tupleEnumerable.State == RowState.Reading) + while (await MoveNextAsync().ConfigureAwait(false)) { /* Do nothing, just iterate the enumerator */ } + + _tupleEnumerable.State = RowState.Consumed; + } +} diff --git a/src/Npgsql/Replication/PgOutput/TupleType.cs b/src/Npgsql/Replication/PgOutput/TupleType.cs new file mode 100644 index 0000000000..80a4f8cc67 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/TupleType.cs @@ -0,0 +1,8 @@ +namespace Npgsql.Replication.PgOutput; + +enum TupleType : byte +{ + Key = (byte)'K', + NewTuple = (byte)'N', + OldTuple = (byte)'O', +} \ No newline at end of file diff --git a/src/Npgsql/Replication/PhysicalReplicationConnection.cs b/src/Npgsql/Replication/PhysicalReplicationConnection.cs index c0b410a884..05d0af33ca 100644 --- a/src/Npgsql/Replication/PhysicalReplicationConnection.cs +++ b/src/Npgsql/Replication/PhysicalReplicationConnection.cs @@ -1,99 +1,125 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Globalization; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; using NpgsqlTypes; -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Represents a physical replication connection to a PostgreSQL server. +/// +public sealed class PhysicalReplicationConnection : ReplicationConnection { + private protected override ReplicationMode ReplicationMode => ReplicationMode.Physical; + + /// + /// Initializes a new instance of . + /// + public PhysicalReplicationConnection() {} + /// - /// Represents a physical replication connection to a PostgreSQL server. + /// Initializes a new instance of with the given connection string. /// - public sealed class PhysicalReplicationConnection : ReplicationConnection + /// The connection used to open the PostgreSQL database. + public PhysicalReplicationConnection(string? connectionString) : base(connectionString) {} + + /// + /// Creates a that wraps a PostgreSQL physical replication slot and + /// can be used to start physical streaming replication + /// + /// + /// The name of the slot to create. Must be a valid replication slot name + /// (see Section 26.2.6.1). + /// + /// + /// if this replication slot shall be a temporary one; otherwise + /// . Temporary slots are not saved to disk and are automatically dropped on error or + /// when the session has finished. + /// + /// + /// If this is set to this physical replication slot reserves WAL immediately. Otherwise, + /// WAL is only reserved upon connection from a streaming replication client. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A representing a that represents the + /// newly-created replication slot. + /// + public async Task CreateReplicationSlot( + string slotName, bool isTemporary = false, bool reserveWal = false, CancellationToken cancellationToken = default) { - private protected override ReplicationMode ReplicationMode => ReplicationMode.Physical; - - /// - /// Initializes a new instance of . - /// - public PhysicalReplicationConnection() {} - - /// - /// Initializes a new instance of with the given connection string. - /// - /// The connection used to open the PostgreSQL database. - public PhysicalReplicationConnection(string? connectionString) : base(connectionString) {} - - /// - /// Creates a that wraps a PostgreSQL physical replication slot and - /// can be used to start physical streaming replication - /// - /// - /// The name of the slot to create. Must be a valid replication slot name - /// (see Section 26.2.6.1). - /// - /// - /// if this replication slot shall be a temporary one; otherwise - /// . Temporary slots are not saved to disk and are automatically dropped on error or - /// when the session has finished. - /// - /// - /// If this is set to this physical replication slot reserves WAL immediately. Otherwise, - /// WAL is only reserved upon connection from a streaming replication client. - /// - /// - /// The token to monitor for cancellation requests. The default value is . - /// - /// A that wraps the newly-created replication slot. - /// - public Task CreateReplicationSlot( - string slotName, bool isTemporary = false, bool reserveWal = false, CancellationToken cancellationToken = default) - { - using var _ = NoSynchronizationContextScope.Enter(); - return CreatePhysicalReplicationSlot(); - - async Task CreatePhysicalReplicationSlot() - { - var builder = new StringBuilder("CREATE_REPLICATION_SLOT ").Append(slotName); - if (isTemporary) - builder.Append(" TEMPORARY"); - builder.Append(" PHYSICAL"); - if (reserveWal) - builder.Append(" RESERVE_WAL"); - - var slotOptions = await CreateReplicationSlot(builder.ToString(), isTemporary, cancellationToken); - - return new PhysicalReplicationSlot(slotOptions.SlotName); - } - } + CheckDisposed(); + + var builder = new StringBuilder("CREATE_REPLICATION_SLOT ").Append(slotName); + if (isTemporary) + builder.Append(" TEMPORARY"); + builder.Append(" PHYSICAL"); + if (reserveWal) + builder.Append(PostgreSqlVersion.Major >= 15 ? " (RESERVE_WAL)" : " RESERVE_WAL"); + + var command = builder.ToString(); + + LogMessages.CreatingReplicationSlot(ReplicationLogger, slotName, command, Connector.Id); - /// - /// Instructs the server to start streaming the WAL for physical replication, starting at WAL location - /// . The server can reply with an error, for example if the requested - /// section of the WAL has already been recycled. - /// - /// - /// If the client requests a timeline that's not the latest but is part of the history of the server, the server - /// will stream all the WAL on that timeline starting from the requested start point up to the point where the - /// server switched to another timeline. - /// - /// - /// The replication slot that will be updated as replication progresses so that the server - /// knows which WAL segments are still needed by the standby. - /// - /// The WAL location to begin streaming at. - /// The token to be used for stopping the replication. - /// Streaming starts on timeline tli. - /// A representing an that - /// can be used to stream WAL entries in form of instances. - public IAsyncEnumerable StartReplication(PhysicalReplicationSlot? slot, + var slotOptions = await CreateReplicationSlot(builder.ToString(), cancellationToken).ConfigureAwait(false); + + return new PhysicalReplicationSlot(slotOptions.SlotName); + } + + /// + /// Read some information associated to a replication slot. + /// + /// This command is currently only supported for physical replication slots. + /// + /// + /// + /// The name of the slot to read. Must be a valid replication slot name + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A representing a or + /// if the replication slot does not exist. + public Task ReadReplicationSlot(string slotName, CancellationToken cancellationToken = default) + => ReadReplicationSlotInternal(slotName, cancellationToken); + + /// + /// Instructs the server to start streaming the WAL for physical replication, starting at WAL location + /// . The server can reply with an error, for example if the requested + /// section of the WAL has already been recycled. + /// + /// + /// If the client requests a timeline that's not the latest but is part of the history of the server, the server + /// will stream all the WAL on that timeline starting from the requested start point up to the point where the + /// server switched to another timeline. + /// + /// + /// The replication slot that will be updated as replication progresses so that the server + /// knows which WAL segments are still needed by the standby. + /// + /// The WAL location to begin streaming at. + /// The token to be used for stopping the replication. + /// Streaming starts on timeline tli. + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public IAsyncEnumerable StartReplication(PhysicalReplicationSlot? slot, + NpgsqlLogSequenceNumber walLocation, + CancellationToken cancellationToken, + uint timeline = default) + { + return StartPhysicalReplication(slot, walLocation, cancellationToken, timeline); + + // Local method to avoid having to add the EnumeratorCancellation attribute to the public signature. + async IAsyncEnumerable StartPhysicalReplication(PhysicalReplicationSlot? slot, NpgsqlLogSequenceNumber walLocation, - CancellationToken cancellationToken, - uint timeline = default) + [EnumeratorCancellation] CancellationToken cancellationToken, + uint timeline) { - using var _ = NoSynchronizationContextScope.Enter(); - var builder = new StringBuilder("START_REPLICATION"); if (slot != null) builder.Append(" SLOT ").Append(slot.Name); @@ -101,26 +127,61 @@ public IAsyncEnumerable StartReplication(PhysicalReplicationSlo if (timeline != default) builder.Append(" TIMELINE ").Append(timeline.ToString(CultureInfo.InvariantCulture)); - return StartReplicationInternal(builder.ToString(), bypassingStream: false, cancellationToken); + var command = builder.ToString(); + + LogMessages.StartingPhysicalReplication(ReplicationLogger, slot?.Name, command, Connector.Id); + + var enumerator = StartReplicationInternalWrapper(command, bypassingStream: false, cancellationToken); + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + yield return enumerator.Current; } + } + + /// + /// Instructs the server to start streaming the WAL for logical replication, starting at WAL location + /// . The server can reply with an error, for example if the requested + /// section of WAL has already been recycled. + /// + /// + /// If the client requests a timeline that's not the latest but is part of the history of the server, the server + /// will stream all the WAL on that timeline starting from the requested start point up to the point where the + /// server switched to another timeline. + /// + /// The WAL location to begin streaming at. + /// The token to be used for stopping the replication. + /// Streaming starts on timeline tli. + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public IAsyncEnumerable StartReplication( + NpgsqlLogSequenceNumber walLocation, CancellationToken cancellationToken, uint timeline = default) + => StartReplication(slot: null, walLocation: walLocation, timeline: timeline, cancellationToken: cancellationToken); - /// - /// Instructs the server to start streaming the WAL for logical replication, starting at WAL location - /// . The server can reply with an error, for example if the requested - /// section of WAL has already been recycled. - /// - /// - /// If the client requests a timeline that's not the latest but is part of the history of the server, the server - /// will stream all the WAL on that timeline starting from the requested start point up to the point where the - /// server switched to another timeline. - /// - /// The WAL location to begin streaming at. - /// The token to be used for stopping the replication. - /// Streaming starts on timeline tli. - /// A representing an that - /// can be used to stream WAL entries in form of instances. - public IAsyncEnumerable StartReplication( - NpgsqlLogSequenceNumber walLocation, CancellationToken cancellationToken, uint timeline = default) - => StartReplication(slot: null, walLocation: walLocation, timeline: timeline, cancellationToken: cancellationToken); + /// + /// Instructs the server to start streaming the WAL for physical replication, starting at the WAL location + /// and timeline id specified in . The server can reply with an error, for example + /// if the requested section of the WAL has already been recycled. + /// + /// + /// If the client requests a timeline that's not the latest but is part of the history of the server, the server + /// will stream all the WAL on that timeline starting from the requested start point up to the point where the + /// server switched to another timeline. + /// + /// + /// The replication slot that will be updated as replication progresses so that the server + /// knows which WAL segments are still needed by the standby. + /// + /// The must contain a valid to be used for this overload. + /// + /// + /// The token to be used for stopping the replication. + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public IAsyncEnumerable StartReplication(PhysicalReplicationSlot slot, CancellationToken cancellationToken) + { + if (!slot.RestartLsn.HasValue) + throw new ArgumentException($"For this overload of {nameof(StartReplication)} the {nameof(slot)} argument must contain a " + + $"valid {nameof(slot.RestartLsn)}. Please use an overload with the walLocation argument otherwise.", + nameof(slot)); + return StartReplication(slot, slot.RestartLsn.Value, cancellationToken, slot.RestartTimeline ?? default); } } diff --git a/src/Npgsql/Replication/PhysicalReplicationSlot.cs b/src/Npgsql/Replication/PhysicalReplicationSlot.cs index 02e87b0db2..7aba817fe2 100644 --- a/src/Npgsql/Replication/PhysicalReplicationSlot.cs +++ b/src/Npgsql/Replication/PhysicalReplicationSlot.cs @@ -1,11 +1,36 @@ -namespace Npgsql.Replication +using NpgsqlTypes; + +namespace Npgsql.Replication; + +/// +/// Wraps a replication slot that uses physical replication. +/// +public class PhysicalReplicationSlot : ReplicationSlot { /// - /// Wraps a replication slot that uses physical replication. + /// Creates a new instance. /// - public class PhysicalReplicationSlot : ReplicationSlot + /// + /// Create a instance with this constructor to wrap an existing PostgreSQL replication slot + /// that has been initialized for physical replication. + /// + /// The name of the existing replication slot + /// The replication slot's restart_lsn + /// The timeline ID associated to restart_lsn, following the current timeline history. + public PhysicalReplicationSlot(string slotName, NpgsqlLogSequenceNumber? restartLsn = null, uint? restartTimeline = null) + : base(slotName) { - internal PhysicalReplicationSlot(string name) - : base(name) { } + RestartLsn = restartLsn; + RestartTimeline = restartTimeline; } + + /// + /// The replication slot's restart_lsn. + /// + public NpgsqlLogSequenceNumber? RestartLsn { get; } + + /// + /// The timeline ID associated to restart_lsn, following the current timeline history. + /// + public uint? RestartTimeline { get; } } diff --git a/src/Npgsql/Replication/ReplicationConnection.cs b/src/Npgsql/Replication/ReplicationConnection.cs index 3bfeab3c71..058b12a2c0 100644 --- a/src/Npgsql/Replication/ReplicationConnection.cs +++ b/src/Npgsql/Replication/ReplicationConnection.cs @@ -1,6 +1,4 @@ using Npgsql.BackendMessages; -using Npgsql.Logging; -using Npgsql.TypeHandlers.DateTimeHandlers; using NpgsqlTypes; using System; using System.Collections.Generic; @@ -8,819 +6,928 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; -using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Internal; using static Npgsql.Util.Statics; +using Npgsql.Util; -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Defines the core behavior of replication connections and provides the base class for +/// and +/// . +/// +public abstract class ReplicationConnection : IAsyncDisposable { + #region Fields + + static readonly Version FirstVersionWithTwoPhaseSupport = new(15, 0); + static readonly Version FirstVersionWithoutDropSlotDoubleCommandCompleteMessage = new(13, 0); + static readonly Version FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode = new(10, 0); + readonly NpgsqlConnection _npgsqlConnection; + readonly SemaphoreSlim _feedbackSemaphore = new(1, 1); + string? _userFacingConnectionString; + TimeSpan? _commandTimeout; + TimeSpan _walReceiverTimeout = TimeSpan.FromSeconds(60d); + Timer? _sendFeedbackTimer; + Timer? _requestFeedbackTimer; + TimeSpan _requestFeedbackInterval; + + IAsyncEnumerator? _currentEnumerator; + CancellationTokenSource? _replicationCancellationTokenSource; + bool _pgCancellationSupported; + bool _isDisposed; + + // We represent the log sequence numbers as unsigned long + // although we have a special struct to represent them and + // they are in fact unsigned 64-bit integers, because + // we access them via Interlocked to synchronize access + // and overcome non-atomic reads/writes on 32-bit platforms + long _lastReceivedLsn; + long _lastFlushedLsn; + long _lastAppliedLsn; + + readonly XLogDataMessage _cachedXLogDataMessage = new(); + + internal ILogger ReplicationLogger { get; private set; } = default!; // Initialized in Open, shouldn't be used otherwise + + #endregion Fields + + #region Constructors + + private protected ReplicationConnection() + { + _npgsqlConnection = new NpgsqlConnection(); + _requestFeedbackInterval = new TimeSpan(_walReceiverTimeout.Ticks / 2); + } + + private protected ReplicationConnection(string? connectionString) : this() + => ConnectionString = connectionString; + + #endregion + + #region Properties + /// - /// Defines the core behavior of replication connections and provides the base class for - /// and - /// . + /// Gets or sets the string used to connect to a PostgreSQL database. See the manual for details. /// - public abstract class ReplicationConnection : IAsyncDisposable - { - #region Fields - - static readonly Version FirstVersionWithoutDropSlotDoubleCommandCompleteMessage = new Version(13, 0); - static readonly Version FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode = new Version(10, 0); - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(ReplicationConnection)); - readonly NpgsqlConnection _npgsqlConnection; - readonly SemaphoreSlim _feedbackSemaphore = new SemaphoreSlim(1, 1); - string? _userFacingConnectionString; - TimeSpan? _commandTimeout; - TimeSpan _walReceiverTimeout = TimeSpan.FromSeconds(60d); - Timer? _sendFeedbackTimer; - Timer? _requestFeedbackTimer; - TimeSpan _requestFeedbackInterval; - Task _replicationCompletion = Task.CompletedTask; - bool _pgCancellationSupported; - bool _isDisposed; - - // We represent the log sequence numbers as unsigned long - // although we have a special struct to represent them and - // they are in fact unsigned 64-bit integers, because - // we access them via Interlocked to synchronize access - // and overcome non-atomic reads/writes on 32-bit platforms - long _lastReceivedLsn; - long _lastFlushedLsn; - long _lastAppliedLsn; - - readonly XLogDataMessage _cachedXLogDataMessage = new XLogDataMessage(); - - #endregion Fields - - #region Constructors - - private protected ReplicationConnection() + /// + /// The connection string that includes the server name, the database name, and other parameters needed to establish the initial + /// connection. The default value is an empty string. + /// + /// + /// Since replication connections are a special kind of connection, + /// , , + /// and + /// are always disabled no matter what you set them to in your connection string. + /// + [AllowNull] + public string ConnectionString { + get => _userFacingConnectionString ?? string.Empty; + set { - _npgsqlConnection = new NpgsqlConnection(); - _requestFeedbackInterval = new TimeSpan(_walReceiverTimeout.Ticks / 2); - } - - private protected ReplicationConnection(string? connectionString) : this() - => ConnectionString = connectionString; - - #endregion - - #region Properties - - /// - /// Gets or sets the string used to connect to a PostgreSQL database. See the manual for details. - /// - /// - /// The connection string that includes the server name, the database name, and other parameters needed to establish the initial - /// connection. The default value is an empty string. - /// - /// - /// Since replication connections are a special kind of connection, - /// , , - /// and - /// are always disabled no matter what you set them to in your connection string. - /// - [AllowNull] - public string ConnectionString { - get => _userFacingConnectionString ?? string.Empty; - set + _userFacingConnectionString = value; + var cs = new NpgsqlConnectionStringBuilder(value) { - _userFacingConnectionString = value; - _npgsqlConnection.ConnectionString = new NpgsqlConnectionStringBuilder(value) - { - Pooling = false, - Enlist = false, - Multiplexing = false, - KeepAlive = 0, - ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, - ReplicationMode = ReplicationMode - }.ToString(); - } + Pooling = false, + Enlist = false, + Multiplexing = false, + KeepAlive = 0, + ReplicationMode = ReplicationMode + }; + + // Physical replication connections don't allow regular queries, so we can't load types from PG + if (ReplicationMode == ReplicationMode.Physical) + cs.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; + + _npgsqlConnection.ConnectionString = cs.ToString(); } + } - /// - /// The location of the last WAL byte + 1 received in the standby. - /// - public NpgsqlLogSequenceNumber LastReceivedLsn - { - get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)); - private protected set => Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)(ulong)value)); - } + /// + /// The location of the last WAL byte + 1 received in the standby. + /// + public NpgsqlLogSequenceNumber LastReceivedLsn + { + get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)); + private protected set => Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)(ulong)value)); + } - /// - /// The location of the last WAL byte + 1 flushed to disk in the standby. - /// - public NpgsqlLogSequenceNumber LastFlushedLsn - { - get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastFlushedLsn)); - set => Interlocked.Exchange(ref _lastFlushedLsn, unchecked((long)(ulong)value)); - } + /// + /// The location of the last WAL byte + 1 flushed to disk in the standby. + /// + public NpgsqlLogSequenceNumber LastFlushedLsn + { + get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastFlushedLsn)); + set => Interlocked.Exchange(ref _lastFlushedLsn, unchecked((long)(ulong)value)); + } - /// - /// The location of the last WAL byte + 1 applied (e. g. written to disk) in the standby. - /// - public NpgsqlLogSequenceNumber LastAppliedLsn - { - get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastAppliedLsn)); - set => Interlocked.Exchange(ref _lastAppliedLsn, unchecked((long)(ulong)value)); - } + /// + /// The location of the last WAL byte + 1 applied (e. g. written to disk) in the standby. + /// + public NpgsqlLogSequenceNumber LastAppliedLsn + { + get => (NpgsqlLogSequenceNumber)unchecked((ulong)Interlocked.Read(ref _lastAppliedLsn)); + set => Interlocked.Exchange(ref _lastAppliedLsn, unchecked((long)(ulong)value)); + } - /// - /// Send replies at least this often. - /// Timeout. disables automated replies. - /// - public TimeSpan WalReceiverStatusInterval { get; set; } = TimeSpan.FromSeconds(10d); - - /// - /// Time that receiver waits for communication from master. - /// Timeout. disables the timeout. - /// - public TimeSpan WalReceiverTimeout + /// + /// Send replies at least this often. + /// Timeout. disables automated replies. + /// + public TimeSpan WalReceiverStatusInterval { get; set; } = TimeSpan.FromSeconds(10d); + + /// + /// Time that receiver waits for communication from master. + /// Timeout. disables the timeout. + /// + public TimeSpan WalReceiverTimeout + { + get => _walReceiverTimeout; + set { - get => _walReceiverTimeout; - set - { - _walReceiverTimeout = value; - _requestFeedbackInterval = value == Timeout.InfiniteTimeSpan - ? value - : new TimeSpan(value.Ticks / 2); - } + _walReceiverTimeout = value; + _requestFeedbackInterval = value == Timeout.InfiniteTimeSpan + ? value + : new TimeSpan(value.Ticks / 2); } + } - private protected abstract ReplicationMode ReplicationMode { get; } + private protected abstract ReplicationMode ReplicationMode { get; } - internal Version PostgreSqlVersion => _npgsqlConnection.PostgreSqlVersion; + /// + /// The version of the PostgreSQL server we're connected to. + /// + ///

+ /// This can only be called when the connection is open. + ///

+ ///

+ /// In case of a development or pre-release version this field will contain + /// the version of the next version to be released from this branch. + ///

+ ///
+ ///
+ public Version PostgreSqlVersion => _npgsqlConnection.PostgreSqlVersion; + + /// + /// The PostgreSQL server version as returned by the server_version option. + /// + /// This can only be called when the connection is open. + /// + /// + public string ServerVersion => _npgsqlConnection.ServerVersion; - internal NpgsqlConnector Connector - => _npgsqlConnection.Connector ?? - throw new InvalidOperationException($"The {Connector} property can only be used when there is an active connection"); + internal NpgsqlConnector Connector + => _npgsqlConnection.Connector ?? + throw new InvalidOperationException($"The {nameof(Connector)} property can only be used when there is an active connection"); - /// - /// Gets or sets the wait time before terminating the attempt to execute a command and generating an error. - /// - /// The time to wait for the command to execute. The default value is 30 seconds. - public TimeSpan CommandTimeout + /// + /// Gets or sets the wait time before terminating the attempt to execute a command and generating an error. + /// + /// The time to wait for the command to execute. The default value is 30 seconds. + public TimeSpan CommandTimeout + { + get => _commandTimeout ?? (_npgsqlConnection.CommandTimeout > 0 + ? TimeSpan.FromSeconds(_npgsqlConnection.CommandTimeout) + : Timeout.InfiniteTimeSpan); + set { - get => _commandTimeout ?? (_npgsqlConnection.CommandTimeout > 0 - ? TimeSpan.FromSeconds(_npgsqlConnection.CommandTimeout) - : Timeout.InfiniteTimeSpan); - set - { - if (value < TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) - throw new ArgumentOutOfRangeException(nameof(value), value, - $"A finite CommandTimeout can't be less than {TimeSpan.Zero}."); + if (value < TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) + throw new ArgumentOutOfRangeException(nameof(value), value, + $"A finite CommandTimeout can't be less than {TimeSpan.Zero}."); - _commandTimeout = value; - if (Connector.State != ConnectorState.Replication) - SetTimeouts(value, value); - } + _commandTimeout = value; + if (Connector.State != ConnectorState.Replication) + SetTimeouts(value, value); } + } - /// - /// The client encoding for the connection - /// This can only be called when there is an active connection. - /// - public Encoding Encoding => _npgsqlConnection.Connector?.TextEncoding ?? throw new InvalidOperationException($"The {Encoding} property can only be used when there is an active connection"); - - /// - /// Process id of backend server. - /// This can only be called when there is an active connection. - /// - public int ProcessID => _npgsqlConnection.Connector?.BackendProcessId ?? throw new InvalidOperationException($"The {ProcessID} property can only be used when there is an active connection"); - - #endregion Properties - - #region Open / Dispose - - /// - /// Opens a database replication connection with the property settings specified by the - /// ConnectionString. - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// - /// A task representing the asynchronous open operation. - public async Task Open(CancellationToken cancellationToken = default) - { - CheckDisposed(); + /// + /// The client encoding for the connection + /// This can only be called when there is an active connection. + /// + public Encoding Encoding => _npgsqlConnection.Connector?.TextEncoding ?? throw new InvalidOperationException($"The {nameof(Encoding)} property can only be used when there is an active connection"); - await _npgsqlConnection.OpenAsync(cancellationToken) - .ConfigureAwait(false); + /// + /// Process id of backend server. + /// This can only be called when there is an active connection. + /// + public int ProcessID => _npgsqlConnection.Connector?.BackendProcessId ?? throw new InvalidOperationException($"The {nameof(ProcessID)} property can only be used when there is an active connection"); - // PG versions before 10 ignore cancellations during replication - _pgCancellationSupported = _npgsqlConnection.PostgreSqlVersion >= new Version(10, 0); + #endregion Properties - SetTimeouts(CommandTimeout, CommandTimeout); - } + #region Open / Dispose - /// - /// Closes the replication connection and performs tasks associated - /// with freeing, releasing, or resetting its unmanaged resources asynchronously. - /// - /// A task that represents the asynchronous dispose operation. - public ValueTask DisposeAsync() - { - using (NoSynchronizationContextScope.Enter()) - return DisposeAsyncCore(); + /// + /// Opens a database replication connection with the property settings specified by the + /// . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous open operation. + public async Task Open(CancellationToken cancellationToken = default) + { + CheckDisposed(); - async ValueTask DisposeAsyncCore() - { - if (_isDisposed) - return; + await _npgsqlConnection.OpenAsync(cancellationToken).ConfigureAwait(false); - if (Connector.State == ConnectorState.Replication) - { - Connector.PerformPostgresCancellation(); - await _replicationCompletion; - } + // PG versions before 10 ignore cancellations during replication + _pgCancellationSupported = _npgsqlConnection.PostgreSqlVersion.IsGreaterOrEqual(10); - Debug.Assert(_sendFeedbackTimer is null, "Send feedback timer isn't null at replication shutdown"); - Debug.Assert(_requestFeedbackTimer is null, "Request feedback timer isn't null at replication shutdown"); - _feedbackSemaphore.Dispose(); - await _npgsqlConnection.Close(async: true); - _isDisposed = true; - } - } + SetTimeouts(CommandTimeout, CommandTimeout); - #endregion Open / Dispose + _npgsqlConnection.Connector!.LongRunningConnection = true; - #region Replication methods + ReplicationLogger = _npgsqlConnection.Connector!.LoggingConfiguration.ReplicationLogger; + } + + /// + /// Closes the replication connection and performs tasks associated + /// with freeing, releasing, or resetting its unmanaged resources asynchronously. + /// + /// A task that represents the asynchronous dispose operation. + public async ValueTask DisposeAsync() + { + if (_isDisposed) + return; - /// - /// Requests the server to identify itself. - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// - /// A containing information - /// about the system we are connected to. - /// - public Task IdentifySystem(CancellationToken cancellationToken = default) + if (_npgsqlConnection.Connector?.State == ConnectorState.Replication) { - using (NoSynchronizationContextScope.Enter()) - return IdentifySystemInternal(); + Debug.Assert(_currentEnumerator is not null); + Debug.Assert(_replicationCancellationTokenSource is not null); - async Task IdentifySystemInternal() + // Replication is in progress; cancel it (soft or hard) and iterate the enumerator until we get the cancellation + // exception. Note: this isn't thread-safe: a user calling DisposeAsync and enumerating at the same time is violating + // our contract. + _replicationCancellationTokenSource.Cancel(); + try + { + while (await _currentEnumerator.MoveNextAsync().ConfigureAwait(false)) + { + // Do nothing with messages - simply enumerate until cancellation/termination + } + } + catch { - var row = await ReadSingleRow("IDENTIFY_SYSTEM", cancellationToken); - return new ReplicationSystemIdentification( - (string)row[0], (uint)row[1], NpgsqlLogSequenceNumber.Parse((string)row[2]), (string)row[3]); + // Cancellation/termination occurred } } - /// - /// Requests the server to send the current setting of a run-time parameter. - /// This is similar to the SQL command SHOW. - /// - /// The name of a run-time parameter. - /// Available parameters are documented in https://www.postgresql.org/docs/current/runtime-config.html. - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// The current setting of the run-time parameter specified in as . - public Task Show(string parameterName, CancellationToken cancellationToken = default) - { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); - - using (NoSynchronizationContextScope.Enter()) - return ShowInternal(); + Debug.Assert(_sendFeedbackTimer is null, "Send feedback timer isn't null at replication shutdown"); + Debug.Assert(_requestFeedbackTimer is null, "Request feedback timer isn't null at replication shutdown"); + _feedbackSemaphore.Dispose(); - async Task ShowInternal() - => (string)(await ReadSingleRow("SHOW " + parameterName, cancellationToken))[0]; + try + { + await _npgsqlConnection.Close(async: true).ConfigureAwait(false); } - - /// - /// Requests the server to send over the timeline history file for timeline tli. - /// - /// The timeline for which the history file should be sent. - /// The token to monitor for cancellation requests. - /// The default value is . - /// The timeline history file for timeline tli - public Task TimelineHistory(uint tli, CancellationToken cancellationToken = default) + catch { - using (NoSynchronizationContextScope.Enter()) - return TimelineHistoryInternal(); - - async Task TimelineHistoryInternal() - { - var result = await ReadSingleRow($"TIMELINE_HISTORY {tli:D}", cancellationToken); - return new TimelineHistoryFile((string)result[0], (byte[])result[1]); - } + // Dispose } - internal async Task CreateReplicationSlot( - string command, bool temporarySlot, CancellationToken cancellationToken = default) - { - CheckDisposed(); + _isDisposed = true; + } - using var _ = Connector.StartUserAction(cancellationToken, attemptPgCancellation: _pgCancellationSupported); + #endregion Open / Dispose - await Connector.WriteQuery(command, true, cancellationToken); - await Connector.Flush(true, cancellationToken); + #region Replication methods - try - { - var rowDescription = Expect(await Connector.ReadMessage(true), Connector); - Debug.Assert(rowDescription.NumFields == 4); - Debug.Assert(rowDescription.Fields[0].TypeOID == 25u, "slot_name expected as text"); - Debug.Assert(rowDescription.Fields[1].TypeOID == 25u, "consistent_point expected as text"); - Debug.Assert(rowDescription.Fields[2].TypeOID == 25u, "snapshot_name expected as text"); - Debug.Assert(rowDescription.Fields[3].TypeOID == 25u, "output_plugin expected as text"); - Expect(await Connector.ReadMessage(true), Connector); - var buf = Connector.ReadBuffer; - await buf.EnsureAsync(2); - var results = new object[buf.ReadInt16()]; - Debug.Assert(results.Length == 4); - - // slot_name - await buf.EnsureAsync(4); - var len = buf.ReadInt32(); - Debug.Assert(len > 0, "slot_name should never be empty"); - await buf.EnsureAsync(len); - var slotNameResult = buf.ReadString(len); - - // consistent_point - await buf.EnsureAsync(4); - len = buf.ReadInt32(); - Debug.Assert(len > 0, "consistent_point should never be empty"); - await buf.EnsureAsync(len); - var consistentPoint = NpgsqlLogSequenceNumber.Parse(buf.ReadString(len)); - - // snapshot_name - await buf.EnsureAsync(4); - len = buf.ReadInt32(); - string? snapshotName; - if (len == -1) - snapshotName = null; - else - { - await buf.EnsureAsync(len); - snapshotName = buf.ReadString(len); - } + /// + /// Requests the server to identify itself. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + /// A containing information about the system we are connected to. + /// + public async Task IdentifySystem(CancellationToken cancellationToken = default) + { + var row = await ReadSingleRow("IDENTIFY_SYSTEM", cancellationToken).ConfigureAwait(false); + return new ReplicationSystemIdentification( + (string)row[0], (uint)row[1], NpgsqlLogSequenceNumber.Parse((string)row[2]), (string)row[3]); + } + /// + /// Requests the server to send the current setting of a run-time parameter. + /// This is similar to the SQL command SHOW. + /// + /// The name of a run-time parameter. + /// Available parameters are documented in https://www.postgresql.org/docs/current/runtime-config.html. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The current setting of the run-time parameter specified in as . + public Task Show(string parameterName, CancellationToken cancellationToken = default) + { + if (parameterName is null) + throw new ArgumentNullException(nameof(parameterName)); - // output_plugin - await buf.EnsureAsync(4); - len = buf.ReadInt32(); - if (len != -1) - { - await buf.EnsureAsync(len); - buf.Skip(len); // We know already what we created - } + return ShowInternal(parameterName, cancellationToken); - Expect(await Connector.ReadMessage(true), Connector); - Expect(await Connector.ReadMessage(true), Connector); + async Task ShowInternal(string parameterName, CancellationToken cancellationToken) + => (string)(await ReadSingleRow("SHOW " + parameterName, cancellationToken).ConfigureAwait(false))[0]; + } - return new ReplicationSlotOptions(slotNameResult, consistentPoint, snapshotName); - } - catch (PostgresException e) + /// + /// Requests the server to send over the timeline history file for timeline tli. + /// + /// The timeline for which the history file should be sent. + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The timeline history file for timeline tli + public async Task TimelineHistory(uint tli, CancellationToken cancellationToken = default) + { + var result = await ReadSingleRow($"TIMELINE_HISTORY {tli:D}", cancellationToken).ConfigureAwait(false); + return new TimelineHistoryFile((string)result[0], (byte[])result[1]); + } + + internal async Task CreateReplicationSlot(string command, CancellationToken cancellationToken = default) + { + try + { + var result = await ReadSingleRow(command, cancellationToken).ConfigureAwait(false); + var slotName = (string)result[0]; + var consistentPoint = (string)result[1]; + var snapshotName = (string?)result[2]; + return new ReplicationSlotOptions(slotName, NpgsqlLogSequenceNumber.Parse(consistentPoint), snapshotName); + } + catch (PostgresException e) when (!Connector.IsBroken && e.SqlState == PostgresErrorCodes.SyntaxError) + { + if (PostgreSqlVersion < FirstVersionWithTwoPhaseSupport && command.Contains(" TWO_PHASE")) + throw new NotSupportedException("Logical replication support for prepared transactions was introduced in PostgreSQL " + + FirstVersionWithTwoPhaseSupport.ToString(1) + + ". Using PostgreSQL version " + + (PostgreSqlVersion.Build == -1 + ? PostgreSqlVersion.ToString(2) + : PostgreSqlVersion.ToString(3)) + + " you have to set the twoPhase argument to false.", e); + if (PostgreSqlVersion < FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode) { - if (PostgreSqlVersion < FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode && e.SqlState == PostgresErrorCodes.SyntaxError) - { - if (temporarySlot) - throw new NotSupportedException("Temporary replication slots were introduced in PostgreSQL " + - $"{FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode.ToString(1)}. " + - $"Using PostgreSQL version {PostgreSqlVersion.ToString(3)} you " + - $"have to set the {nameof(temporarySlot)} argument to false.", e); - if (command.Contains("_SNAPSHOT")) - throw new NotSupportedException( - "The EXPORT_SNAPSHOT, USE_SNAPSHOT and NOEXPORT_SNAPSHOT syntax was introduced in PostgreSQL " + - $"{FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode.ToString(1)}. Using PostgreSQL version " + - $"{PostgreSqlVersion.ToString(3)} you have to omit the slotSnapshotInitMode argument.", e); - } - throw; + if (command.Contains(" TEMPORARY")) + throw new NotSupportedException("Temporary replication slots were introduced in PostgreSQL " + + $"{FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode.ToString(1)}. " + + $"Using PostgreSQL version {PostgreSqlVersion.ToString(3)} you " + + $"have to set the isTemporary argument to false.", e); + if (command.Contains(" EXPORT_SNAPSHOT") || command.Contains(" NOEXPORT_SNAPSHOT") || command.Contains(" USE_SNAPSHOT")) + throw new NotSupportedException( + "The EXPORT_SNAPSHOT, USE_SNAPSHOT and NOEXPORT_SNAPSHOT syntax was introduced in PostgreSQL " + + $"{FirstVersionWithTemporarySlotsAndSlotSnapshotInitMode.ToString(1)}. Using PostgreSQL version " + + $"{PostgreSqlVersion.ToString(3)} you have to omit the slotSnapshotInitMode argument.", e); } + throw; } + } + + internal async Task ReadReplicationSlotInternal(string slotName, CancellationToken cancellationToken = default) + { + var result = await ReadSingleRow($"READ_REPLICATION_SLOT {slotName}", cancellationToken).ConfigureAwait(false); + var slotType = (string?)result[0]; - internal async IAsyncEnumerable StartReplicationInternal( - string command, - bool bypassingStream, - [EnumeratorCancellation] CancellationToken cancellationToken) + // Currently (2021-12-30) slot_type is always 'physical' for existing slots or null for slot names that don't exist but that + // might change and we'd have to adopt our implementation in that case so check it just in case + switch (slotType) { - CheckDisposed(); + case "physical": + var restartLsn = (string?)result[1]; + var restartTli = (uint?)result[2]; + return new PhysicalReplicationSlot( + slotName.ToLowerInvariant(), + restartLsn == null ? null : NpgsqlLogSequenceNumber.Parse(restartLsn), + restartTli); + case null: + return null; + default: + throw new NotSupportedException( + $"The replication slot type '{slotType}' is currently not supported by Npgsql. Please file an issue."); + } + } - var connector = _npgsqlConnection.Connector!; + internal IAsyncEnumerator StartReplicationInternalWrapper( + string command, + bool bypassingStream, + CancellationToken cancellationToken) + { + _currentEnumerator = StartReplicationInternal(command, bypassingStream, cancellationToken); + return _currentEnumerator; + } - using var _ = Connector.StartUserAction( - ConnectorState.Replication, cancellationToken, attemptPgCancellation: _pgCancellationSupported); + internal async IAsyncEnumerator StartReplicationInternal( + string command, + bool bypassingStream, + CancellationToken cancellationToken) + { + CheckDisposed(); - var completionSource = new TaskCompletionSource(); - _replicationCompletion = completionSource.Task; + var connector = _npgsqlConnection.Connector!; - try + _replicationCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + using var _ = connector.StartUserAction( + ConnectorState.Replication, _replicationCancellationTokenSource.Token, attemptPgCancellation: _pgCancellationSupported); + + NpgsqlReadBuffer.ColumnStream? columnStream = null; + + try + { + await connector.WriteQuery(command, true, cancellationToken).ConfigureAwait(false); + await connector.Flush(true, cancellationToken).ConfigureAwait(false); + + var msg = await connector.ReadMessage(true).ConfigureAwait(false); + switch (msg.Code) + { + case BackendMessageCode.CopyBothResponse: + break; + case BackendMessageCode.CommandComplete: { - await connector.WriteQuery(command, true, cancellationToken); - await connector.Flush(true, cancellationToken); + yield break; + } + default: + throw connector.UnexpectedMessageReceived(msg.Code); + } - var msg = await connector.ReadMessage(true); - switch (msg.Code) - { - case BackendMessageCode.CopyBothResponse: - break; - case BackendMessageCode.CommandComplete: - { - yield break; - } - default: - throw connector.UnexpectedMessageReceived(msg.Code); - } + var buf = connector.ReadBuffer; + + columnStream = new NpgsqlReadBuffer.ColumnStream(connector); - var buf = connector.ReadBuffer; + SetTimeouts(_walReceiverTimeout, CommandTimeout); - // Cancellation is handled at the replication level - we don't want every ReadAsync - var columnStream = new NpgsqlReadBuffer.ColumnStream(connector, startCancellableOperations: false); + _sendFeedbackTimer = new Timer(TimerSendFeedback, state: null, WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); + _requestFeedbackTimer = new Timer(TimerRequestFeedback, state: null, _requestFeedbackInterval, Timeout.InfiniteTimeSpan); - SetTimeouts(_walReceiverTimeout, CommandTimeout); + while (true) + { + msg = await connector.ReadMessage(async: true).ConfigureAwait(false); + Expect(msg, Connector); - _sendFeedbackTimer = new Timer(TimerSendFeedback, state: null, WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); - _requestFeedbackTimer = new Timer(TimerRequestFeedback, state: null, _requestFeedbackInterval, Timeout.InfiniteTimeSpan); + // We received some message so there's no need to forcibly request feedback + // Reset the timer to request feedback. + _requestFeedbackTimer.Change(_requestFeedbackInterval, Timeout.InfiniteTimeSpan); - while (true) + var messageLength = ((CopyDataMessage)msg).Length; + await buf.EnsureAsync(1).ConfigureAwait(false); + var code = (char)buf.ReadByte(); + switch (code) { - msg = await Connector.ReadMessage(async: true); - Expect(msg, Connector); - - // We received some message so there's no need to forcibly request feedback - // Reset the timer to request feedback. - _requestFeedbackTimer.Change(_requestFeedbackInterval, Timeout.InfiniteTimeSpan); + case 'w': // XLogData + { + await buf.EnsureAsync(24).ConfigureAwait(false); + var startLsn = buf.ReadUInt64(); + var endLsn = buf.ReadUInt64(); + var sendTime = PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); - var messageLength = ((CopyDataMessage)msg).Length; - await buf.EnsureAsync(1); - var code = (char)buf.ReadByte(); - switch (code) - { - case 'w': // XLogData - { - await buf.EnsureAsync(24); - var startLsn = buf.ReadUInt64(); - var endLsn = buf.ReadUInt64(); - var sendTime = TimestampHandler.FromPostgresTimestamp(buf.ReadInt64()).ToLocalTime(); + if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < startLsn) + Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)startLsn)); + if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < endLsn) + Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)endLsn)); - if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < startLsn) - Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)startLsn)); - if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < endLsn) - Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)endLsn)); + // dataLen = msg.Length - (code = 1 + walStart = 8 + walEnd = 8 + serverClock = 8) + var dataLen = messageLength - 25; + columnStream.Init(dataLen, canSeek: false, commandScoped: false); - // dataLen = msg.Length - (code = 1 + walStart = 8 + walEnd = 8 + serverClock = 8) - var dataLen = messageLength - 25; - columnStream.Init(dataLen, canSeek: false); + _cachedXLogDataMessage.Populate(new NpgsqlLogSequenceNumber(startLsn), new NpgsqlLogSequenceNumber(endLsn), + sendTime, columnStream); + yield return _cachedXLogDataMessage; - _cachedXLogDataMessage.Populate(new NpgsqlLogSequenceNumber(startLsn), new NpgsqlLogSequenceNumber(endLsn), - sendTime, columnStream); - yield return _cachedXLogDataMessage; + // Our consumer may not have read the stream to the end, but it might as well have been us + // ourselves bypassing the stream and reading directly from the buffer in StartReplication() + if (!columnStream.IsDisposed && columnStream.Position < columnStream.Length && !bypassingStream) + await buf.Skip(checked((int)(columnStream.Length - columnStream.Position)), true).ConfigureAwait(false); - // Our consumer may not have read the stream to the end, but it might as well have been us - // ourselves bypassing the stream and reading directly from the buffer in StartReplication() - if (!columnStream.IsDisposed && columnStream.Position < columnStream.Length && !bypassingStream) - await buf.Skip(columnStream.Length - columnStream.Position, true); + continue; + } - continue; - } + case 'k': // Primary keepalive message + { + await buf.EnsureAsync(17).ConfigureAwait(false); + var end = buf.ReadUInt64(); - case 'k': // Primary keepalive message + if (ReplicationLogger.IsEnabled(LogLevel.Trace)) { - await buf.EnsureAsync(17); - var endLsn = buf.ReadUInt64(); - var timestamp = buf.ReadInt64(); - var replyRequested = buf.ReadByte() == 1; - if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < endLsn) - Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)endLsn)); + var endLsn = new NpgsqlLogSequenceNumber(end); + var timestamp = PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc); + LogMessages.ReceivedReplicationPrimaryKeepalive(ReplicationLogger, endLsn, timestamp, Connector.Id); + } + else + buf.Skip(8); - if (replyRequested) - await SendFeedback(waitOnSemaphore: true, cancellationToken: CancellationToken.None); + var replyRequested = buf.ReadByte() == 1; + if (unchecked((ulong)Interlocked.Read(ref _lastReceivedLsn)) < end) + Interlocked.Exchange(ref _lastReceivedLsn, unchecked((long)end)); - continue; + if (replyRequested) + { + LogMessages.SendingReplicationStandbyStatusUpdate(ReplicationLogger, "the server requested it", Connector.Id); + await SendFeedback(waitOnSemaphore: true, cancellationToken: CancellationToken.None).ConfigureAwait(false); } - default: - throw Connector.Break(new NpgsqlException($"Unknown replication message code '{code}'")); - } - } - } - finally - { -#if NETSTANDARD2_0 - if (_sendFeedbackTimer != null) - { - var mre = new ManualResetEvent(false); - var actuallyDisposed = _sendFeedbackTimer.Dispose(mre); - Debug.Assert(actuallyDisposed, $"{nameof(_sendFeedbackTimer)} had already been disposed when completing replication"); - if (actuallyDisposed) - await mre.WaitOneAsync(cancellationToken); + continue; } - if (_requestFeedbackTimer != null) - { - var mre = new ManualResetEvent(false); - var actuallyDisposed = _requestFeedbackTimer.Dispose(mre); - Debug.Assert(actuallyDisposed, $"{nameof(_requestFeedbackTimer)} had already been disposed when completing replication"); - if (actuallyDisposed) - await mre.WaitOneAsync(cancellationToken); + default: + throw Connector.Break(new NpgsqlException($"Unknown replication message code '{code}'")); } -#else - - if (_sendFeedbackTimer != null) - await _sendFeedbackTimer.DisposeAsync(); - if (_requestFeedbackTimer != null) - await _requestFeedbackTimer.DisposeAsync(); -#endif - _sendFeedbackTimer = null; - _requestFeedbackTimer = null; - - SetTimeouts(CommandTimeout, CommandTimeout); - - completionSource.SetResult(0); } } - - /// - /// Sends a forced status update to PostgreSQL with the current WAL tracking information. - /// - /// The connection currently isn't streaming - /// A Task representing the sending of the status update (and not any PostgreSQL response). - public Task SendStatusUpdate(CancellationToken cancellationToken = default) + finally { - using (NoSynchronizationContextScope.Enter()) - return SendStatusUpdateInternal(); + if (columnStream != null && !bypassingStream && !_replicationCancellationTokenSource.Token.IsCancellationRequested) + await columnStream.DisposeAsync().ConfigureAwait(false); - async Task SendStatusUpdateInternal() +#if NETSTANDARD2_0 + if (_sendFeedbackTimer != null) { - CheckDisposed(); - cancellationToken.ThrowIfCancellationRequested(); - - // TODO: If the user accidentally does concurrent usage of the connection, the following is vulnerable to race conditions. - // However, we generally aren't safe for this in Npgsql, leaving as-is for now. - if (Connector.State != ConnectorState.Replication) - throw new InvalidOperationException("Status update can only be sent during replication"); + var mre = new ManualResetEvent(false); + var actuallyDisposed = _sendFeedbackTimer.Dispose(mre); + Debug.Assert(actuallyDisposed, $"{nameof(_sendFeedbackTimer)} had already been disposed when completing replication"); + if (actuallyDisposed) + await mre.WaitOneAsync(cancellationToken).ConfigureAwait(false); + } - await SendFeedback(waitOnSemaphore: true, cancellationToken: cancellationToken); + if (_requestFeedbackTimer != null) + { + var mre = new ManualResetEvent(false); + var actuallyDisposed = _requestFeedbackTimer.Dispose(mre); + Debug.Assert(actuallyDisposed, $"{nameof(_requestFeedbackTimer)} had already been disposed when completing replication"); + if (actuallyDisposed) + await mre.WaitOneAsync(cancellationToken).ConfigureAwait(false); } +#else + + if (_sendFeedbackTimer != null) + await _sendFeedbackTimer.DisposeAsync().ConfigureAwait(false); + if (_requestFeedbackTimer != null) + await _requestFeedbackTimer.DisposeAsync().ConfigureAwait(false); +#endif + _sendFeedbackTimer = null; + _requestFeedbackTimer = null; + + SetTimeouts(CommandTimeout, CommandTimeout); + + _replicationCancellationTokenSource.Dispose(); + _replicationCancellationTokenSource = null; + + _currentEnumerator = null; } + } - async Task SendFeedback(bool waitOnSemaphore = false, bool requestReply = false, CancellationToken cancellationToken = default) - { - var taken = waitOnSemaphore - ? await _feedbackSemaphore.WaitAsync(Timeout.Infinite, cancellationToken) - : await _feedbackSemaphore.WaitAsync(TimeSpan.Zero, cancellationToken); + /// + /// Sets the current status of the replication as it is interpreted by the consuming client. The value supplied + /// in will be sent to the server via and + /// with the next status update. + /// + /// A status update which will happen upon server request, upon expiration of + /// our upon an enforced status update via , whichever happens first. + /// If you want the value you set here to be pushed to the server immediately (e. g. in synchronous replication scenarios), + /// call after calling this method. + /// + /// + /// + /// This is a convenience method setting both and in one operation. + /// You can use it if your application processes replication messages in a way that doesn't care about the difference between + /// writing a message and flushing it to a permanent storage medium. + /// + /// The location of the last WAL byte + 1 applied (e. g. processed or written to disk) and flushed to disk in the standby. + public void SetReplicationStatus(NpgsqlLogSequenceNumber lastAppliedAndFlushedLsn) + { + Interlocked.Exchange(ref _lastAppliedLsn, unchecked((long)(ulong)lastAppliedAndFlushedLsn)); + Interlocked.Exchange(ref _lastFlushedLsn, unchecked((long)(ulong)lastAppliedAndFlushedLsn)); + } - if (!taken) - return; + /// + /// Sends a forced status update to PostgreSQL with the current WAL tracking information. + /// + /// The connection currently isn't streaming + /// A Task representing the sending of the status update (and not any PostgreSQL response). + public async Task SendStatusUpdate(CancellationToken cancellationToken = default) + { + CheckDisposed(); + cancellationToken.ThrowIfCancellationRequested(); - try - { - var connector = _npgsqlConnection.Connector!; - var buf = connector.WriteBuffer; + // TODO: If the user accidentally does concurrent usage of the connection, the following is vulnerable to race conditions. + // However, we generally aren't safe for this in Npgsql, leaving as-is for now. + if (Connector.State != ConnectorState.Replication) + throw new InvalidOperationException("Status update can only be sent during replication"); - const int len = 39; + LogMessages.SendingReplicationStandbyStatusUpdate(ReplicationLogger, nameof(SendStatusUpdate) + "was called", Connector.Id); + await SendFeedback(waitOnSemaphore: true, cancellationToken: cancellationToken).ConfigureAwait(false); + } - if (buf.WriteSpaceLeft < len) - await connector.Flush(async: true, cancellationToken); + async Task SendFeedback(bool waitOnSemaphore = false, bool requestReply = false, CancellationToken cancellationToken = default) + { + var taken = waitOnSemaphore + ? await _feedbackSemaphore.WaitAsync(Timeout.Infinite, cancellationToken).ConfigureAwait(false) + : await _feedbackSemaphore.WaitAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); - buf.WriteByte(FrontendMessageCode.CopyData); - buf.WriteInt32(len - 1); - buf.WriteByte((byte)'r'); // TODO: enum/const? - // We write the LSNs as Int64 here to save us the casting - buf.WriteInt64(Interlocked.Read(ref _lastReceivedLsn)); - buf.WriteInt64(Interlocked.Read(ref _lastFlushedLsn)); - buf.WriteInt64(Interlocked.Read(ref _lastAppliedLsn)); - buf.WriteInt64(TimestampHandler.ToPostgresTimestamp(DateTime.Now)); - buf.WriteByte(requestReply ? (byte)1 : (byte)0); + if (!taken) + { + ReplicationLogger.LogTrace($"Aborting feedback due to expired {nameof(WalReceiverStatusInterval)} because of a concurrent feedback request"); + return; + } - await connector.Flush(async: true, cancellationToken); - } - finally + try + { + var connector = _npgsqlConnection.Connector!; + var buf = connector.WriteBuffer; + + const int len = 39; + + if (buf.WriteSpaceLeft < len) + await connector.Flush(async: true, cancellationToken).ConfigureAwait(false); + + buf.StartMessage(len); + buf.WriteByte(FrontendMessageCode.CopyData); + buf.WriteInt32(len - 1); + buf.WriteByte((byte)'r'); // TODO: enum/const? + // We write the LSNs as Int64 here to save us the casting + var lastReceivedLsn = Interlocked.Read(ref _lastReceivedLsn); + var lastFlushedLsn = Interlocked.Read(ref _lastFlushedLsn); + var lastAppliedLsn = Interlocked.Read(ref _lastAppliedLsn); + var timestamp = DateTime.UtcNow; + buf.WriteInt64(lastReceivedLsn); + buf.WriteInt64(lastFlushedLsn); + buf.WriteInt64(lastAppliedLsn); + buf.WriteInt64(PgDateTime.EncodeTimestamp(timestamp)); + buf.WriteByte(requestReply ? (byte)1 : (byte)0); + + await connector.Flush(async: true, cancellationToken).ConfigureAwait(false); + + if (ReplicationLogger.IsEnabled(LogLevel.Trace)) { - _sendFeedbackTimer!.Change(WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); - _requestFeedbackTimer!.Change(_requestFeedbackInterval, Timeout.InfiniteTimeSpan); - _feedbackSemaphore.Release(); + LogMessages.SentReplicationFeedbackMessage( + ReplicationLogger, + new NpgsqlLogSequenceNumber(unchecked((ulong)lastReceivedLsn)), + new NpgsqlLogSequenceNumber(unchecked((ulong)lastFlushedLsn)), + new NpgsqlLogSequenceNumber(unchecked((ulong)lastAppliedLsn)), + timestamp, + Connector.Id); } } + catch (Exception e) + { + LogMessages.ReplicationFeedbackMessageSendingFailed(ReplicationLogger, _npgsqlConnection?.Connector?.Id, e); + } + finally + { + _sendFeedbackTimer!.Change(WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); + if (requestReply) + _requestFeedbackTimer!.Change(_requestFeedbackInterval, Timeout.InfiniteTimeSpan); + _feedbackSemaphore.Release(); + } + } - async void TimerRequestFeedback(object? obj) + async void TimerRequestFeedback(object? obj) + { + try { - try - { - if (Connector.State != ConnectorState.Replication) - return; + if (Connector.State != ConnectorState.Replication) + return; - await SendFeedback(waitOnSemaphore: true, requestReply: true); - } - catch (Exception e) - { - Log.Error("An exception occurred while requesting streaming replication feedback from the server.", e, _npgsqlConnection?.Connector?.Id ?? 0); - } + if (ReplicationLogger.IsEnabled(LogLevel.Trace)) + LogMessages.SendingReplicationStandbyStatusUpdate(ReplicationLogger, $"half of the {nameof(WalReceiverTimeout)} of {WalReceiverTimeout} has expired", Connector.Id); + + await SendFeedback(waitOnSemaphore: true, requestReply: true).ConfigureAwait(false); } + catch + { + // Already logged inside SendFeedback + } + } - async void TimerSendFeedback(object? obj) + async void TimerSendFeedback(object? obj) + { + try { - try - { - if (Connector.State != ConnectorState.Replication) - return; + if (Connector.State != ConnectorState.Replication) + return; - await SendFeedback(); - } - catch (Exception e) - { - Log.Error("An exception occurred while sending streaming replication feedback to the server.", e, _npgsqlConnection?.Connector?.Id ?? 0); - } - } + if (ReplicationLogger.IsEnabled(LogLevel.Trace)) + LogMessages.SendingReplicationStandbyStatusUpdate(ReplicationLogger, $"{nameof(WalReceiverStatusInterval)} of {WalReceiverStatusInterval} has expired", Connector.Id); - /// - /// Drops a replication slot, freeing any reserved server-side resources. - /// If the slot is a logical slot that was created in a database other than - /// the database the walsender is connected to, this command fails. - /// - /// The name of the slot to drop. - /// - /// causes the command to wait until the slot becomes - /// inactive if it currently is active instead of the default behavior of raising an error. - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// A task representing the asynchronous drop operation. - public Task DropReplicationSlot(string slotName, bool wait = false, CancellationToken cancellationToken = default) + await SendFeedback().ConfigureAwait(false); + } + catch { - if (slotName is null) - throw new ArgumentNullException(nameof(slotName)); + // Already logged inside SendFeedback + } + } - using (NoSynchronizationContextScope.Enter()) - return DropReplicationSlotInternal(); + /// + /// Drops a replication slot, freeing any reserved server-side resources. + /// If the slot is a logical slot that was created in a database other than + /// the database the walsender is connected to, this command fails. + /// + /// The name of the slot to drop. + /// + /// causes the command to wait until the slot becomes + /// inactive if it currently is active instead of the default behavior of raising an error. + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// A task representing the asynchronous drop operation. + public Task DropReplicationSlot(string slotName, bool wait = false, CancellationToken cancellationToken = default) + { + if (slotName is null) + throw new ArgumentNullException(nameof(slotName)); - async Task DropReplicationSlotInternal() - { - CheckDisposed(); + CheckDisposed(); - using var _ = Connector.StartUserAction(cancellationToken, attemptPgCancellation: _pgCancellationSupported); + return DropReplicationSlotInternal(slotName, wait, cancellationToken); - var command = "DROP_REPLICATION_SLOT " + slotName; - if (wait) - command += " WAIT"; + async Task DropReplicationSlotInternal(string slotName, bool wait, CancellationToken cancellationToken) + { + using var _ = Connector.StartUserAction(cancellationToken, attemptPgCancellation: _pgCancellationSupported); - await Connector.WriteQuery(command, true, CancellationToken.None); - await Connector.Flush(true, CancellationToken.None); + var command = "DROP_REPLICATION_SLOT " + slotName; + if (wait) + command += " WAIT"; - Expect(await Connector.ReadMessage(true), Connector); + LogMessages.DroppingReplicationSlot(ReplicationLogger, slotName, command, Connector.Id); - // Two CommandComplete messages are returned - if (PostgreSqlVersion < FirstVersionWithoutDropSlotDoubleCommandCompleteMessage) - Expect(await Connector.ReadMessage(true), Connector); + await Connector.WriteQuery(command, true, CancellationToken.None).ConfigureAwait(false); + await Connector.Flush(true, CancellationToken.None).ConfigureAwait(false); - Expect(await Connector.ReadMessage(true), Connector); - } + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + + // Two CommandComplete messages are returned + if (PostgreSqlVersion < FirstVersionWithoutDropSlotDoubleCommandCompleteMessage) + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); } + } - #endregion + #endregion - async Task ReadSingleRow(string command, CancellationToken cancellationToken = default) - { - CheckDisposed(); + async Task ReadSingleRow(string command, CancellationToken cancellationToken = default) + { + CheckDisposed(); - using var _ = Connector.StartUserAction(cancellationToken, attemptPgCancellation: _pgCancellationSupported); + using var _ = Connector.StartUserAction(cancellationToken, attemptPgCancellation: _pgCancellationSupported); - await Connector.WriteQuery(command, true, cancellationToken); - await Connector.Flush(true, cancellationToken); + LogMessages.ExecutingReplicationCommand(ReplicationLogger, command, Connector.Id); - var description = - Expect(await Connector.ReadMessage(true), Connector); - Expect(await Connector.ReadMessage(true), Connector); - var buf = Connector.ReadBuffer; - await buf.EnsureAsync(2); - var results = new object[buf.ReadInt16()]; - for (var i = 0; i < results.Length; i++) - { - await buf.EnsureAsync(4); - var len = buf.ReadInt32(); - if (len == -1) - continue; + await Connector.WriteQuery(command, true, cancellationToken).ConfigureAwait(false); + await Connector.Flush(true, cancellationToken).ConfigureAwait(false); - await buf.EnsureAsync(len); - var field = description.Fields[i]; - switch (field.PostgresType.Name) + var rowDescription = Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + var buf = Connector.ReadBuffer; + await buf.EnsureAsync(2).ConfigureAwait(false); + var results = new object[buf.ReadInt16()]; + for (var i = 0; i < results.Length; i++) + { + await buf.EnsureAsync(4).ConfigureAwait(false); + var len = buf.ReadInt32(); + if (len == -1) + continue; + + await buf.EnsureAsync(len).ConfigureAwait(false); + var field = rowDescription[i]; + switch (field.PostgresType.Name) + { + case "text": + results[i] = buf.ReadString(len); + continue; + // Currently in all instances where ReadSingleRow gets called, we expect unsigned integer values only, since that's always + // TimeLineID which is a uint32 in PostgreSQL that is sent as integer up to PG 15 and as bigint as of PG 16 + // (https://github.com/postgres/postgres/blob/57d0051706b897048063acc14c2c3454200c488f/src/include/access/xlogdefs.h#L59 and + // https://github.com/postgres/postgres/commit/ec40f3422412cfdc140b5d3f67db7fd2dac0f1e2). + // Because of this, it is safe to always parse the values we get as unit although, according to the row description message + // we formally could also get a signed int or long value. + // Whenever ReadSingleRow gets used in a new context we have to check, whether this contract is still + // valid in that context and if it isn't, adjust the method accordingly (e.g. by switching on the command). + case "integer": + case "bigint": + { + var str = buf.ReadString(len); + if (!uint.TryParse(str, NumberStyles.None, null, out var num)) { - case "text": - results[i] = buf.ReadString(len); - continue; - case "integer": - var str = buf.ReadString(len); - if (!uint.TryParse(str, NumberStyles.None, null, out var num)) - { - throw Connector.Break( - new NpgsqlException( - $"Could not parse '{str}' as unsigned integer in field {field.Name}")); - } + throw Connector.Break( + new NpgsqlException( + $"Could not parse '{str}' as unsigned integer in field {field.Name}")); + } - results[i] = num; - continue; - case "bytea": - try - { - var bytes = buf.ReadMemory(len); - // Theoretically we could just copy over the raw bytes here, since bytea - // only comes from TIMELINE_HISTORY which doesn't really send bytea but raw bytes - // but let's not rely on this implementation detail and stay compatible - results[i] = ParseBytea(bytes.Span); - } - catch (Exception e) - { - throw Connector.Break( - new NpgsqlException($"Could not parse data as bytea in field {field.Name}", e)); - } + results[i] = num; + continue; + } + case "bytea": + try + { + var bytes = buf.ReadMemory(len); + // Theoretically we could just copy over the raw bytes here, since bytea + // only comes from TIMELINE_HISTORY which doesn't really send bytea but raw bytes + // but let's not rely on this implementation detail and stay compatible + results[i] = ParseBytea(bytes.Span); + } + catch (Exception e) + { + throw Connector.Break( + new NpgsqlException($"Could not parse data as bytea in field {field.Name}", e)); + } - continue; - default: + continue; + default: - throw Connector.Break(new NpgsqlException( - $"Field {field.Name} has PostgreSQL type {field.PostgresType.Name} which isn't supported yet")); - } + throw Connector.Break(new NpgsqlException( + $"Field {field.Name} has PostgreSQL type {field.PostgresType.Name} which isn't supported yet")); } + } + + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + Expect(await Connector.ReadMessage(true).ConfigureAwait(false), Connector); + return results; - Expect(await Connector.ReadMessage(true), Connector); - Expect(await Connector.ReadMessage(true), Connector); - return results; + static byte[] ParseBytea(ReadOnlySpan bytes) + { + return bytes.Length >= 2 && bytes[0] == '\\' && bytes[1] == 'x' + ? ParseByteaHex(bytes.Slice(2)) + : ParseByteaEscape(bytes); - byte[] ParseBytea(ReadOnlySpan bytes) + static byte[] ParseByteaHex(ReadOnlySpan inBytes) { - return bytes.Length >= 2 && bytes[0] == '\\' && bytes[1] == 'x' - ? ParseByteaHex(bytes.Slice(2)) - : ParseByteaEscape(bytes); + var outBytes = new byte[inBytes.Length / 2]; + for (var i = 0; i < inBytes.Length; i++) + { + var v1 = inBytes[i++]; + var v2 = inBytes[i]; + outBytes[i / 2] = + (byte)(((v1 - (v1 < 0x3A ? 0x30 : 87)) << 4) | (v2 - (v2 < 0x3A ? 0x30 : 87))); + } - byte[] ParseByteaHex(ReadOnlySpan inBytes) + return outBytes; + } + + static byte[] ParseByteaEscape(ReadOnlySpan inBytes) + { + var result = new MemoryStream(new byte[inBytes.Length]); + for (var tp = 0; tp < inBytes.Length;) { - var outBytes = new byte[inBytes.Length / 2]; - for (var i = 0; i < inBytes.Length; i++) + var c1 = inBytes[tp]; + if (c1 != '\\') { - var v1 = inBytes[i++]; - var v2 = inBytes[i]; - outBytes[i / 2] = - (byte)(((v1 - (v1 < 0x3A ? 0x30 : 87)) << 4) | (v2 - (v2 < 0x3A ? 0x30 : 87))); + // Don't validate whether c1 >= 0x20 && c1 <= 0x7e here + // TIMELINE_HISTORY currently (2020-09-13) sends raw + // bytes instead of bytea for the content value. + result.WriteByte(c1); + tp++; + continue; } - return outBytes; - } + var c2 = inBytes[tp + 1]; + if (c2 == '\\') + { + result.WriteByte(c2); + tp += 2; + continue; + } - byte[] ParseByteaEscape(ReadOnlySpan inBytes) - { - var result = new MemoryStream(new byte[inBytes.Length]); - for (var tp = 0; tp < inBytes.Length;) + var c3 = inBytes[tp + 2]; + var c4 = inBytes[tp + 3]; + if (c2 >= '0' && c2 <= '3' && + c3 >= '0' && c3 <= '7' && + c4 >= '0' && c4 <= '7') { - var c1 = inBytes[tp]; - if (c1 != '\\') - { - // Don't validate whether c1 >= 0x20 && c1 <= 0x7e here - // TIMELINE_HISTORY currently (2020-09-13) sends raw - // bytes instead of bytea for the content value. - result.WriteByte(c1); - tp++; - continue; - } - - var c2 = inBytes[tp + 1]; - if (c2 == '\\') - { - result.WriteByte(c2); - tp += 2; - continue; - } - - var c3 = inBytes[tp + 2]; - var c4 = inBytes[tp + 3]; - if (c2 >= '0' && c2 <= '3' && - c3 >= '0' && c3 <= '7' && - c4 >= '0' && c4 <= '7') - { - c2 <<= 3; - c2 += c3; - c2 <<= 3; - result.WriteByte((byte)(c2 + c4)); - - tp += 4; - continue; - } - - throw new FormatException("Invalid syntax for type bytea"); + c2 <<= 3; + c2 += c3; + c2 <<= 3; + result.WriteByte((byte)(c2 + c4)); + + tp += 4; + continue; } - return result.ToArray(); + throw new FormatException("Invalid syntax for type bytea"); } + + return result.ToArray(); } } + } - void SetTimeouts(TimeSpan readTimeout, TimeSpan writeTimeout) - { - var connector = Connector; - connector.UserTimeout = readTimeout > TimeSpan.Zero ? (int)readTimeout.TotalMilliseconds : 0; - - var writeBuffer = connector.WriteBuffer; - if (writeBuffer != null) - writeBuffer.Timeout = writeTimeout; - } + void SetTimeouts(TimeSpan readTimeout, TimeSpan writeTimeout) + { + var connector = Connector; + var readBuffer = connector.ReadBuffer; + if (readBuffer != null) + readBuffer.Timeout = readTimeout > TimeSpan.Zero ? readTimeout : TimeSpan.Zero; + + var writeBuffer = connector.WriteBuffer; + if (writeBuffer != null) + writeBuffer.Timeout = writeTimeout; + } - void CheckDisposed() - { - if (_isDisposed) - throw new ObjectDisposedException(GetType().Name); - } + internal void CheckDisposed() + { + if (_isDisposed) + throw new ObjectDisposedException(GetType().Name); } } diff --git a/src/Npgsql/Replication/ReplicationMessage.cs b/src/Npgsql/Replication/ReplicationMessage.cs index a03807c0e9..be957346cb 100644 --- a/src/Npgsql/Replication/ReplicationMessage.cs +++ b/src/Npgsql/Replication/ReplicationMessage.cs @@ -1,33 +1,37 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// The common base class for all streaming replication messages +/// +public abstract class ReplicationMessage { /// - /// The common base class for all streaming replication messages + /// The starting point of the WAL data in this message. /// - public abstract class ReplicationMessage - { - /// - /// The starting point of the WAL data in this message. - /// - public NpgsqlLogSequenceNumber WalStart { get; private set; } + public NpgsqlLogSequenceNumber WalStart { get; private set; } - /// - /// The current end of WAL on the server. - /// - public NpgsqlLogSequenceNumber WalEnd { get; private set; } + /// + /// The current end of WAL on the server. + /// + public NpgsqlLogSequenceNumber WalEnd { get; private set; } - /// - /// The server's system clock at the time this message was transmitted, as microseconds since midnight on 2000-01-01. - /// - public DateTime ServerClock { get; private set; } + /// + /// The server's system clock at the time this message was transmitted, as microseconds since midnight on 2000-01-01. + /// + /// + /// Since the client using Npgsql and the server may be located in different time zones, + /// as of Npgsql 7.0 this value is no longer converted to local time but keeps its original value in UTC. + /// You can check if you don't want to introduce behavior depending on Npgsql versions. + /// + public DateTime ServerClock { get; private set; } - private protected void Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock) - { - WalStart = walStart; - WalEnd = walEnd; - ServerClock = serverClock; - } + private protected void Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock) + { + WalStart = walStart; + WalEnd = walEnd; + ServerClock = serverClock; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/ReplicationSlot.cs b/src/Npgsql/Replication/ReplicationSlot.cs index 1378bec13e..8790303444 100644 --- a/src/Npgsql/Replication/ReplicationSlot.cs +++ b/src/Npgsql/Replication/ReplicationSlot.cs @@ -1,18 +1,17 @@ -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Contains information about a newly-created replication slot. +/// +public abstract class ReplicationSlot { - /// - /// Contains information about a newly-created replication slot. - /// - public abstract class ReplicationSlot + internal ReplicationSlot(string name) { - internal ReplicationSlot(string name) - { - Name = name; - } - - /// - /// The name of the newly-created replication slot. - /// - public string Name { get; } + Name = name; } -} + + /// + /// The name of the newly-created replication slot. + /// + public string Name { get; } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/ReplicationSlotOptions.cs b/src/Npgsql/Replication/ReplicationSlotOptions.cs index 9de31f6f50..669e8711df 100644 --- a/src/Npgsql/Replication/ReplicationSlotOptions.cs +++ b/src/Npgsql/Replication/ReplicationSlotOptions.cs @@ -1,61 +1,59 @@ -using JetBrains.Annotations; +using System; using NpgsqlTypes; -using System; -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Contains information about a replication slot. +/// +public readonly struct ReplicationSlotOptions { /// - /// Contains information about a replication slot. + /// Creates a new instance. /// - public readonly struct ReplicationSlotOptions - { - /// - /// Creates a new instance. - /// - /// - /// The name of the replication slot. - /// - /// - /// The WAL location at which the slot became consistent. - /// - public ReplicationSlotOptions(string slotName, string? consistentPoint = null) - : this(slotName, consistentPoint is null ? default : NpgsqlLogSequenceNumber.Parse(consistentPoint), null){} + /// + /// The name of the replication slot. + /// + /// + /// The WAL location at which the slot became consistent. + /// + public ReplicationSlotOptions(string slotName, string? consistentPoint = null) + : this(slotName, consistentPoint is null ? default : NpgsqlLogSequenceNumber.Parse(consistentPoint), null){} - /// - /// Creates a new instance. - /// - /// - /// The name of the replication slot. - /// - /// - /// The WAL location at which the slot became consistent. - /// - public ReplicationSlotOptions(string slotName, NpgsqlLogSequenceNumber consistentPoint) - : this(slotName, consistentPoint, null) {} + /// + /// Creates a new instance. + /// + /// + /// The name of the replication slot. + /// + /// + /// The WAL location at which the slot became consistent. + /// + public ReplicationSlotOptions(string slotName, NpgsqlLogSequenceNumber consistentPoint) + : this(slotName, consistentPoint, null) {} - internal ReplicationSlotOptions( - string slotName, - NpgsqlLogSequenceNumber consistentPoint, - string? snapshotName) - { - SlotName = slotName ?? throw new ArgumentNullException(nameof(slotName), "The replication slot name cannot be null."); - ConsistentPoint = consistentPoint; - SnapshotName = snapshotName; - } + internal ReplicationSlotOptions( + string slotName, + NpgsqlLogSequenceNumber consistentPoint, + string? snapshotName) + { + SlotName = slotName ?? throw new ArgumentNullException(nameof(slotName), "The replication slot name cannot be null."); + ConsistentPoint = consistentPoint; + SnapshotName = snapshotName; + } - /// - /// The name of the replication slot. - /// - public string SlotName { get; } + /// + /// The name of the replication slot. + /// + public string SlotName { get; } - /// - /// The WAL location at which the slot became consistent. - /// - public NpgsqlLogSequenceNumber ConsistentPoint { get; } + /// + /// The WAL location at which the slot became consistent. + /// + public NpgsqlLogSequenceNumber ConsistentPoint { get; } - /// - /// The identifier of the snapshot exported by the CREATE_REPLICATION_SLOT command. - /// - internal string? SnapshotName { get; } - } -} + /// + /// The identifier of the snapshot exported by the CREATE_REPLICATION_SLOT command. + /// + internal string? SnapshotName { get; } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/ReplicationSystemIdentification.cs b/src/Npgsql/Replication/ReplicationSystemIdentification.cs index fa8c81b43b..7e6673e702 100644 --- a/src/Npgsql/Replication/ReplicationSystemIdentification.cs +++ b/src/Npgsql/Replication/ReplicationSystemIdentification.cs @@ -1,39 +1,38 @@ using NpgsqlTypes; -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Contains server identification information returned from . +/// +public class ReplicationSystemIdentification { - /// - /// Contains server identification information returned from . - /// - public class ReplicationSystemIdentification + internal ReplicationSystemIdentification(string systemId, uint timeline, NpgsqlLogSequenceNumber xLogPos, string dbName) { - internal ReplicationSystemIdentification(string systemId, uint timeline, NpgsqlLogSequenceNumber xLogPos, string dbName) - { - SystemId = systemId; - Timeline = timeline; - XLogPos = xLogPos; - DbName = dbName; - } + SystemId = systemId; + Timeline = timeline; + XLogPos = xLogPos; + DbName = dbName; + } - /// - /// The unique system identifier identifying the cluster. - /// This can be used to check that the base backup used to initialize the standby came from the same cluster. - /// - public string SystemId { get; } + /// + /// The unique system identifier identifying the cluster. + /// This can be used to check that the base backup used to initialize the standby came from the same cluster. + /// + public string SystemId { get; } - /// - /// Current timeline ID. Also useful to check that the standby is consistent with the master. - /// - public uint Timeline { get; } + /// + /// Current timeline ID. Also useful to check that the standby is consistent with the master. + /// + public uint Timeline { get; } - /// - /// Current WAL flush location. Useful to get a known location in the write-ahead log where streaming can start. - /// - public NpgsqlLogSequenceNumber XLogPos { get; } + /// + /// Current WAL flush location. Useful to get a known location in the write-ahead log where streaming can start. + /// + public NpgsqlLogSequenceNumber XLogPos { get; } - /// - /// Database connected to. - /// - public string? DbName { get; } - } + /// + /// Database connected to. + /// + public string? DbName { get; } } diff --git a/src/Npgsql/Replication/TestDecoding/TestDecodingAsyncEnumerable.cs b/src/Npgsql/Replication/TestDecoding/TestDecodingAsyncEnumerable.cs index 0bf648b2b5..aca7ee70ea 100644 --- a/src/Npgsql/Replication/TestDecoding/TestDecodingAsyncEnumerable.cs +++ b/src/Npgsql/Replication/TestDecoding/TestDecodingAsyncEnumerable.cs @@ -1,57 +1,78 @@ +using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using Npgsql.Replication.Internal; using NpgsqlTypes; -namespace Npgsql.Replication.TestDecoding +namespace Npgsql.Replication.TestDecoding; + +sealed class TestDecodingAsyncEnumerable : IAsyncEnumerable { - class TestDecodingAsyncEnumerable : IAsyncEnumerable + readonly LogicalReplicationConnection _connection; + readonly TestDecodingReplicationSlot _slot; + readonly TestDecodingOptions _options; + readonly CancellationToken _baseCancellationToken; + readonly NpgsqlLogSequenceNumber? _walLocation; + + readonly TestDecodingData _cachedMessage = new(); + + internal TestDecodingAsyncEnumerable( + LogicalReplicationConnection connection, + TestDecodingReplicationSlot slot, + TestDecodingOptions options, + CancellationToken cancellationToken, + NpgsqlLogSequenceNumber? walLocation = null) { - readonly LogicalReplicationConnection _connection; - readonly TestDecodingReplicationSlot _slot; - readonly TestDecodingOptions _options; - readonly CancellationToken _baseCancellationToken; - readonly NpgsqlLogSequenceNumber? _walLocation; - - readonly TestDecodingData _cachedMessage = new TestDecodingData(); - - internal TestDecodingAsyncEnumerable( - LogicalReplicationConnection connection, - TestDecodingReplicationSlot slot, - TestDecodingOptions options, - CancellationToken cancellationToken, - NpgsqlLogSequenceNumber? walLocation = null) - { - _connection = connection; - _slot = slot; - _options = options; - _baseCancellationToken = cancellationToken; - _walLocation = walLocation; - } + _connection = connection; + _slot = slot; + _options = options; + _baseCancellationToken = cancellationToken; + _walLocation = walLocation; + } - public IAsyncEnumerator GetAsyncEnumerator( - CancellationToken cancellationToken = new CancellationToken()) - { - using (NoSynchronizationContextScope.Enter()) - return StartReplicationInternal( - CancellationTokenSource.CreateLinkedTokenSource(_baseCancellationToken, cancellationToken).Token); - } + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + cancellationToken = CancellationTokenSource.CreateLinkedTokenSource(_baseCancellationToken, cancellationToken).Token; - async IAsyncEnumerator StartReplicationInternal(CancellationToken cancellationToken) - { - var stream = _connection.StartLogicalReplication( - _slot, cancellationToken, _walLocation, _options.GetOptionPairs()); - var encoding = _connection.Encoding!; + var stream = _connection.StartLogicalReplication( + _slot, cancellationToken, _walLocation, _options.GetOptionPairs()); + var encoding = _connection.Encoding!; + + var buffer = ArrayPool.Shared.Rent(4096); - await foreach (var msg in stream.WithCancellation(cancellationToken)) + try + { + await foreach (var msg in stream.ConfigureAwait(false)) { - var memoryStream = new MemoryStream(); - await msg.Data.CopyToAsync(memoryStream, 4096, CancellationToken.None); - var data = encoding.GetString(memoryStream.ToArray()); + var len = (int)msg.Data.Length; + Debug.Assert(msg.Data.Position == 0); + if (len > buffer.Length) + { + ArrayPool.Shared.Return(buffer); + buffer = ArrayPool.Shared.Rent(len); + } + + var offset = 0; + while (offset < len) + { + var read = await msg.Data.ReadAsync(buffer, offset, len - offset, CancellationToken.None).ConfigureAwait(false); + if (read == 0) + throw new EndOfStreamException(); + offset += read; + } + + Debug.Assert(offset == len); + var data = encoding.GetString(buffer, 0, len); + yield return _cachedMessage.Populate(msg.WalStart, msg.WalEnd, msg.ServerClock, data); } } + finally + { + ArrayPool.Shared.Return(buffer); + } } } diff --git a/src/Npgsql/Replication/TestDecoding/TestDecodingConnectionExtensions.cs b/src/Npgsql/Replication/TestDecoding/TestDecodingConnectionExtensions.cs index 53610fbfc3..77321711d9 100644 --- a/src/Npgsql/Replication/TestDecoding/TestDecodingConnectionExtensions.cs +++ b/src/Npgsql/Replication/TestDecoding/TestDecodingConnectionExtensions.cs @@ -6,83 +6,87 @@ using Npgsql.Replication.TestDecoding; // ReSharper disable once CheckNamespace -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Extension methods to use with the +/// test_decoding logical decoding plugin. +/// See https://www.postgresql.org/docs/current/test-decoding.html. +/// +public static class TestDecodingConnectionExtensions { /// - /// Extension methods to use with the + /// Creates a class that wraps a replication slot using the /// test_decoding logical decoding plugin. - /// See https://www.postgresql.org/docs/current/test-decoding.html. /// - public static class TestDecodingConnectionExtensions + /// + /// See https://www.postgresql.org/docs/current/test-decoding.html + /// for more information. + /// + /// The to use for creating the + /// replication slot + /// The name of the slot to create. Must be a valid replication slot name (see + /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). + /// + /// + /// if this replication slot shall be temporary one; otherwise . + /// Temporary slots are not saved to disk and are automatically dropped on error or when the session has finished. + /// + /// + /// A to specify what to do with the snapshot created during logical slot + /// initialization. , which is also the default, will export the + /// snapshot for use in other sessions. This option can't be used inside a transaction. + /// will use the snapshot for the current transaction executing the + /// command. This option must be used in a transaction, and must be the + /// first command run in that transaction. Finally, will just use + /// the snapshot for logical decoding as normal but won't do anything else with it. + /// + /// + /// If , this logical replication slot supports decoding of two-phase transactions. With this option, + /// two-phase commands like PREPARE TRANSACTION, COMMIT PREPARED and ROLLBACK PREPARED are decoded and transmitted. + /// The transaction will be decoded and transmitted at PREPARE TRANSACTION time. The default is . + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// + /// A that wraps the newly-created replication slot. + /// + public static async Task CreateTestDecodingReplicationSlot( + this LogicalReplicationConnection connection, + string slotName, + bool temporarySlot = false, + LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, + bool twoPhase = false, + CancellationToken cancellationToken = default) { - /// - /// Creates a class that wraps a replication slot using the - /// test_decoding logical decoding plugin. - /// - /// - /// See https://www.postgresql.org/docs/current/test-decoding.html - /// for more information. - /// - /// The to use for creating the - /// replication slot - /// The name of the slot to create. Must be a valid replication slot name (see - /// https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION). - /// - /// - /// if this replication slot shall be temporary one; otherwise . - /// Temporary slots are not saved to disk and are automatically dropped on error or when the session has finished. - /// - /// - /// A to specify what to do with the snapshot created during logical slot - /// initialization. , which is also the default, will export the - /// snapshot for use in other sessions. This option can't be used inside a transaction. - /// will use the snapshot for the current transaction executing the - /// command. This option must be used in a transaction, and must be the - /// first command run in that transaction. Finally, will just use - /// the snapshot for logical decoding as normal but won't do anything else with it. - /// - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// - /// - /// A that wraps the newly-created replication slot. - /// - public static async Task CreateTestDecodingReplicationSlot( - this LogicalReplicationConnection connection, - string slotName, - bool temporarySlot = false, - LogicalSlotSnapshotInitMode? slotSnapshotInitMode = null, - CancellationToken cancellationToken = default) - { - // We don't enter NoSynchronizationContextScope here since we (have to) do it in CreateReplicationSlotForPlugin, because - // otherwise it wouldn't be set for external plugins. - var options = await connection.CreateLogicalReplicationSlot( - slotName, "test_decoding", temporarySlot, slotSnapshotInitMode, cancellationToken).ConfigureAwait(false); - return new TestDecodingReplicationSlot(options); - } - - /// - /// Instructs the server to start streaming the WAL for logical replication using the test_decoding logical decoding plugin, - /// starting at WAL location or at the slot's consistent point if - /// isn't specified. - /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. - /// - /// The to use for starting replication - /// The replication slot that will be updated as replication progresses so that the server - /// knows which WAL segments are still needed by the standby. - /// - /// The token to monitor for stopping the replication. - /// The collection of options passed to the slot's logical decoding plugin. - /// The WAL location to begin streaming at. - /// A representing an that - /// can be used to stream WAL entries in form of instances. - public static IAsyncEnumerable StartReplication( - this LogicalReplicationConnection connection, - TestDecodingReplicationSlot slot, - CancellationToken cancellationToken, - TestDecodingOptions? options = default, - NpgsqlLogSequenceNumber? walLocation = null) - => new TestDecodingAsyncEnumerable(connection, slot, options ?? new TestDecodingOptions(), cancellationToken, walLocation); + // We don't enter NoSynchronizationContextScope here since we (have to) do it in CreateReplicationSlotForPlugin, because + // otherwise it wouldn't be set for external plugins. + var options = await connection.CreateLogicalReplicationSlot( + slotName, "test_decoding", temporarySlot, slotSnapshotInitMode, twoPhase, cancellationToken).ConfigureAwait(false); + return new TestDecodingReplicationSlot(options); } -} + + /// + /// Instructs the server to start streaming the WAL for logical replication using the test_decoding logical decoding plugin, + /// starting at WAL location or at the slot's consistent point if + /// isn't specified. + /// The server can reply with an error, for example if the requested section of the WAL has already been recycled. + /// + /// The to use for starting replication + /// The replication slot that will be updated as replication progresses so that the server + /// knows which WAL segments are still needed by the standby. + /// + /// The token to monitor for stopping the replication. + /// The collection of options passed to the slot's logical decoding plugin. + /// The WAL location to begin streaming at. + /// A representing an that + /// can be used to stream WAL entries in form of instances. + public static IAsyncEnumerable StartReplication( + this LogicalReplicationConnection connection, + TestDecodingReplicationSlot slot, + CancellationToken cancellationToken, + TestDecodingOptions? options = default, + NpgsqlLogSequenceNumber? walLocation = null) + => new TestDecodingAsyncEnumerable(connection, slot, options ?? new TestDecodingOptions(), cancellationToken, walLocation); +} \ No newline at end of file diff --git a/src/Npgsql/Replication/TestDecoding/TestDecodingData.cs b/src/Npgsql/Replication/TestDecoding/TestDecodingData.cs index cabecfad31..c887a015ad 100644 --- a/src/Npgsql/Replication/TestDecoding/TestDecodingData.cs +++ b/src/Npgsql/Replication/TestDecoding/TestDecodingData.cs @@ -1,40 +1,39 @@ using NpgsqlTypes; using System; -namespace Npgsql.Replication.TestDecoding +namespace Npgsql.Replication.TestDecoding; + +/// +/// Text representations of PostgreSQL WAL operations decoded by the "test_decoding" plugin. See +/// https://www.postgresql.org/docs/current/test-decoding.html. +/// +public sealed class TestDecodingData : ReplicationMessage { /// - /// Text representations of PostgreSQL WAL operations decoded by the "test_decoding" plugin. See - /// https://www.postgresql.org/docs/current/test-decoding.html. + /// Decoded text representation of the operation performed in this WAL entry /// - public sealed class TestDecodingData : ReplicationMessage - { - /// - /// Decoded text representation of the operation performed in this WAL entry - /// - public string Data { get; private set; } = default!; + public string Data { get; private set; } = default!; - internal TestDecodingData Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, string data) - { - base.Populate(walStart, walEnd, serverClock); + internal TestDecodingData Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, string data) + { + base.Populate(walStart, walEnd, serverClock); - Data = data; + Data = data; - return this; - } + return this; + } - /// - public override string ToString() => Data; + /// + public override string ToString() => Data; - /// - /// Returns a clone of this message, which can be accessed after other replication messages have been retrieved. - /// - public TestDecodingData Clone() - { - var clone = new TestDecodingData(); - clone.Populate(WalStart, WalEnd, ServerClock, Data); - return clone; - } + /// + /// Returns a clone of this message, which can be accessed after other replication messages have been retrieved. + /// + public TestDecodingData Clone() + { + var clone = new TestDecodingData(); + clone.Populate(WalStart, WalEnd, ServerClock, Data); + return clone; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/TestDecoding/TestDecodingOptions.cs b/src/Npgsql/Replication/TestDecoding/TestDecodingOptions.cs index 523ef05b06..b0887a3885 100644 --- a/src/Npgsql/Replication/TestDecoding/TestDecodingOptions.cs +++ b/src/Npgsql/Replication/TestDecoding/TestDecodingOptions.cs @@ -1,115 +1,114 @@ using System; using System.Collections.Generic; -namespace Npgsql.Replication.TestDecoding +namespace Npgsql.Replication.TestDecoding; + +/// +/// Options to be passed to the test_decoding plugin +/// +public class TestDecodingOptions : IEquatable { /// - /// Options to be passed to the test_decoding plugin + /// Creates a new instance of . /// - public class TestDecodingOptions : IEquatable + /// Include the transaction number for BEGIN and COMMIT command output + /// Include the timestamp for COMMIT command output + /// Set the output mode to binary + /// Skip output for transactions that didn't change the database + /// Only output data that don't have the replication origin set + /// Include output from table rewrites that were caused by DDL statements + /// Enable streaming output + public TestDecodingOptions(bool? includeXids = null, bool? includeTimestamp = null, bool? forceBinary = null, + bool? skipEmptyXacts = null, bool? onlyLocal = null, bool? includeRewrites = null, bool? streamChanges = null) { - /// - /// Creates a new instance of . - /// - /// Include the transaction number for BEGIN and COMMIT command output - /// Include the timestamp for COMMIT command output - /// Set the output mode to binary - /// Skip output for transactions that didn't change the database - /// Only output data that don't have the replication origin set - /// Include output from table rewrites that were caused by DDL statements - /// Enable streaming output - public TestDecodingOptions(bool? includeXids = null, bool? includeTimestamp = null, bool? forceBinary = null, - bool? skipEmptyXacts = null, bool? onlyLocal = null, bool? includeRewrites = null, bool? streamChanges = null) - { - IncludeXids = includeXids; - IncludeTimestamp = includeTimestamp; - ForceBinary = forceBinary; - SkipEmptyXacts = skipEmptyXacts; - OnlyLocal = onlyLocal; - IncludeRewrites = includeRewrites; - StreamChanges = streamChanges; - } + IncludeXids = includeXids; + IncludeTimestamp = includeTimestamp; + ForceBinary = forceBinary; + SkipEmptyXacts = skipEmptyXacts; + OnlyLocal = onlyLocal; + IncludeRewrites = includeRewrites; + StreamChanges = streamChanges; + } - /// - /// Include the transaction number for BEGIN and COMMIT command output - /// - public bool? IncludeXids { get; } + /// + /// Include the transaction number for BEGIN and COMMIT command output + /// + public bool? IncludeXids { get; } - /// - /// Include the timestamp for COMMIT command output - /// - public bool? IncludeTimestamp { get; } + /// + /// Include the timestamp for COMMIT command output + /// + public bool? IncludeTimestamp { get; } - /// - /// Set the output mode to binary - /// - public bool? ForceBinary { get; } + /// + /// Set the output mode to binary + /// + public bool? ForceBinary { get; } - /// - /// Skip output for transactions that didn't change the database - /// - public bool? SkipEmptyXacts { get; } + /// + /// Skip output for transactions that didn't change the database + /// + public bool? SkipEmptyXacts { get; } - /// - /// Only output data that don't have the replication origin set - /// - public bool? OnlyLocal { get; } + /// + /// Only output data that don't have the replication origin set + /// + public bool? OnlyLocal { get; } - /// - /// Include output from table rewrites that were caused by DDL statements - /// - public bool? IncludeRewrites { get; } + /// + /// Include output from table rewrites that were caused by DDL statements + /// + public bool? IncludeRewrites { get; } - /// - /// Enable streaming output - /// - public bool? StreamChanges { get; } + /// + /// Enable streaming output + /// + public bool? StreamChanges { get; } - internal IEnumerable> GetOptionPairs() - { - if (IncludeXids != null) - yield return new KeyValuePair("include-xids", IncludeXids.Value ? null : "f"); - if (IncludeTimestamp != null) - yield return new KeyValuePair("include-timestamp", IncludeTimestamp.Value ? null : "f"); - if (ForceBinary != null) - yield return new KeyValuePair("force-binary", ForceBinary.Value ? "t" : "f"); - if (SkipEmptyXacts != null) - yield return new KeyValuePair("skip-empty-xacts", SkipEmptyXacts.Value ? null : "f"); - if (OnlyLocal != null) - yield return new KeyValuePair("only-local", OnlyLocal.Value ? null : "false"); - if (IncludeRewrites != null) - yield return new KeyValuePair("include-rewrites", IncludeRewrites.Value ? "t" : "f"); - if (StreamChanges != null) - yield return new KeyValuePair("stream-changes", StreamChanges.Value ? "t" : "f"); - } + internal IEnumerable> GetOptionPairs() + { + if (IncludeXids != null) + yield return new KeyValuePair("include-xids", IncludeXids.Value ? null : "f"); + if (IncludeTimestamp != null) + yield return new KeyValuePair("include-timestamp", IncludeTimestamp.Value ? null : "f"); + if (ForceBinary != null) + yield return new KeyValuePair("force-binary", ForceBinary.Value ? "t" : "f"); + if (SkipEmptyXacts != null) + yield return new KeyValuePair("skip-empty-xacts", SkipEmptyXacts.Value ? null : "f"); + if (OnlyLocal != null) + yield return new KeyValuePair("only-local", OnlyLocal.Value ? null : "false"); + if (IncludeRewrites != null) + yield return new KeyValuePair("include-rewrites", IncludeRewrites.Value ? "t" : "f"); + if (StreamChanges != null) + yield return new KeyValuePair("stream-changes", StreamChanges.Value ? "t" : "f"); + } - /// - public bool Equals(TestDecodingOptions? other) - => other != null && ( - ReferenceEquals(this, other) || - IncludeXids == other.IncludeXids && IncludeTimestamp == other.IncludeTimestamp && ForceBinary == other.ForceBinary && - SkipEmptyXacts == other.SkipEmptyXacts && OnlyLocal == other.OnlyLocal && IncludeRewrites == other.IncludeRewrites && - StreamChanges == other.StreamChanges); + /// + public bool Equals(TestDecodingOptions? other) + => other != null && ( + ReferenceEquals(this, other) || + IncludeXids == other.IncludeXids && IncludeTimestamp == other.IncludeTimestamp && ForceBinary == other.ForceBinary && + SkipEmptyXacts == other.SkipEmptyXacts && OnlyLocal == other.OnlyLocal && IncludeRewrites == other.IncludeRewrites && + StreamChanges == other.StreamChanges); - /// - public override bool Equals(object? obj) - => obj is TestDecodingOptions other && other.Equals(this); + /// + public override bool Equals(object? obj) + => obj is TestDecodingOptions other && other.Equals(this); - /// - public override int GetHashCode() - { + /// + public override int GetHashCode() + { #if NETSTANDARD2_0 - var hashCode = IncludeXids.GetHashCode(); - hashCode = (hashCode * 397) ^ IncludeTimestamp.GetHashCode(); - hashCode = (hashCode * 397) ^ ForceBinary.GetHashCode(); - hashCode = (hashCode * 397) ^ SkipEmptyXacts.GetHashCode(); - hashCode = (hashCode * 397) ^ OnlyLocal.GetHashCode(); - hashCode = (hashCode * 397) ^ IncludeRewrites.GetHashCode(); - hashCode = (hashCode * 397) ^ StreamChanges.GetHashCode(); - return hashCode; + var hashCode = IncludeXids.GetHashCode(); + hashCode = (hashCode * 397) ^ IncludeTimestamp.GetHashCode(); + hashCode = (hashCode * 397) ^ ForceBinary.GetHashCode(); + hashCode = (hashCode * 397) ^ SkipEmptyXacts.GetHashCode(); + hashCode = (hashCode * 397) ^ OnlyLocal.GetHashCode(); + hashCode = (hashCode * 397) ^ IncludeRewrites.GetHashCode(); + hashCode = (hashCode * 397) ^ StreamChanges.GetHashCode(); + return hashCode; #else - return HashCode.Combine(IncludeXids, IncludeTimestamp, ForceBinary, SkipEmptyXacts, OnlyLocal, IncludeRewrites, StreamChanges); + return HashCode.Combine(IncludeXids, IncludeTimestamp, ForceBinary, SkipEmptyXacts, OnlyLocal, IncludeRewrites, StreamChanges); #endif - } } -} +} \ No newline at end of file diff --git a/src/Npgsql/Replication/TestDecoding/TestDecodingReplicationSlot.cs b/src/Npgsql/Replication/TestDecoding/TestDecodingReplicationSlot.cs index 014d84a26b..cc5c52e5a4 100644 --- a/src/Npgsql/Replication/TestDecoding/TestDecodingReplicationSlot.cs +++ b/src/Npgsql/Replication/TestDecoding/TestDecodingReplicationSlot.cs @@ -1,34 +1,33 @@ using Npgsql.Replication.Internal; -namespace Npgsql.Replication.TestDecoding +namespace Npgsql.Replication.TestDecoding; + +/// +/// Acts as a proxy for a logical replication slot +/// initialized for for the test_decoding logical decoding plugin. +/// +public class TestDecodingReplicationSlot : LogicalReplicationSlot { /// - /// Acts as a proxy for a logical replication slot - /// initialized for for the test_decoding logical decoding plugin. + /// Creates a new instance. /// - public class TestDecodingReplicationSlot : LogicalReplicationSlot - { - /// - /// Creates a new instance. - /// - /// - /// Create a instance with this - /// constructor to wrap an existing PostgreSQL replication slot that has - /// been initialized for the test_decoding logical decoding plugin. - /// - /// The name of the existing replication slot - public TestDecodingReplicationSlot(string slotName) - : this(new ReplicationSlotOptions(slotName)) { } + /// + /// Create a instance with this + /// constructor to wrap an existing PostgreSQL replication slot that has + /// been initialized for the test_decoding logical decoding plugin. + /// + /// The name of the existing replication slot + public TestDecodingReplicationSlot(string slotName) + : this(new ReplicationSlotOptions(slotName)) { } - /// - /// Creates a new instance. - /// - /// - /// Create a instance with this - /// constructor to wrap an existing PostgreSQL replication slot that has - /// been initialized for the test_decoding logical decoding plugin. - /// - /// The representing the existing replication slot - public TestDecodingReplicationSlot(ReplicationSlotOptions options) : base("test_decoding", options) { } - } -} + /// + /// Creates a new instance. + /// + /// + /// Create a instance with this + /// constructor to wrap an existing PostgreSQL replication slot that has + /// been initialized for the test_decoding logical decoding plugin. + /// + /// The representing the existing replication slot + public TestDecodingReplicationSlot(ReplicationSlotOptions options) : base("test_decoding", options) { } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/TimelineHistoryFile.cs b/src/Npgsql/Replication/TimelineHistoryFile.cs index 1ff03f0287..89a15ffd69 100644 --- a/src/Npgsql/Replication/TimelineHistoryFile.cs +++ b/src/Npgsql/Replication/TimelineHistoryFile.cs @@ -1,30 +1,29 @@ -namespace Npgsql.Replication +namespace Npgsql.Replication; + +/// +/// Represents a PostgreSQL timeline history file +/// +public readonly struct TimelineHistoryFile { - /// - /// Represents a PostgreSQL timeline history file - /// - public readonly struct TimelineHistoryFile + internal TimelineHistoryFile(string fileName, byte[] content) { - internal TimelineHistoryFile(string fileName, byte[] content) - { - FileName = fileName; - Content = content; - } + FileName = fileName; + Content = content; + } - /// - /// File name of the timeline history file, e.g., 00000002.history. - /// - public string FileName { get; } + /// + /// File name of the timeline history file, e.g., 00000002.history. + /// + public string FileName { get; } - // While it is pretty safe to assume that a timeline history file - // only contains ASCII bytes since it is automatically written and - // parsed by the PostgreSQL backend, we don't want to claim anything - // about its content (we get it as bytes and we hand it over as bytes). + // While it is pretty safe to assume that a timeline history file + // only contains ASCII bytes since it is automatically written and + // parsed by the PostgreSQL backend, we don't want to claim anything + // about its content (we get it as bytes and we hand it over as bytes). - /// - /// Contents of the timeline history file. - /// - public byte[] Content { get; } - } -} + /// + /// Contents of the timeline history file. + /// + public byte[] Content { get; } +} \ No newline at end of file diff --git a/src/Npgsql/Replication/XLogDataMessage.cs b/src/Npgsql/Replication/XLogDataMessage.cs index 89104f8819..6b4ecd6dcf 100644 --- a/src/Npgsql/Replication/XLogDataMessage.cs +++ b/src/Npgsql/Replication/XLogDataMessage.cs @@ -1,37 +1,34 @@ -using JetBrains.Annotations; -using NpgsqlTypes; -using System; +using System; using System.IO; +using NpgsqlTypes; -namespace Npgsql.Replication -{ +namespace Npgsql.Replication; +/// +/// A message representing a section of the WAL data stream. +/// +public class XLogDataMessage : ReplicationMessage +{ /// - /// A message representing a section of the WAL data stream. + /// A section of the WAL data stream that is raw WAL data in physical replication or decoded with the selected + /// logical decoding plugin in logical replication. It is only valid until the next + /// is requested from the stream. /// - public class XLogDataMessage : ReplicationMessage - { - /// - /// A section of the WAL data stream that is raw WAL data in physical replication or decoded with the selected - /// logical decoding plugin in logical replication. It is only valid until the next - /// is requested from the stream. - /// - /// - /// A single WAL record is never split across two XLogData messages. - /// When a WAL record crosses a WAL page boundary, and is therefore already split using continuation records, - /// it can be split at the page boundary. In other words, the first main WAL record and its continuation - /// records can be sent in different XLogData messages. - /// - public Stream Data { get; private set; } = default!; + /// + /// A single WAL record is never split across two XLogData messages. + /// When a WAL record crosses a WAL page boundary, and is therefore already split using continuation records, + /// it can be split at the page boundary. In other words, the first main WAL record and its continuation + /// records can be sent in different XLogData messages. + /// + public Stream Data { get; private set; } = default!; - internal XLogDataMessage Populate( - NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, Stream data) - { - base.Populate(walStart, walEnd, serverClock); + internal XLogDataMessage Populate( + NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, Stream data) + { + base.Populate(walStart, walEnd, serverClock); - Data = data; + Data = data; - return this; - } + return this; } -} +} \ No newline at end of file diff --git a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs index 01fe817b70..300001e72d 100644 --- a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs +++ b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs @@ -2,36 +2,37 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data; -using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Transactions; using Npgsql.BackendMessages; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandlers.CompositeHandlers; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using Npgsql.PostgresTypes; using Npgsql.Util; +using NpgsqlTypes; -namespace Npgsql.Schema +namespace Npgsql.Schema; + +sealed class DbColumnSchemaGenerator { - class DbColumnSchemaGenerator - { - readonly RowDescriptionMessage _rowDescription; - readonly NpgsqlConnection _connection; - readonly bool _fetchAdditionalInfo; + readonly RowDescriptionMessage _rowDescription; + readonly NpgsqlConnection _connection; + readonly bool _fetchAdditionalInfo; - internal DbColumnSchemaGenerator(NpgsqlConnection connection, RowDescriptionMessage rowDescription, bool fetchAdditionalInfo) - { - _connection = connection; - _rowDescription = rowDescription; - _fetchAdditionalInfo = fetchAdditionalInfo; - } + internal DbColumnSchemaGenerator(NpgsqlConnection connection, RowDescriptionMessage rowDescription, bool fetchAdditionalInfo) + { + _connection = connection; + _rowDescription = rowDescription; + _fetchAdditionalInfo = fetchAdditionalInfo; + } - #region Columns queries + #region Columns queries - static string GenerateColumnsQuery(Version pgVersion, string columnFieldFilter) => -$@"SELECT + static string GenerateColumnsQuery(Version pgVersion, string columnFieldFilter) => + $@"SELECT typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, - {(pgVersion >= new Version(10, 0) ? "attidentity != ''" : "FALSE")} AS isidentity, + {(pgVersion.IsGreaterOrEqual(10) ? "attidentity != ''" : "FALSE")} AS isidentity, CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, CASE WHEN col.is_updatable = 'YES' THEN true ELSE false END AS is_updatable, @@ -45,7 +46,7 @@ pg_index.indisprimary AND SELECT * FROM pg_index WHERE pg_index.indrelid = cls.oid AND pg_index.indisunique AND - pg_index.{(pgVersion >= new Version(11, 0) ? "indnkeyatts" : "indnatts")} = 1 AND + pg_index.{(pgVersion.IsGreaterOrEqual(11) ? "indnkeyatts" : "indnatts")} = 1 AND attnum = pg_index.indkey[0] ) AS isunique FROM pg_attribute AS attr @@ -64,11 +65,11 @@ nspname NOT IN ('pg_catalog', 'information_schema') AND ({columnFieldFilter}) ORDER BY attnum"; - /// - /// Stripped-down version of , mainly to support Amazon Redshift. - /// - static string GenerateOldColumnsQuery(string columnFieldFilter) => - $@"SELECT + /// + /// Stripped-down version of , mainly to support Amazon Redshift. + /// + static string GenerateOldColumnsQuery(string columnFieldFilter) => + $@"SELECT typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, @@ -91,29 +92,36 @@ nspname NOT IN ('pg_catalog', 'information_schema') AND ({columnFieldFilter}) ORDER BY attnum"; - #endregion Column queries + #endregion Column queries - internal async Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) - { - // This is mainly for Amazon Redshift - var oldQueryMode = _connection.PostgreSqlVersion < new Version(8, 2); + internal async Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) + { + // This is mainly for Amazon Redshift + var oldQueryMode = _connection.PostgreSqlVersion < new Version(8, 2); - var fields = _rowDescription.Fields; - var result = new List(fields.Count); - for (var i = 0; i < fields.Count; i++) - result.Add(null); - var populatedColumns = 0; + var numFields = _rowDescription.Count; + var result = new List(numFields); + for (var i = 0; i < numFields; i++) + result.Add(null); + var populatedColumns = 0; + if (_fetchAdditionalInfo) + { // We have two types of fields - those which correspond to actual database columns // and those that don't (e.g. SELECT 8). For the former we load lots of info from // the backend (if fetchAdditionalInfo is true), for the latter we only have the RowDescription - var columnFieldFilter = _rowDescription.Fields - .Where(f => f.TableOID != 0) // Only column fields - .Select(c => $"(attr.attrelid={c.TableOID} AND attr.attnum={c.ColumnAttributeNumber})") - .Join(" OR "); + var filters = new List(); + for (var index = 0; index < _rowDescription.Count; index++) + { + var f = _rowDescription[index]; + // Only column fields + if (f.TableOID != 0) + filters.Add($"(attr.attrelid={f.TableOID} AND attr.attnum={f.ColumnAttributeNumber})"); + } - if (_fetchAdditionalInfo && columnFieldFilter != "") + var columnFieldFilter = string.Join(" OR ", filters); + if (columnFieldFilter != string.Empty) { var query = oldQueryMode ? GenerateOldColumnsQuery(columnFieldFilter) @@ -124,140 +132,150 @@ internal async Task> GetColumnSchema(bool asy async ? TransactionScopeAsyncFlowOption.Enabled : TransactionScopeAsyncFlowOption.Suppress); using var connection = (NpgsqlConnection)((ICloneable)_connection).Clone(); - await connection.Open(async, cancellationToken); + await connection.Open(async, cancellationToken).ConfigureAwait(false); using var cmd = new NpgsqlCommand(query, connection); - using var reader = await cmd.ExecuteReader(CommandBehavior.Default, async, cancellationToken); - while (async ? await reader.ReadAsync(cancellationToken): reader.Read()) + var reader = await cmd.ExecuteReader(async, CommandBehavior.Default, cancellationToken).ConfigureAwait(false); + try { - var column = LoadColumnDefinition(reader, _connection.Connector!.TypeMapper.DatabaseInfo, oldQueryMode); - for (var ordinal = 0; ordinal < fields.Count; ordinal++) + while (async ? await reader.ReadAsync(cancellationToken).ConfigureAwait(false) : reader.Read()) { - var field = fields[ordinal]; - if (field.TableOID == column.TableOID && - field.ColumnAttributeNumber == column.ColumnAttributeNumber) + var column = LoadColumnDefinition(reader, _connection.Connector!.DatabaseInfo, oldQueryMode); + for (var ordinal = 0; ordinal < numFields; ordinal++) { - populatedColumns++; - - if (column.ColumnOrdinal.HasValue) - column = column.Clone(); - - // The column's ordinal is with respect to the resultset, not its table - column.ColumnOrdinal = ordinal; - result[ordinal] = column; + var field = _rowDescription[ordinal]; + if (field.TableOID == column.TableOID && + field.ColumnAttributeNumber == column.ColumnAttributeNumber) + { + populatedColumns++; + + if (column.ColumnOrdinal.HasValue) + column = column.Clone(); + + // The column's ordinal is with respect to the resultset, not its table + column.ColumnOrdinal = ordinal; + result[ordinal] = column; + } } } } - } - - // We had some fields which don't correspond to regular table columns (or fetchAdditionalInfo is false). - // Fill in whatever info we have from the RowDescription itself - for (var i = 0; i < fields.Count; i++) - { - var column = result[i]; - var field = fields[i]; - - if (column is null) + finally { - column = SetUpNonColumnField(field); - column.ColumnOrdinal = i; - result[i] = column; - populatedColumns++; + if (async) + await reader.DisposeAsync().ConfigureAwait(false); + else + reader.Dispose(); } - - column.ColumnName = field.Name; - column.IsAliased = column.BaseColumnName is null ? default(bool?) : (column.BaseColumnName != column.ColumnName); } - - if (populatedColumns != fields.Count) - throw new NpgsqlException("Could not load all columns for the resultset"); - - return result.AsReadOnly()!; } - NpgsqlDbColumn LoadColumnDefinition(NpgsqlDataReader reader, NpgsqlDatabaseInfo databaseInfo, bool oldQueryMode) + // We had some fields which don't correspond to regular table columns (or fetchAdditionalInfo is false). + // Fill in whatever info we have from the RowDescription itself + for (var i = 0; i < numFields; i++) { - // We don't set ColumnName here. It should always contain the column alias rather than - // the table column name (i.e. in case of "SELECT foo AS foo_alias"). It will be set later. - var column = new NpgsqlDbColumn - { - AllowDBNull = !reader.GetBoolean(reader.GetOrdinal("attnotnull")), - BaseCatalogName = _connection.Database!, - BaseSchemaName = reader.GetString(reader.GetOrdinal("nspname")), - BaseServerName = _connection.Host!, - BaseTableName = reader.GetString(reader.GetOrdinal("relname")), - BaseColumnName = reader.GetString(reader.GetOrdinal("attname")), - ColumnAttributeNumber = reader.GetInt16(reader.GetOrdinal("attnum")), - IsKey = reader.GetBoolean(reader.GetOrdinal("isprimarykey")), - IsReadOnly = !reader.GetBoolean(reader.GetOrdinal("is_updatable")), - IsUnique = reader.GetBoolean(reader.GetOrdinal("isunique")), - - TableOID = reader.GetFieldValue(reader.GetOrdinal("attrelid")), - TypeOID = reader.GetFieldValue(reader.GetOrdinal("typoid")) - }; - - column.PostgresType = databaseInfo.ByOID[column.TypeOID]; - column.DataTypeName = column.PostgresType.DisplayName; // Facets do not get included - - var defaultValueOrdinal = reader.GetOrdinal("default"); - column.DefaultValue = reader.IsDBNull(defaultValueOrdinal) ? null : reader.GetString(defaultValueOrdinal); - - column.IsAutoIncrement = - !oldQueryMode && reader.GetBoolean(reader.GetOrdinal("isidentity")) || - column.DefaultValue != null && column.DefaultValue.StartsWith("nextval("); - - ColumnPostConfig(column, reader.GetInt32(reader.GetOrdinal("typmod"))); - - return column; - } + var column = result[i]; + var field = _rowDescription[i]; - NpgsqlDbColumn SetUpNonColumnField(FieldDescription field) - { - // ColumnName and BaseColumnName will be set later - var column = new NpgsqlDbColumn + if (column is null) { - BaseCatalogName = _connection.Database!, - BaseServerName = _connection.Host!, - IsReadOnly = true, - DataTypeName = field.PostgresType.DisplayName, - TypeOID = field.TypeOID, - TableOID = field.TableOID, - ColumnAttributeNumber = field.ColumnAttributeNumber, - PostgresType = field.PostgresType - }; - - ColumnPostConfig(column, field.TypeModifier); - - return column; + column = SetUpNonColumnField(field); + column.ColumnOrdinal = i; + result[i] = column; + populatedColumns++; + } + + column.ColumnName = field.Name; + column.IsAliased = column.BaseColumnName is null ? default(bool?) : (column.BaseColumnName != column.ColumnName); } - /// - /// Performs some post-setup configuration that's common to both table columns and non-columns. - /// - void ColumnPostConfig(NpgsqlDbColumn column, int typeModifier) + if (populatedColumns != numFields) + throw new NpgsqlException("Could not load all columns for the resultset"); + + return result.AsReadOnly()!; + } + + NpgsqlDbColumn LoadColumnDefinition(NpgsqlDataReader reader, NpgsqlDatabaseInfo databaseInfo, bool oldQueryMode) + { + // We don't set ColumnName here. It should always contain the column alias rather than + // the table column name (i.e. in case of "SELECT foo AS foo_alias"). It will be set later. + var column = new NpgsqlDbColumn { - var typeMapper = _connection.Connector!.TypeMapper; + AllowDBNull = !reader.GetBoolean(reader.GetOrdinal("attnotnull")), + BaseCatalogName = _connection.Database!, + BaseSchemaName = reader.GetString(reader.GetOrdinal("nspname")), + BaseServerName = _connection.Host!, + BaseTableName = reader.GetString(reader.GetOrdinal("relname")), + BaseColumnName = reader.GetString(reader.GetOrdinal("attname")), + ColumnAttributeNumber = reader.GetInt16(reader.GetOrdinal("attnum")), + IsKey = reader.GetBoolean(reader.GetOrdinal("isprimarykey")), + IsReadOnly = !reader.GetBoolean(reader.GetOrdinal("is_updatable")), + IsUnique = reader.GetBoolean(reader.GetOrdinal("isunique")), - column.NpgsqlDbType = typeMapper.GetTypeInfoByOid(column.TypeOID).npgsqlDbType; - column.DataType = typeMapper.TryGetByOID(column.TypeOID, out var handler) - ? handler.GetFieldType() - : null; + TableOID = reader.GetFieldValue(reader.GetOrdinal("attrelid")), + TypeOID = reader.GetFieldValue(reader.GetOrdinal("typoid")) + }; - if (column.DataType != null) - { - column.IsLong = handler is ByteaHandler; + column.PostgresType = databaseInfo.ByOID[column.TypeOID]; + column.DataTypeName = column.PostgresType.DisplayName; // Facets do not get included - if (handler is ICompositeHandler) - column.UdtAssemblyQualifiedName = column.DataType.AssemblyQualifiedName; - } + var defaultValueOrdinal = reader.GetOrdinal("default"); + column.DefaultValue = reader.IsDBNull(defaultValueOrdinal) ? null : reader.GetString(defaultValueOrdinal); + + column.IsIdentity = !oldQueryMode && reader.GetBoolean(reader.GetOrdinal("isidentity")); + + // Use a heuristic to discover old SERIAL columns + column.IsAutoIncrement = + column.IsIdentity == true || + column.DefaultValue != null && column.DefaultValue.StartsWith("nextval(", StringComparison.Ordinal); + + ColumnPostConfig(column, reader.GetInt32(reader.GetOrdinal("typmod"))); - var facets = column.PostgresType.GetFacets(typeModifier); - if (facets.Size != null) - column.ColumnSize = facets.Size; - if (facets.Precision != null) - column.NumericPrecision = facets.Precision; - if (facets.Scale != null) - column.NumericScale = facets.Scale; + return column; + } + + NpgsqlDbColumn SetUpNonColumnField(FieldDescription field) + { + // ColumnName and BaseColumnName will be set later + var column = new NpgsqlDbColumn + { + BaseCatalogName = _connection.Database!, + BaseServerName = _connection.Host!, + IsReadOnly = true, + DataTypeName = field.PostgresType.DisplayName, + TypeOID = field.TypeOID, + TableOID = field.TableOID, + ColumnAttributeNumber = field.ColumnAttributeNumber, + PostgresType = field.PostgresType + }; + + ColumnPostConfig(column, field.TypeModifier); + + return column; + } + + /// + /// Performs some post-setup configuration that's common to both table columns and non-columns. + /// + void ColumnPostConfig(NpgsqlDbColumn column, int typeModifier) + { + var serializerOptions = _connection.Connector!.SerializerOptions; + + column.NpgsqlDbType = column.PostgresType.DataTypeName.ToNpgsqlDbType(); + if (serializerOptions.GetObjectOrDefaultTypeInfo(column.PostgresType) is { } typeInfo) + { + column.DataType = typeInfo.Type; + column.IsLong = column.PostgresType.DataTypeName == DataTypeNames.Bytea; + + if (column.PostgresType is PostgresCompositeType) + column.UdtAssemblyQualifiedName = typeInfo.Type.AssemblyQualifiedName; } + + var facets = column.PostgresType.GetFacets(typeModifier); + if (facets.Size != null) + column.ColumnSize = facets.Size; + if (facets.Precision != null) + column.NumericPrecision = facets.Precision; + if (facets.Scale != null) + column.NumericScale = facets.Scale; } } diff --git a/src/Npgsql/Schema/NpgsqlDbColumn.cs b/src/Npgsql/Schema/NpgsqlDbColumn.cs index 8961bd01fb..e4597e3d86 100644 --- a/src/Npgsql/Schema/NpgsqlDbColumn.cs +++ b/src/Npgsql/Schema/NpgsqlDbColumn.cs @@ -1,230 +1,234 @@ using System; using System.Data.Common; -using System.Runtime.CompilerServices; -using JetBrains.Annotations; using Npgsql.PostgresTypes; using NpgsqlTypes; -namespace Npgsql.Schema +namespace Npgsql.Schema; + +/// +/// Provides schema information about a column. +/// +/// +/// Note that this can correspond to a field returned in a query which isn't an actual table column +/// +/// See https://msdn.microsoft.com/en-us/library/system.data.sqlclient.sqldatareader.getschematable(v=vs.110).aspx +/// for information on the meaning of the different fields. +/// +public class NpgsqlDbColumn : DbColumn { /// - /// Provides schema information about a column. + /// Initializes a new instance of the class. /// - /// - /// Note that this can correspond to a field returned in a query which isn't an actual table column - /// - /// See https://msdn.microsoft.com/en-us/library/system.data.sqlclient.sqldatareader.getschematable(v=vs.110).aspx - /// for information on the meaning of the different fields. - /// - public class NpgsqlDbColumn : DbColumn - { - /// - /// Initializes a new instance of the class. - /// - public NpgsqlDbColumn() - { - PostgresType = UnknownBackendType.Instance; - - // Not supported in PostgreSQL - IsExpression = false; - IsAliased = false; - IsHidden = false; - IsIdentity = false; - } - - internal NpgsqlDbColumn Clone() => - Unsafe.As(MemberwiseClone()); - - #region Standard fields - // ReSharper disable once InconsistentNaming - /// - public new bool? AllowDBNull - { - get => base.AllowDBNull; - protected internal set => base.AllowDBNull = value; - } + public NpgsqlDbColumn() + { + PostgresType = UnknownBackendType.Instance; - /// - public new string BaseCatalogName - { - get => base.BaseCatalogName!; - protected internal set => base.BaseCatalogName = value; - } + // Not supported in PostgreSQL + IsExpression = false; + IsAliased = false; + IsHidden = false; + IsIdentity = false; + } - /// - public new string? BaseColumnName - { - get => base.BaseColumnName; - protected internal set => base.BaseColumnName = value; - } + internal NpgsqlDbColumn Clone() => + (NpgsqlDbColumn)MemberwiseClone(); - /// - public new string? BaseSchemaName - { - get => base.BaseSchemaName; - protected internal set => base.BaseSchemaName = value; - } + #region Standard fields + // ReSharper disable once InconsistentNaming + /// + public new bool? AllowDBNull + { + get => base.AllowDBNull; + protected internal set => base.AllowDBNull = value; + } - /// - public new string BaseServerName - { - get => base.BaseServerName!; - protected internal set => base.BaseServerName = value; - } + /// + public new string BaseCatalogName + { + get => base.BaseCatalogName!; + protected internal set => base.BaseCatalogName = value; + } - /// - public new string? BaseTableName - { - get => base.BaseTableName; - protected internal set => base.BaseTableName = value; - } + /// + public new string? BaseColumnName + { + get => base.BaseColumnName; + protected internal set => base.BaseColumnName = value; + } - /// - public new string ColumnName - { - get => base.ColumnName; - protected internal set => base.ColumnName = value; - } + /// + public new string? BaseSchemaName + { + get => base.BaseSchemaName; + protected internal set => base.BaseSchemaName = value; + } - /// - public new int? ColumnOrdinal - { - get => base.ColumnOrdinal; - protected internal set => base.ColumnOrdinal = value; - } + /// + public new string BaseServerName + { + get => base.BaseServerName!; + protected internal set => base.BaseServerName = value; + } - /// - public new int? ColumnSize - { - get => base.ColumnSize; - protected internal set => base.ColumnSize = value; - } + /// + public new string? BaseTableName + { + get => base.BaseTableName; + protected internal set => base.BaseTableName = value; + } - /// - public new bool? IsAliased - { - get => base.IsAliased; - protected internal set => base.IsAliased = value; - } + /// + public new string ColumnName + { + get => base.ColumnName; + protected internal set => base.ColumnName = value; + } - /// - public new bool? IsAutoIncrement - { - get => base.IsAutoIncrement; - protected internal set => base.IsAutoIncrement = value; - } + /// + public new int? ColumnOrdinal + { + get => base.ColumnOrdinal; + protected internal set => base.ColumnOrdinal = value; + } - /// - public new bool? IsKey - { - get => base.IsKey; - protected internal set => base.IsKey = value; - } + /// + public new int? ColumnSize + { + get => base.ColumnSize; + protected internal set => base.ColumnSize = value; + } - /// - public new bool? IsLong - { - get => base.IsLong; - protected internal set => base.IsLong = value; - } + /// + public new bool? IsAliased + { + get => base.IsAliased; + protected internal set => base.IsAliased = value; + } - /// - public new bool? IsReadOnly - { - get => base.IsReadOnly; - protected internal set => base.IsReadOnly = value; - } + /// + public new bool? IsAutoIncrement + { + get => base.IsAutoIncrement; + protected internal set => base.IsAutoIncrement = value; + } - /// - public new bool? IsUnique - { - get => base.IsUnique; - protected internal set => base.IsUnique = value; - } + /// + public new bool? IsIdentity + { + get => base.IsIdentity; + protected internal set => base.IsIdentity = value; + } - /// - public new int? NumericPrecision - { - get => base.NumericPrecision; - protected internal set => base.NumericPrecision = value; - } + /// + public new bool? IsKey + { + get => base.IsKey; + protected internal set => base.IsKey = value; + } - /// - public new int? NumericScale - { - get => base.NumericScale; - protected internal set => base.NumericScale = value; - } + /// + public new bool? IsLong + { + get => base.IsLong; + protected internal set => base.IsLong = value; + } - /// - public new string? UdtAssemblyQualifiedName - { - get => base.UdtAssemblyQualifiedName; - protected internal set => base.UdtAssemblyQualifiedName = value; - } + /// + public new bool? IsReadOnly + { + get => base.IsReadOnly; + protected internal set => base.IsReadOnly = value; + } - /// - public new Type? DataType - { - get => base.DataType; - protected internal set => base.DataType = value; - } + /// + public new bool? IsUnique + { + get => base.IsUnique; + protected internal set => base.IsUnique = value; + } - /// - public new string DataTypeName - { - get => base.DataTypeName!; - protected internal set => base.DataTypeName = value; - } - - #endregion Standard fields - - #region Npgsql-specific fields - - /// - /// The describing the type of this column. - /// - public PostgresType PostgresType { get; internal set; } - - /// - /// The OID of the type of this column in the PostgreSQL pg_type catalog table. - /// - public uint TypeOID { get; internal set; } - - /// - /// The OID of the PostgreSQL table of this column. - /// - public uint TableOID { get; internal set; } - - /// - /// The column's position within its table. Note that this is different from , - /// which is the column's position within the resultset. - /// - public short? ColumnAttributeNumber { get; internal set; } - - /// - /// The default SQL expression for this column. - /// - public string? DefaultValue { get; internal set; } - - /// - /// The value for this column's type. - /// - public NpgsqlDbType? NpgsqlDbType { get; internal set; } - - /// - public override object? this[string propertyName] - => propertyName switch - { - nameof(PostgresType) => PostgresType, - nameof(TypeOID) => TypeOID, - nameof(TableOID) => TableOID, - nameof(ColumnAttributeNumber) => ColumnAttributeNumber, - nameof(DefaultValue) => DefaultValue, - nameof(NpgsqlDbType) => NpgsqlDbType, - _ => base[propertyName] - }; - - #endregion Npgsql-specific fields + /// + public new int? NumericPrecision + { + get => base.NumericPrecision; + protected internal set => base.NumericPrecision = value; + } + + /// + public new int? NumericScale + { + get => base.NumericScale; + protected internal set => base.NumericScale = value; + } + + /// + public new string? UdtAssemblyQualifiedName + { + get => base.UdtAssemblyQualifiedName; + protected internal set => base.UdtAssemblyQualifiedName = value; + } + + /// + public new Type? DataType + { + get => base.DataType; + protected internal set => base.DataType = value; + } + + /// + public new string DataTypeName + { + get => base.DataTypeName!; + protected internal set => base.DataTypeName = value; } + + #endregion Standard fields + + #region Npgsql-specific fields + + /// + /// The describing the type of this column. + /// + public PostgresType PostgresType { get; internal set; } + + /// + /// The OID of the type of this column in the PostgreSQL pg_type catalog table. + /// + public uint TypeOID { get; internal set; } + + /// + /// The OID of the PostgreSQL table of this column. + /// + public uint TableOID { get; internal set; } + + /// + /// The column's position within its table. Note that this is different from , + /// which is the column's position within the resultset. + /// + public short? ColumnAttributeNumber { get; internal set; } + + /// + /// The default SQL expression for this column. + /// + public string? DefaultValue { get; internal set; } + + /// + /// The value for this column's type. + /// + public NpgsqlDbType? NpgsqlDbType { get; internal set; } + + /// + public override object? this[string propertyName] + => propertyName switch + { + nameof(PostgresType) => PostgresType, + nameof(TypeOID) => TypeOID, + nameof(TableOID) => TableOID, + nameof(ColumnAttributeNumber) => ColumnAttributeNumber, + nameof(DefaultValue) => DefaultValue, + nameof(NpgsqlDbType) => NpgsqlDbType, + _ => base[propertyName] + }; + + #endregion Npgsql-specific fields } diff --git a/src/Npgsql/Shims/Batching.cs b/src/Npgsql/Shims/Batching.cs new file mode 100644 index 0000000000..c8e7ddec1c --- /dev/null +++ b/src/Npgsql/Shims/Batching.cs @@ -0,0 +1,130 @@ +#if !NET6_0_OR_GREATER +using System.Collections; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable 1591,RS0016 + +// ReSharper disable once CheckNamespace +namespace System.Data.Common +{ + public abstract class DbBatch : IDisposable, IAsyncDisposable + { + public DbBatchCommandCollection BatchCommands => DbBatchCommands; + + protected abstract DbBatchCommandCollection DbBatchCommands { get; } + + public abstract int Timeout { get; set; } + + public DbConnection? Connection + { + get => DbConnection; + set => DbConnection = value; + } + + protected abstract DbConnection? DbConnection { get; set; } + + public DbTransaction? Transaction + { + get => DbTransaction; + set => DbTransaction = value; + } + + protected abstract DbTransaction? DbTransaction { get; set; } + + public DbDataReader ExecuteReader(CommandBehavior behavior = CommandBehavior.Default) + => ExecuteDbDataReader(behavior); + + protected abstract DbDataReader ExecuteDbDataReader(CommandBehavior behavior); + + public Task ExecuteReaderAsync(CancellationToken cancellationToken = default) + => ExecuteDbDataReaderAsync(CommandBehavior.Default, cancellationToken); + + public Task ExecuteReaderAsync( + CommandBehavior behavior, + CancellationToken cancellationToken = default) + => ExecuteDbDataReaderAsync(behavior, cancellationToken); + + protected abstract Task ExecuteDbDataReaderAsync( + CommandBehavior behavior, + CancellationToken cancellationToken); + + public abstract int ExecuteNonQuery(); + + public abstract Task ExecuteNonQueryAsync(CancellationToken cancellationToken = default); + + public abstract object? ExecuteScalar(); + + public abstract Task ExecuteScalarAsync(CancellationToken cancellationToken = default); + + public abstract void Prepare(); + + public abstract Task PrepareAsync(CancellationToken cancellationToken = default); + + public abstract void Cancel(); + + public DbBatchCommand CreateBatchCommand() => CreateDbBatchCommand(); + + protected abstract DbBatchCommand CreateDbBatchCommand(); + + public virtual void Dispose() {} + + public virtual ValueTask DisposeAsync() + { + Dispose(); + return default; + } + } + + public abstract class DbBatchCommand + { + public abstract string CommandText { get; set; } + + public abstract CommandType CommandType { get; set; } + + public abstract int RecordsAffected { get; } + + public DbParameterCollection Parameters => DbParameterCollection; + + protected abstract DbParameterCollection DbParameterCollection { get; } + } + + public abstract class DbBatchCommandCollection : IList + { + public abstract IEnumerator GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public abstract void Add(DbBatchCommand item); + + public abstract void Clear(); + + public abstract bool Contains(DbBatchCommand item); + + public abstract void CopyTo(DbBatchCommand[] array, int arrayIndex); + + public abstract bool Remove(DbBatchCommand item); + + public abstract int Count { get; } + + public abstract bool IsReadOnly { get; } + + public abstract int IndexOf(DbBatchCommand item); + + public abstract void Insert(int index, DbBatchCommand item); + + public abstract void RemoveAt(int index); + + public DbBatchCommand this[int index] + { + get => GetBatchCommand(index); + set => SetBatchCommand(index, value); + } + + protected abstract DbBatchCommand GetBatchCommand(int index); + + protected abstract void SetBatchCommand(int index, DbBatchCommand batchCommand); + } +} +#endif diff --git a/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs b/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs new file mode 100644 index 0000000000..02f5c2077c --- /dev/null +++ b/src/Npgsql/Shims/ConcurrentDictionaryExtensions.cs @@ -0,0 +1,15 @@ +namespace System.Collections.Concurrent; + +#if NETSTANDARD2_0 +static class ConcurrentDictionaryExtensions +{ + public static TValue GetOrAdd(this ConcurrentDictionary instance, TKey key, + Func valueFactory, TArg factoryArgument) + { + // The actual closure capture exists in a local function to prevent a display class allocation at the start of the method. + return instance.TryGetValue(key, out var value) ? value : GetOrAddWithClosure(instance, key, valueFactory, factoryArgument); + + static TValue GetOrAddWithClosure(ConcurrentDictionary instance, TKey key, Func valuefactory, TArg factoryargument) => instance.GetOrAdd(key, key => valuefactory(key, factoryargument)); + } +} +#endif diff --git a/src/Npgsql/Netstandard20/DbDataReaderExtensions.cs b/src/Npgsql/Shims/DbDataReaderExtensions.cs similarity index 100% rename from src/Npgsql/Netstandard20/DbDataReaderExtensions.cs rename to src/Npgsql/Shims/DbDataReaderExtensions.cs diff --git a/src/Npgsql/Shims/DbDataSource.cs b/src/Npgsql/Shims/DbDataSource.cs new file mode 100644 index 0000000000..6951d427fb --- /dev/null +++ b/src/Npgsql/Shims/DbDataSource.cs @@ -0,0 +1,70 @@ +#if !NET7_0_OR_GREATER + +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member (compatibility shim for old TFMs) + +// ReSharper disable once CheckNamespace +namespace System.Data.Common; + +public abstract class DbDataSource : IDisposable, IAsyncDisposable +{ + public abstract string ConnectionString { get; } + + protected abstract DbConnection CreateDbConnection(); + + // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. + protected virtual DbConnection OpenDbConnection() + => throw new NotSupportedException(); + + // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. + protected virtual ValueTask OpenDbConnectionAsync(CancellationToken cancellationToken = default) + => throw new NotSupportedException(); + + // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. + protected virtual DbCommand CreateDbCommand(string? commandText = null) + => throw new NotSupportedException(); + + // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. + protected virtual DbBatch CreateDbBatch() + => throw new NotSupportedException(); + + public DbConnection CreateConnection() + => CreateDbConnection(); + + public DbConnection OpenConnection() + => OpenDbConnection(); + + public ValueTask OpenConnectionAsync(CancellationToken cancellationToken = default) + => OpenDbConnectionAsync(cancellationToken); + + public DbCommand CreateCommand(string? commandText = null) + => CreateDbCommand(commandText); + + public DbBatch CreateBatch() + => CreateDbBatch(); + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + public async ValueTask DisposeAsync() + { + await DisposeAsyncCore().ConfigureAwait(false); + + Dispose(disposing: false); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + } + + protected virtual ValueTask DisposeAsyncCore() + => default; +} + +#endif \ No newline at end of file diff --git a/src/Npgsql/Shims/DictonaryExtensions.cs b/src/Npgsql/Shims/DictonaryExtensions.cs new file mode 100644 index 0000000000..a13397a39e --- /dev/null +++ b/src/Npgsql/Shims/DictonaryExtensions.cs @@ -0,0 +1,19 @@ +#if NETSTANDARD2_0 +// ReSharper disable once CheckNamespace +namespace System.Collections.Generic; + +// Helpers for Dictionary before netstandard 2.1 +static class DictonaryExtensions +{ + public static bool TryAdd(this Dictionary dictionary, TKey key, TValue value) + { + if (!dictionary.ContainsKey(key)) + { + dictionary.Add(key, value); + return true; + } + + return false; + } +} +#endif diff --git a/src/Npgsql/Shims/EncodingExtensions.cs b/src/Npgsql/Shims/EncodingExtensions.cs new file mode 100644 index 0000000000..e66d8fffa9 --- /dev/null +++ b/src/Npgsql/Shims/EncodingExtensions.cs @@ -0,0 +1,237 @@ +// ReSharper disable RedundantUsingDirective +using System.Buffers; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +// ReSharper restore RedundantUsingDirective + +// ReSharper disable once CheckNamespace +namespace System.Text; + +static class EncodingExtensions +{ +#if NETSTANDARD2_0 + + /// + /// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference + /// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe ref readonly T GetNonNullPinnableReference(ReadOnlySpan span) + => ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef((void*)1); + + /// + /// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference + /// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe ref T GetNonNullPinnableReference(Span span) + => ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef((void*)1); + + public static unsafe int GetByteCount(this Encoding encoding, ReadOnlySpan chars) + { + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + { + return encoding.GetByteCount(charsPtr, chars.Length); + } + } + + public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan chars, Span bytes) + { + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetBytes(charsPtr, chars.Length, bytesPtr, bytes.Length); + } + } + + public static unsafe int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) + { + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetCharCount(bytesPtr, bytes.Length); + } + } + + public static unsafe int GetCharCount(this Decoder encoding, ReadOnlySpan bytes, bool flush) + { + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetCharCount(bytesPtr, bytes.Length, flush); + } + } + + public static unsafe int GetChars(this Decoder encoding, ReadOnlySpan bytes, Span chars, bool flush) + { + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + { + return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length, flush); + } + } + + public static unsafe int GetChars(this Encoding encoding, ReadOnlySpan bytes, Span chars) + { + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + { + return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length); + } + } + + public static unsafe void Convert(this Encoder encoder, ReadOnlySpan chars, Span bytes, bool flush, out int charsUsed, out int bytesUsed, out bool completed) + { + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + { + encoder.Convert(charsPtr, chars.Length, bytesPtr, bytes.Length, flush, out charsUsed, out bytesUsed, out completed); + } + } + + public static unsafe void Convert(this Decoder encoder, ReadOnlySpan bytes, Span chars, bool flush, out int bytesUsed, out int charsUsed, out bool completed) + { + fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes)) + fixed (char* charsPtr = &GetNonNullPinnableReference(chars)) + { + encoder.Convert(bytesPtr, bytes.Length, charsPtr, chars.Length, flush, out bytesUsed, out charsUsed, out completed); + } + } +#endif + +#if NETSTANDARD + /// + /// Decodes the specified to s using the specified + /// and outputs the result to . + /// + /// The which represents how the data in is encoded. + /// The to decode to characters. + /// The destination buffer to which the decoded characters will be written. + /// The number of chars written to . + /// Thrown if is not large enough to contain the encoded form of . + /// Thrown if contains data that cannot be decoded and is configured + /// to throw an exception when such data is seen. + public static int GetChars(this Encoding encoding, in ReadOnlySequence bytes, Span chars) + { + if (encoding is null) + throw new ArgumentNullException(nameof(encoding)); + + if (bytes.IsSingleSegment) + { + // If the incoming sequence is single-segment, one-shot this. + + return encoding.GetChars(bytes.First.Span, chars); + } + else + { + // If the incoming sequence is multi-segment, create a stateful Decoder + // and use it as the workhorse. On the final iteration we'll pass flush=true. + + var remainingBytes = bytes; + var originalCharsLength = chars.Length; + var decoder = encoding.GetDecoder(); + bool isFinalSegment; + + do + { + var firstSpan = remainingBytes.First.Span; + var next = remainingBytes.GetPosition(firstSpan.Length); + isFinalSegment = remainingBytes.IsSingleSegment; + + var charsWrittenJustNow = decoder.GetChars(firstSpan, chars, flush: isFinalSegment); + chars = chars.Slice(charsWrittenJustNow); + remainingBytes = remainingBytes.Slice(next); + } while (!isFinalSegment); + + return originalCharsLength - chars.Length; // total number of chars we wrote + } + } + + public static string GetString(this Encoding encoding, in ReadOnlySequence bytes) + { + if (encoding is null) + throw new ArgumentNullException(nameof(encoding)); + + // If the incoming sequence is single-segment, one-shot this. + if (bytes.IsSingleSegment) + { +#if NETSTANDARD2_1 + return encoding.GetString(bytes.First.Span); +#else + var rented = false; + byte[] arr; + var offset = 0; + var memory = bytes.First; + if (MemoryMarshal.TryGetArray(memory, out var segment)) + { + arr = segment.Array!; + offset = segment.Offset; + } + else + { + rented = true; + arr = ArrayPool.Shared.Rent(memory.Length); + bytes.First.Span.CopyTo(arr); + } + var ret = encoding.GetString(arr, offset, memory.Length); + if (rented) + ArrayPool.Shared.Return(arr); + return ret; +#endif + } + + // If the incoming sequence is multi-segment, create a stateful Decoder + // and use it as the workhorse. On the final iteration we'll pass flush=true. + + var decoder = encoding.GetDecoder(); + + // Maintain a list of all the segments we'll need to concat together. + // These will be released back to the pool at the end of the method. + + var listOfSegments = new List<(char[], int)>(); + var totalCharCount = 0; + + var remainingBytes = bytes; + bool isFinalSegment; + + do + { + var firstSpan = remainingBytes.First.Span; + var next = remainingBytes.GetPosition(firstSpan.Length); + isFinalSegment = remainingBytes.IsSingleSegment; + + var charCountThisIteration = decoder.GetCharCount(firstSpan, flush: isFinalSegment); // could throw ArgumentException if overflow would occur + var rentedArray = ArrayPool.Shared.Rent(charCountThisIteration); + var actualCharsWrittenThisIteration = decoder.GetChars(firstSpan, rentedArray, flush: isFinalSegment); + listOfSegments.Add((rentedArray, actualCharsWrittenThisIteration)); + + totalCharCount += actualCharsWrittenThisIteration; + if (totalCharCount < 0) + { + // If we overflowed, call string.Create, passing int.MaxValue. + // This will end up throwing the expected OutOfMemoryException + // since strings are limited to under int.MaxValue elements in length. + + totalCharCount = int.MaxValue; + break; + } + + remainingBytes = remainingBytes.Slice(next); + } while (!isFinalSegment); + + // Now build up the string to return, then release all of our scratch buffers + // back to the shared pool. + var chars = ArrayPool.Shared.Rent(totalCharCount); + var span = chars.AsSpan(); + foreach (var (array, length) in listOfSegments) + { + array.AsSpan(0, length).CopyTo(span); + ArrayPool.Shared.Return(array); + span = span.Slice(length); + } + + var str = new string(chars); + ArrayPool.Shared.Return(chars); + return str; + } +#endif +} diff --git a/src/Npgsql/Shims/ExperimentalAttribute.cs b/src/Npgsql/Shims/ExperimentalAttribute.cs new file mode 100644 index 0000000000..36ff9ee11d --- /dev/null +++ b/src/Npgsql/Shims/ExperimentalAttribute.cs @@ -0,0 +1,21 @@ +#if !NET8_0_OR_GREATER +namespace System.Diagnostics.CodeAnalysis; + +/// Indicates that an API is experimental and it may change in the future. +[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Enum | AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Event | AttributeTargets.Interface | AttributeTargets.Delegate, Inherited = false)] +public sealed class ExperimentalAttribute : Attribute +{ + /// Initializes a new instance of the class, specifying the ID that the compiler will use when reporting a use of the API the attribute applies to. + /// The ID that the compiler will use when reporting a use of the API the attribute applies to. + public ExperimentalAttribute(string diagnosticId) => this.DiagnosticId = diagnosticId; + + /// Gets the ID that the compiler will use when reporting a use of the API the attribute applies to. + /// The unique diagnostic ID. + public string DiagnosticId { get; } + + /// Gets or sets the URL for corresponding documentation. + /// The API accepts a format string instead of an actual URL, creating a generic URL that includes the diagnostic ID. + /// The format string that represents a URL to corresponding documentation. + public string? UrlFormat { get; set; } +} +#endif diff --git a/src/Npgsql/Shims/MemoryExtensions.cs b/src/Npgsql/Shims/MemoryExtensions.cs new file mode 100644 index 0000000000..0da143f3c4 --- /dev/null +++ b/src/Npgsql/Shims/MemoryExtensions.cs @@ -0,0 +1,20 @@ +#if !NET7_0_OR_GREATER +using System; + +namespace Npgsql; + +static class MemoryExtensions +{ + public static int IndexOfAnyExcept(this ReadOnlySpan span, T value0, T value1) where T : IEquatable + { + for (var i = 0; i < span.Length; i++) + { + var v = span[i]; + if (!v.Equals(value0) && !v.Equals(value1)) + return i; + } + + return -1; + } +} +#endif diff --git a/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs b/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs new file mode 100644 index 0000000000..0370285a7d --- /dev/null +++ b/src/Npgsql/Shims/ReadOnlySequenceExtensions.cs @@ -0,0 +1,13 @@ +namespace System.Buffers; + +static class ReadOnlySequenceExtensions +{ + public static ReadOnlySpan GetFirstSpan(this ReadOnlySequence sequence) + { +#if NETSTANDARD + return sequence.First.Span; +# else + return sequence.FirstSpan; +#endif + } +} diff --git a/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs b/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs new file mode 100644 index 0000000000..11a70c9793 --- /dev/null +++ b/src/Npgsql/Shims/ReadOnlySpanOfCharExtensions.cs @@ -0,0 +1,15 @@ +using System; +using System.Runtime.CompilerServices; + +namespace Npgsql.Netstandard20; + +static class ReadOnlySpanOfCharExtensions +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int ParseInt(this ReadOnlySpan span) + => int.Parse(span +#if NETSTANDARD2_0 + .ToString() +#endif + ); +} \ No newline at end of file diff --git a/src/Npgsql/Shims/ReferenceEqualityComparer.cs b/src/Npgsql/Shims/ReferenceEqualityComparer.cs new file mode 100644 index 0000000000..38515ed90f --- /dev/null +++ b/src/Npgsql/Shims/ReferenceEqualityComparer.cs @@ -0,0 +1,48 @@ +using System.Runtime.CompilerServices; + +namespace System.Collections.Generic; + +#if NETSTANDARD +sealed class ReferenceEqualityComparer : IEqualityComparer, IEqualityComparer +{ + ReferenceEqualityComparer() { } + + /// + /// Gets the singleton instance. + /// + public static ReferenceEqualityComparer Instance { get; } = new(); + + /// + /// Determines whether two object references refer to the same object instance. + /// + /// The first object to compare. + /// The second object to compare. + /// + /// if both and refer to the same object instance + /// or if both are ; otherwise, . + /// + /// + /// This API is a wrapper around . + /// It is not necessarily equivalent to calling . + /// + public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); + + /// + /// Returns a hash code for the specified object. The returned hash code is based on the object + /// identity, not on the contents of the object. + /// + /// The object for which to retrieve the hash code. + /// A hash code for the identity of . + /// + /// This API is a wrapper around . + /// It is not necessarily equivalent to calling . + /// + public int GetHashCode(object? obj) + { + // Depending on target framework, RuntimeHelpers.GetHashCode might not be annotated + // with the proper nullability attribute. We'll suppress any warning that might + // result. + return RuntimeHelpers.GetHashCode(obj!); + } +} +#endif diff --git a/src/Npgsql/Shims/RequiresPreviewFeaturesAttribute.cs b/src/Npgsql/Shims/RequiresPreviewFeaturesAttribute.cs new file mode 100644 index 0000000000..4f7673959f --- /dev/null +++ b/src/Npgsql/Shims/RequiresPreviewFeaturesAttribute.cs @@ -0,0 +1,48 @@ +#if !NET6_0_OR_GREATER + +// ReSharper disable once CheckNamespace +namespace System.Runtime.Versioning; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +#pragma warning disable RS0016 // Add public types and members to the declared API + +[AttributeUsage(AttributeTargets.Assembly | + AttributeTargets.Module | + AttributeTargets.Class | + AttributeTargets.Interface | + AttributeTargets.Delegate | + AttributeTargets.Struct | + AttributeTargets.Enum | + AttributeTargets.Constructor | + AttributeTargets.Method | + AttributeTargets.Property | + AttributeTargets.Field | + AttributeTargets.Event, Inherited = false)] +public sealed class RequiresPreviewFeaturesAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + public RequiresPreviewFeaturesAttribute() { } + + /// + /// Initializes a new instance of the class with the specified message. + /// + /// An optional message associated with this attribute instance. + public RequiresPreviewFeaturesAttribute(string? message) + { + Message = message; + } + + /// + /// Returns the optional message associated with this attribute instance. + /// + public string? Message { get; } + + /// + /// Returns the optional URL associated with this attribute instance. + /// + public string? Url { get; set; } +} + +#endif \ No newline at end of file diff --git a/src/Npgsql/Netstandard20/StreamExtensions.cs b/src/Npgsql/Shims/StreamExtensions.cs similarity index 67% rename from src/Npgsql/Netstandard20/StreamExtensions.cs rename to src/Npgsql/Shims/StreamExtensions.cs index 925061870d..6a6a54231b 100644 --- a/src/Npgsql/Netstandard20/StreamExtensions.cs +++ b/src/Npgsql/Shims/StreamExtensions.cs @@ -1,7 +1,9 @@ -#if NETSTANDARD2_0 +#if NETSTANDARD2_0 || !NET7_0_OR_GREATER using System.Buffers; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using Npgsql; // ReSharper disable once CheckNamespace namespace System.IO @@ -9,6 +11,33 @@ namespace System.IO // Helpers to read/write Span/Memory to Stream before netstandard 2.1 static class StreamExtensions { + public static void ReadExactly(this Stream stream, Span buffer) + { + var totalRead = 0; + while (totalRead < buffer.Length) + { + var read = stream.Read(buffer.Slice(totalRead)); + if (read is 0) + throw new EndOfStreamException(); + + totalRead += read; + } + } + + public static async ValueTask ReadExactlyAsync(this Stream stream, Memory buffer, CancellationToken cancellationToken = default) + { + var totalRead = 0; + while (totalRead < buffer.Length) + { + var read = await stream.ReadAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); + if (read is 0) + throw new EndOfStreamException(); + + totalRead += read; + } + } + +#if NETSTANDARD2_0 public static int Read(this Stream stream, Span buffer) { var sharedBuffer = ArrayPool.Shared.Rent(buffer.Length); @@ -29,7 +58,7 @@ public static async ValueTask ReadAsync(this Stream stream, Memory bu var sharedBuffer = ArrayPool.Shared.Rent(buffer.Length); try { - var result = await stream.ReadAsync(sharedBuffer, 0, buffer.Length, cancellationToken); + var result = await stream.ReadAsync(sharedBuffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); new Span(sharedBuffer, 0, result).CopyTo(buffer.Span); return result; } @@ -59,13 +88,14 @@ public static async ValueTask WriteAsync(this Stream stream, ReadOnlyMemory.Shared.Return(sharedBuffer); } } +#endif } } #endif diff --git a/src/Npgsql/Netstandard20/StringBuilderExtensions.cs b/src/Npgsql/Shims/StringBuilderExtensions.cs similarity index 100% rename from src/Npgsql/Netstandard20/StringBuilderExtensions.cs rename to src/Npgsql/Shims/StringBuilderExtensions.cs diff --git a/src/Npgsql/Shims/TaskExtensions.cs b/src/Npgsql/Shims/TaskExtensions.cs new file mode 100644 index 0000000000..a7d56948e9 --- /dev/null +++ b/src/Npgsql/Shims/TaskExtensions.cs @@ -0,0 +1,65 @@ +#if !NET6_0_OR_GREATER +using System.Collections.Generic; + +namespace System.Threading.Tasks; + +static class TaskExtensions +{ + /// + /// Gets a that will complete when this completes, when the specified timeout expires, or when the specified has cancellation requested. + /// + /// The representing the asynchronous wait. + /// The timeout after which the should be faulted with a if it hasn't otherwise completed. + /// The to monitor for a cancellation request. + /// The representing the asynchronous wait. + /// This method reproduces new to the .NET 6.0 API .WaitAsync. + public static async Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken) + { + var tasks = new List(3); + + Task? cancellationTask = default; + CancellationTokenRegistration registration = default; + if (cancellationToken.CanBeCanceled) + { + var tcs = new TaskCompletionSource(); + registration = cancellationToken.Register(s => ((TaskCompletionSource)s!).TrySetResult(true), tcs); + cancellationTask = tcs.Task; + tasks.Add(cancellationTask); + } + + Task? delayTask = default; + CancellationTokenSource? delayCts = default; + if (timeout != Timeout.InfiniteTimeSpan) + { + var timeLeft = timeout; + delayCts = new CancellationTokenSource(); + delayTask = Task.Delay(timeLeft, delayCts.Token); + tasks.Add(delayTask); + } + + try + { + if (tasks.Count != 0) + { + tasks.Add(task); + var result = await Task.WhenAny(tasks).ConfigureAwait(false); + if (result == cancellationTask) + { + task = Task.FromCanceled(cancellationToken); + } + else if (result == delayTask) + { + task = Task.FromException(new TimeoutException()); + } + } + await task.ConfigureAwait(false); + } + finally + { + delayCts?.Cancel(); + delayCts?.Dispose(); + registration.Dispose(); + } + } +} +#endif diff --git a/src/Npgsql/Netstandard20/UnixDomainSocketEndPoint.cs b/src/Npgsql/Shims/UnixDomainSocketEndPoint.cs similarity index 98% rename from src/Npgsql/Netstandard20/UnixDomainSocketEndPoint.cs rename to src/Npgsql/Shims/UnixDomainSocketEndPoint.cs index 60430f8a76..6135590493 100644 --- a/src/Npgsql/Netstandard20/UnixDomainSocketEndPoint.cs +++ b/src/Npgsql/Shims/UnixDomainSocketEndPoint.cs @@ -6,7 +6,7 @@ namespace System.Net { // Copied and adapted from https://github.com/mono/mono/blob/master/mcs/class/Mono.Posix/Mono.Unix/UnixEndPoint.cs - class UnixDomainSocketEndPoint : EndPoint + sealed class UnixDomainSocketEndPoint : EndPoint { string _filename; diff --git a/src/Npgsql/Shims/UnreachableException.cs b/src/Npgsql/Shims/UnreachableException.cs new file mode 100644 index 0000000000..c45f3fd1d8 --- /dev/null +++ b/src/Npgsql/Shims/UnreachableException.cs @@ -0,0 +1,41 @@ +#if !NET7_0_OR_GREATER +using System; + +namespace System.Diagnostics; + +/// +/// Exception thrown when the program executes an instruction that was thought to be unreachable. +/// +sealed class UnreachableException : Exception +{ + /// + /// Initializes a new instance of the class with the default error message. + /// + public UnreachableException() + : base("The program executed an instruction that was thought to be unreachable.") + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public UnreachableException(string? message) + : base(message) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message and a reference to the inner exception that is the cause of + /// this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. + public UnreachableException(string? message, Exception? innerException) + : base(message, innerException) + { + } +} +#endif diff --git a/src/Npgsql/Netstandard20/WaitHandleExtensions.cs b/src/Npgsql/Shims/WaitHandleExtensions.cs similarity index 96% rename from src/Npgsql/Netstandard20/WaitHandleExtensions.cs rename to src/Npgsql/Shims/WaitHandleExtensions.cs index dbb3cc4259..5f746cc296 100644 --- a/src/Npgsql/Netstandard20/WaitHandleExtensions.cs +++ b/src/Npgsql/Shims/WaitHandleExtensions.cs @@ -23,7 +23,7 @@ internal static async Task WaitOneAsync( state: tcs, millisecondsTimeout, executeOnlyOnce: true); - return await tcs.Task; + return await tcs.Task.ConfigureAwait(false); } finally { diff --git a/src/Npgsql/SingleThreadSynchronizationContext.cs b/src/Npgsql/SingleThreadSynchronizationContext.cs deleted file mode 100644 index 6650d09cb0..0000000000 --- a/src/Npgsql/SingleThreadSynchronizationContext.cs +++ /dev/null @@ -1,92 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Threading; -using Npgsql.Logging; - -namespace Npgsql -{ - sealed class SingleThreadSynchronizationContext : SynchronizationContext, IDisposable - { - readonly BlockingCollection _tasks = new BlockingCollection(); - Thread? _thread; - - const int ThreadStayAliveMs = 10000; - readonly string _threadName; - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(SingleThreadSynchronizationContext)); - - internal SingleThreadSynchronizationContext(string threadName) - => _threadName = threadName; - - internal Disposable Enter() => new Disposable(this); - - public override void Post(SendOrPostCallback callback, object? state) - { - _tasks.Add(new CallbackAndState { Callback = callback, State = state }); - - if (_thread == null) - { - lock (this) - { - if (_thread != null) - return; - _thread = new Thread(WorkLoop) { Name = _threadName, IsBackground = true }; - _thread.Start(); - } - } - } - - public void Dispose() - { - _tasks.CompleteAdding(); - _tasks.Dispose(); - - lock (this) - { - _thread?.Join(); - } - } - - void WorkLoop() - { - try - { - while (true) - { - var taken = _tasks.TryTake(out var callbackAndState, ThreadStayAliveMs); - if (!taken) - return; - callbackAndState.Callback(callbackAndState.State); - } - } - catch (Exception e) - { - Log.Error($"Exception caught in {nameof(SingleThreadSynchronizationContext)}", e); - } - finally - { - lock (this) { _thread = null; } - } - } - - struct CallbackAndState - { - internal SendOrPostCallback Callback; - internal object? State; - } - - internal struct Disposable : IDisposable - { - readonly SynchronizationContext? _synchronizationContext; - - internal Disposable(SynchronizationContext synchronizationContext) - { - _synchronizationContext = Current; - SetSynchronizationContext(synchronizationContext); - } - - public void Dispose() - => SetSynchronizationContext(_synchronizationContext); - } - } -} diff --git a/src/Npgsql/SqlQueryParser.cs b/src/Npgsql/SqlQueryParser.cs index cce56a8ce1..2e9e37a010 100644 --- a/src/Npgsql/SqlQueryParser.cs +++ b/src/Npgsql/SqlQueryParser.cs @@ -3,467 +3,532 @@ using System.Diagnostics; using System.Text; -namespace Npgsql +namespace Npgsql; + +sealed class SqlQueryParser { - class SqlQueryParser + static NpgsqlParameterCollection EmptyParameters { get; } = new(); + + readonly Dictionary _paramIndexMap = new(StringComparer.OrdinalIgnoreCase); + readonly StringBuilder _rewrittenSql = new(); + + /// + ///

+ /// Receives a user SQL query as passed in by the user in or + /// , and rewrites it for PostgreSQL compatibility. + ///

+ ///

+ /// This includes doing rewriting named parameter placeholders to positional (@p => $1), and splitting the query + /// up by semicolons (legacy batching, SELECT 1; SELECT 2). + ///

+ ///
+ /// The user-facing being executed. + /// Whether PostgreSQL standards-conforming are used. + /// + /// A bool indicating whether parameters contains a list of preconfigured parameters or an empty list to be filled with derived + /// parameters. + /// + internal void ParseRawQuery( + NpgsqlCommand? command, + bool standardConformingStrings = true, + bool deriveParameters = false) + => ParseRawQuery(command, batchCommand: null, standardConformingStrings, deriveParameters); + + /// + ///

+ /// Receives a user SQL query as passed in by the user in or + /// , and rewrites it for PostgreSQL compatibility. + ///

+ ///

+ /// This includes doing rewriting named parameter placeholders to positional (@p => $1), and splitting the query + /// up by semicolons (legacy batching, SELECT 1; SELECT 2). + ///

+ ///
+ /// The user-facing being executed. + /// Whether PostgreSQL standards-conforming are used. + /// + /// A bool indicating whether parameters contains a list of preconfigured parameters or an empty list to be filled with derived + /// parameters. + /// + internal void ParseRawQuery( + NpgsqlBatchCommand? batchCommand, + bool standardConformingStrings = true, + bool deriveParameters = false) + => ParseRawQuery(command: null, batchCommand, standardConformingStrings, deriveParameters); + + void ParseRawQuery( + NpgsqlCommand? command, + NpgsqlBatchCommand? batchCommand, + bool standardConformingStrings = true, + bool deriveParameters = false) { - readonly Dictionary _paramIndexMap = new Dictionary(); - readonly StringBuilder _rewrittenSql = new StringBuilder(); - - /// - /// Receives a raw SQL query as passed in by the user, and performs some processing necessary - /// before sending to the backend. - /// This includes doing parameter placeholder processing (@p => $1), and splitting the query - /// up by semicolons if needed (SELECT 1; SELECT 2) - /// - /// Raw user-provided query. - /// The parameters configured on the of this query - /// or an empty if deriveParameters is set to true. - /// An empty list to be populated with the statements parsed by this method - /// A bool indicating whether parameters contains a list of preconfigured parameters or an empty list to be filled with derived parameters. - internal void ParseRawQuery( - string sql, - NpgsqlParameterCollection parameters, - List statements, - bool deriveParameters = false) - => ParseRawQuery(sql.AsSpan(), parameters, statements, deriveParameters); - - void ParseRawQuery( - ReadOnlySpan sql, - NpgsqlParameterCollection parameters, - List statements, - bool deriveParameters) + string sql; + NpgsqlParameterCollection parameters; + List? batchCommands; + + var statementIndex = 0; + if (command is null) + { + // Batching mode. We're processing only one batch - if we encounter a semicolon (legacy batching), that's an error. + Debug.Assert(batchCommand is not null); + sql = batchCommand.CommandText; + parameters = batchCommand._parameters ?? EmptyParameters; + batchCommands = null; + } + else { - Debug.Assert(deriveParameters == false || parameters.Count == 0); + // Command mode. Semicolons (legacy batching) may occur. + Debug.Assert(batchCommand is null); + sql = command.CommandText; + parameters = command._parameters ?? EmptyParameters; + batchCommands = command.InternalBatchCommands; + MoveToNextBatchCommand(); + } + + Debug.Assert(batchCommand is not null); + Debug.Assert(parameters.PlaceholderType != PlaceholderType.Positional); + Debug.Assert(deriveParameters == false || parameters.Count == 0); + // Debug.Assert(batchCommand.PositionalParameters is not null && batchCommand.PositionalParameters.Count == 0); - NpgsqlStatement statement = null!; - var statementIndex = -1; - MoveToNextStatement(); + _paramIndexMap.Clear(); + _rewrittenSql.Clear(); - var currCharOfs = 0; - var end = sql.Length; - var ch = '\0'; - int dollarTagStart; - int dollarTagEnd; - var currTokenBeg = 0; - var blockCommentLevel = 0; - var parenthesisLevel = 0; + var currCharOfs = 0; + var end = sql.Length; + var ch = '\0'; + int dollarTagStart; + int dollarTagEnd; + var currTokenBeg = 0; + var blockCommentLevel = 0; + var parenthesisLevel = 0; None: - if (currCharOfs >= end) - goto Finish; - var lastChar = ch; - ch = sql[currCharOfs++]; + if (currCharOfs >= end) + goto Finish; + var lastChar = ch; + ch = sql[currCharOfs++]; NoneContinue: - for (; ; lastChar = ch, ch = sql[currCharOfs++]) + while (true) + { + switch (ch) { - switch (ch) - { - case '/': - goto BlockCommentBegin; - case '-': - goto LineCommentBegin; - case '\'': + case '/': + goto BlockCommentBegin; + case '-': + goto LineCommentBegin; + case '\'': + if (standardConformingStrings) goto Quoted; - case '$': - if (!IsIdentifier(lastChar)) - goto DollarQuotedStart; - else - break; - case '"': - goto DoubleQuoted; - case ':': - if (lastChar != ':') - goto ParamStart; - else - break; - case '@': - if (lastChar != '@') - goto ParamStart; - else - break; - case ';': - if (parenthesisLevel == 0) - goto SemiColon; - break; - case '(': - parenthesisLevel++; - break; - case ')': - parenthesisLevel--; - break; - case 'e': - case 'E': - if (!IsLetter(lastChar)) - goto EscapedStart; - else - break; - } - - if (currCharOfs >= end) - goto Finish; + goto Escaped; + case '$': + if (!IsIdentifier(lastChar)) + goto DollarQuotedStart; + break; + case '"': + goto Quoted; + case ':': + if (lastChar != ':') + goto NamedParamStart; + break; + case '@': + if (lastChar != '@') + goto NamedParamStart; + break; + case ';': + if (parenthesisLevel == 0) + goto SemiColon; + break; + case '(': + parenthesisLevel++; + break; + case ')': + parenthesisLevel--; + break; + case 'e': + case 'E': + if (!IsLetter(lastChar)) + goto EscapedStart; + break; } - ParamStart: - if (currCharOfs < end) + if (currCharOfs >= end) + goto Finish; + + lastChar = ch; + ch = sql[currCharOfs++]; + } + + NamedParamStart: + if (currCharOfs < end) + { + lastChar = ch; + ch = sql[currCharOfs]; + if (IsParamNameChar(ch)) { - lastChar = ch; - ch = sql[currCharOfs]; - if (IsParamNameChar(ch)) - { - if (currCharOfs - 1 > currTokenBeg) - _rewrittenSql.Append(sql.Slice(currTokenBeg, currCharOfs - 1 - currTokenBeg)); - currTokenBeg = currCharOfs++ - 1; - goto Param; - } - currCharOfs++; - goto NoneContinue; + if (currCharOfs - 1 > currTokenBeg) + _rewrittenSql.Append(sql, currTokenBeg, currCharOfs - 1 - currTokenBeg); + currTokenBeg = currCharOfs++ - 1; + goto NamedParam; } - goto Finish; + currCharOfs++; + goto NoneContinue; + } + goto Finish; - Param: - // We have already at least one character of the param name - for (;;) + NamedParam: + // We have already at least one character of the param name + while (true) + { + lastChar = ch; + if (currCharOfs >= end || !IsParamNameChar(ch = sql[currCharOfs])) { - lastChar = ch; - if (currCharOfs >= end || !IsParamNameChar(ch = sql[currCharOfs])) - { - var paramName = sql.Slice(currTokenBeg + 1, currCharOfs - (currTokenBeg + 1)).ToString(); + var paramName = sql.Substring(currTokenBeg + 1, currCharOfs - (currTokenBeg + 1)); - if (!_paramIndexMap.TryGetValue(paramName, out var index)) + if (!_paramIndexMap.TryGetValue(paramName, out var index)) + { + // Parameter hasn't been seen before in this query + if (!parameters.TryGetValue(paramName, out var parameter)) { - // Parameter hasn't been seen before in this query - if (!parameters.TryGetValue(paramName, out var parameter)) + if (deriveParameters) { - if (deriveParameters) - { - parameter = new NpgsqlParameter { ParameterName = paramName }; - parameters.Add(parameter); - } - else - { - // Parameter placeholder does not match a parameter on this command. - // Leave the text as it was in the SQL, it may not be a an actual placeholder - _rewrittenSql.Append(sql.Slice(currTokenBeg, currCharOfs - currTokenBeg)); - currTokenBeg = currCharOfs; - if (currCharOfs >= end) - goto Finish; - - currCharOfs++; - goto NoneContinue; - } + parameter = new NpgsqlParameter { ParameterName = paramName }; + parameters.Add(parameter); + } + else + { + // Parameter placeholder does not match a parameter on this command. + // Leave the text as it was in the SQL, it may not be a an actual placeholder + _rewrittenSql.Append(sql, currTokenBeg, currCharOfs - currTokenBeg); + currTokenBeg = currCharOfs; + if (currCharOfs >= end) + goto Finish; + + currCharOfs++; + goto NoneContinue; } - - if (!parameter.IsInputDirection) - throw new Exception($"Parameter '{paramName}' referenced in SQL but is an out-only parameter"); - - statement.InputParameters.Add(parameter); - index = _paramIndexMap[paramName] = statement.InputParameters.Count; } - _rewrittenSql.Append('$'); - _rewrittenSql.Append(index); - currTokenBeg = currCharOfs; - if (currCharOfs >= end) - goto Finish; + if (!parameter.IsInputDirection) + ThrowHelper.ThrowInvalidOperationException("Parameter '{0}' referenced in SQL but is an out-only parameter", paramName); - currCharOfs++; - goto NoneContinue; + batchCommand.PositionalParameters.Add(parameter); + index = _paramIndexMap[paramName] = batchCommand.PositionalParameters.Count; } + _rewrittenSql.Append('$'); + _rewrittenSql.Append(index); + currTokenBeg = currCharOfs; + + if (currCharOfs >= end) + goto Finish; currCharOfs++; + goto NoneContinue; } - Quoted: - while (currCharOfs < end) - { - if (sql[currCharOfs++] == '\'') - { - ch = '\0'; - goto None; - } - } - goto Finish; + currCharOfs++; + } - DoubleQuoted: - while (currCharOfs < end) - { - if (sql[currCharOfs++] == '"') - { - ch = '\0'; - goto None; - } - } - goto Finish; + Quoted: + Debug.Assert(ch == '\'' || ch == '"'); + while (currCharOfs < end && sql[currCharOfs] != ch) + { + currCharOfs++; + } + if (currCharOfs < end) + { + currCharOfs++; + ch = '\0'; + goto None; + } + goto Finish; EscapedStart: - if (currCharOfs < end) - { - lastChar = ch; - ch = sql[currCharOfs++]; - if (ch == '\'') - goto Escaped; - goto NoneContinue; - } - goto Finish; + if (currCharOfs < end) + { + lastChar = ch; + ch = sql[currCharOfs++]; + if (ch == '\'') + goto Escaped; + goto NoneContinue; + } + goto Finish; Escaped: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + switch (ch) { - ch = sql[currCharOfs++]; - switch (ch) - { - case '\'': - goto MaybeConcatenatedEscaped; - case '\\': - { - if (currCharOfs >= end) - goto Finish; - currCharOfs++; - break; - } - } + case '\'': + goto MaybeConcatenatedEscaped; + case '\\': + { + if (currCharOfs >= end) + goto Finish; + currCharOfs++; + break; } - goto Finish; + } + } + goto Finish; MaybeConcatenatedEscaped: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + switch (ch) { - ch = sql[currCharOfs++]; - switch (ch) - { - case '\r': - case '\n': - goto MaybeConcatenatedEscaped2; - case ' ': - case '\t': - case '\f': - continue; - default: - lastChar = '\0'; - goto NoneContinue; - } + case '\r': + case '\n': + goto MaybeConcatenatedEscaped2; + case ' ': + case '\t': + case '\f': + continue; + default: + lastChar = '\0'; + goto NoneContinue; } - goto Finish; + } + goto Finish; MaybeConcatenatedEscaped2: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + switch (ch) { + case '\'': + goto Escaped; + case '-': + { + if (currCharOfs >= end) + goto Finish; ch = sql[currCharOfs++]; - switch (ch) - { - case '\'': - goto Escaped; - case '-': - { - if (currCharOfs >= end) - goto Finish; - ch = sql[currCharOfs++]; - if (ch == '-') - goto MaybeConcatenatedEscapeAfterComment; - lastChar = '\0'; - goto NoneContinue; - } - case ' ': - case '\t': - case '\n': - case '\r': - case '\f': - continue; - default: - lastChar = '\0'; - goto NoneContinue; - } + if (ch == '-') + goto MaybeConcatenatedEscapeAfterComment; + lastChar = '\0'; + goto NoneContinue; } - goto Finish; + case ' ': + case '\t': + case '\n': + case '\r': + case '\f': + continue; + default: + lastChar = '\0'; + goto NoneContinue; + } + } + goto Finish; MaybeConcatenatedEscapeAfterComment: - while (currCharOfs < end) - { - ch = sql[currCharOfs++]; - if (ch == '\r' || ch == '\n') - goto MaybeConcatenatedEscaped2; - } - goto Finish; + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + if (ch == '\r' || ch == '\n') + goto MaybeConcatenatedEscaped2; + } + goto Finish; DollarQuotedStart: - if (currCharOfs < end) + if (currCharOfs < end) + { + ch = sql[currCharOfs]; + if (ch == '$') { - ch = sql[currCharOfs]; - if (ch == '$') - { - // Empty tag - dollarTagStart = dollarTagEnd = currCharOfs; - currCharOfs++; - goto DollarQuoted; - } - if (IsIdentifierStart(ch)) - { - dollarTagStart = currCharOfs; - currCharOfs++; - goto DollarQuotedInFirstDelim; - } - lastChar = '$'; + // Empty tag + dollarTagStart = dollarTagEnd = currCharOfs; currCharOfs++; - goto NoneContinue; + goto DollarQuoted; } - goto Finish; + if (IsIdentifierStart(ch)) + { + dollarTagStart = currCharOfs; + currCharOfs++; + goto DollarQuotedInFirstDelim; + } + lastChar = '$'; + currCharOfs++; + goto NoneContinue; + } + goto Finish; DollarQuotedInFirstDelim: - while (currCharOfs < end) + while (currCharOfs < end) + { + lastChar = ch; + ch = sql[currCharOfs++]; + if (ch == '$') { - lastChar = ch; - ch = sql[currCharOfs++]; - if (ch == '$') - { - dollarTagEnd = currCharOfs - 1; - goto DollarQuoted; - } - if (!IsDollarTagIdentifier(ch)) - goto NoneContinue; + dollarTagEnd = currCharOfs - 1; + goto DollarQuoted; } - goto Finish; + if (!IsDollarTagIdentifier(ch)) + goto NoneContinue; + } + goto Finish; DollarQuoted: - var tag = sql.Slice(dollarTagStart - 1, dollarTagEnd - dollarTagStart + 2); - var pos = sql.Slice(dollarTagEnd + 1).IndexOf(tag); - if (pos == -1) - { - currCharOfs = end; - goto Finish; - } - pos += dollarTagEnd + 1; // If the substring is found adjust the position to be relative to the entire span - currCharOfs = pos + dollarTagEnd - dollarTagStart + 2; - ch = '\0'; - goto None; + var tag = sql.AsSpan(dollarTagStart - 1, dollarTagEnd - dollarTagStart + 2); + var pos = sql.AsSpan(dollarTagEnd + 1).IndexOf(tag); + if (pos == -1) + { + currCharOfs = end; + goto Finish; + } + pos += dollarTagEnd + 1; // If the substring is found adjust the position to be relative to the entire string + currCharOfs = pos + dollarTagEnd - dollarTagStart + 2; + ch = '\0'; + goto None; LineCommentBegin: - if (currCharOfs < end) - { - ch = sql[currCharOfs++]; - if (ch == '-') - goto LineComment; - lastChar = '\0'; - goto NoneContinue; - } - goto Finish; + if (currCharOfs < end) + { + ch = sql[currCharOfs++]; + if (ch == '-') + goto LineComment; + lastChar = '\0'; + goto NoneContinue; + } + goto Finish; LineComment: - while (currCharOfs < end) - { - ch = sql[currCharOfs++]; - if (ch == '\r' || ch == '\n') - goto None; - } - goto Finish; + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + if (ch == '\r' || ch == '\n') + goto None; + } + goto Finish; BlockCommentBegin: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + if (ch == '*') { - ch = sql[currCharOfs++]; - if (ch == '*') - { - blockCommentLevel++; + blockCommentLevel++; + goto BlockComment; + } + if (ch != '/') + { + if (blockCommentLevel > 0) goto BlockComment; - } - if (ch != '/') - { - if (blockCommentLevel > 0) - goto BlockComment; - lastChar = '\0'; - goto NoneContinue; - } + lastChar = '\0'; + goto NoneContinue; } - goto Finish; + } + goto Finish; BlockComment: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + switch (ch) { - ch = sql[currCharOfs++]; - switch (ch) - { - case '*': - goto BlockCommentEnd; - case '/': - goto BlockCommentBegin; - } + case '*': + goto BlockCommentEnd; + case '/': + goto BlockCommentBegin; } - goto Finish; + } + goto Finish; BlockCommentEnd: - while (currCharOfs < end) + while (currCharOfs < end) + { + ch = sql[currCharOfs++]; + if (ch == '/') { - ch = sql[currCharOfs++]; - if (ch == '/') - { - if (--blockCommentLevel > 0) - goto BlockComment; - goto None; - } - if (ch != '*') + if (--blockCommentLevel > 0) goto BlockComment; + goto None; } - goto Finish; + if (ch != '*') + goto BlockComment; + } + goto Finish; SemiColon: - _rewrittenSql.Append(sql.Slice(currTokenBeg, currCharOfs - currTokenBeg - 1)); - statement.SQL = _rewrittenSql.ToString(); - while (currCharOfs < end) + _rewrittenSql.Append(sql, currTokenBeg, currCharOfs - currTokenBeg - 1); + batchCommand.FinalCommandText = _rewrittenSql.ToString(); + while (currCharOfs < end) + { + ch = sql[currCharOfs]; + if (char.IsWhiteSpace(ch)) { - ch = sql[currCharOfs]; - if (char.IsWhiteSpace(ch)) - { - currCharOfs++; - continue; - } - // TODO: Handle end of line comment? Although psql doesn't seem to handle them... + currCharOfs++; + continue; + } + // TODO: Handle end of line comment? Although psql doesn't seem to handle them... - currTokenBeg = currCharOfs; - if (_rewrittenSql.Length > 0) - MoveToNextStatement(); - goto None; + // We've found a non-whitespace character after a semicolon - this is legacy batching. + + if (command is null) + { + ThrowHelper.ThrowNotSupportedException($"Specifying multiple SQL statements in a single {nameof(NpgsqlBatchCommand)} isn't supported, " + + "please remove all semicolons."); } - if (statements.Count > statementIndex + 1) - statements.RemoveRange(statementIndex + 1, statements.Count - (statementIndex + 1)); - return; - Finish: - _rewrittenSql.Append(sql.Slice(currTokenBeg, end - currTokenBeg)); - statement.SQL = _rewrittenSql.ToString(); - if (statements.Count > statementIndex + 1) - statements.RemoveRange(statementIndex + 1, statements.Count - (statementIndex + 1)); + statementIndex++; + MoveToNextBatchCommand(); + _paramIndexMap.Clear(); + _rewrittenSql.Clear(); + + currTokenBeg = currCharOfs; + goto None; + } + if (batchCommands is not null && batchCommands.Count > statementIndex + 1) + batchCommands.RemoveRange(statementIndex + 1, batchCommands.Count - (statementIndex + 1)); + return; - void MoveToNextStatement() + Finish: + _rewrittenSql.Append(sql, currTokenBeg, end - currTokenBeg); + if (statementIndex is 0 && _paramIndexMap.Count is 0) + // Single statement, no parameters, no rewriting necessary + batchCommand.FinalCommandText = sql; + else + batchCommand.FinalCommandText = _rewrittenSql.ToString(); + if (batchCommands is not null && batchCommands.Count > statementIndex + 1) + batchCommands.RemoveRange(statementIndex + 1, batchCommands.Count - (statementIndex + 1)); + + void MoveToNextBatchCommand() + { + Debug.Assert(batchCommands is not null); + if (batchCommands.Count > statementIndex) { - statementIndex++; - if (statements.Count > statementIndex) - { - statement = statements[statementIndex]; - statement.Reset(); - } - else - { - statement = new NpgsqlStatement(); - statements.Add(statement); - } - _paramIndexMap.Clear(); - _rewrittenSql.Clear(); + batchCommand = batchCommands[statementIndex]; + batchCommand.Reset(); + batchCommand._parameters = parameters; + } + else + { + batchCommand = new NpgsqlBatchCommand { _parameters = parameters }; + batchCommands.Add(batchCommand); } } + } - static bool IsLetter(char ch) - => 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z'; + // Is ASCII letter comparison optimization https://github.com/dotnet/runtime/blob/60cfaec2e6cffeb9a006bec4b8908ffcf71ac5b4/src/libraries/System.Private.CoreLib/src/System/Char.cs#L236 - static bool IsIdentifierStart(char ch) - => 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || 128 <= ch && ch <= 255; + static bool IsLetter(char ch) + // [a-zA-Z] + => (uint)((ch | 0x20) - 'a') <= ('z' - 'a'); - static bool IsDollarTagIdentifier(char ch) - => 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || '0' <= ch && ch <= '9' || ch == '_' || 128 <= ch && ch <= 255; + static bool IsIdentifierStart(char ch) + // [a-zA-Z_\x80-\xFF] + => (uint)((ch | 0x20) - 'a') <= ('z' - 'a') || ch == '_' || (uint)(ch - 128) <= 127u; - static bool IsIdentifier(char ch) - => 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || '0' <= ch && ch <= '9' || ch == '_' || ch == '$' || 128 <= ch && ch <= 255; + static bool IsDollarTagIdentifier(char ch) + // [a-zA-Z0-9_\x80-\xFF] + => (uint)((ch | 0x20) - 'a') <= ('z' - 'a') || (uint)(ch - '0') <= ('9' - '0') || ch == '_' || (uint)(ch - 128) <= 127u; - static bool IsParamNameChar(char ch) - => char.IsLetterOrDigit(ch) || ch == '_' || ch == '.'; // why dot?? - } + static bool IsIdentifier(char ch) + // [a-zA-Z0-9_$\x80-\xFF] + => (uint)((ch | 0x20) - 'a') <= ('z' - 'a') || (uint)(ch - '0') <= ('9' - '0') || ch == '_' || ch == '$' || (uint)(ch - 128) <= 127u; + + static bool IsParamNameChar(char ch) + => char.IsLetterOrDigit(ch) || ch == '_' || ch == '.'; // why dot?? } diff --git a/src/Npgsql/TargetSessionAttributes.cs b/src/Npgsql/TargetSessionAttributes.cs new file mode 100644 index 0000000000..43ce3344a2 --- /dev/null +++ b/src/Npgsql/TargetSessionAttributes.cs @@ -0,0 +1,45 @@ +namespace Npgsql; + +#pragma warning disable RS0016 + +/// +/// Specifies server type preference. +/// +public enum TargetSessionAttributes : byte +{ + /// + /// Any successful connection is acceptable. + /// + Any = 0, + + /// + /// Session must accept read-write transactions by default (that is, the server must not be in hot standby mode and the + /// default_transaction_read_only parameter must be off). + /// + ReadWrite = 1, + + /// + /// Session must not accept read-write transactions by default (the converse). + /// + ReadOnly = 2, + + /// + /// Server must not be in hot standby mode. + /// + Primary = 3, + + /// + /// Server must be in hot standby mode. + /// + Standby = 4, + + /// + /// First try to find a primary server, but if none of the listed hosts is a primary server, try again in mode. + /// + PreferPrimary = 5, + + /// + /// First try to find a standby server, but if none of the listed hosts is a standby server, try again in mode. + /// + PreferStandby = 6, +} diff --git a/src/Npgsql/TaskExtensions.cs b/src/Npgsql/TaskExtensions.cs deleted file mode 100644 index 4d2c8f12c2..0000000000 --- a/src/Npgsql/TaskExtensions.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Util; - -namespace Npgsql -{ - static class TaskExtensions - { - /// - /// Utility that simplifies awaiting a task with a timeout. If the given task does not - /// complete within , a is thrown. - /// - /// The task to be awaited - /// How much time to allow to complete before throwing a - /// An awaitable task that represents the original task plus the timeout - internal static async Task WithTimeout(this Task task, NpgsqlTimeout timeout) - { - if (!timeout.IsSet) - return await task; - var timeLeft = timeout.TimeLeft; - if (timeLeft <= TimeSpan.Zero) - throw new TimeoutException(); - if (task != await Task.WhenAny(task, Task.Delay(timeLeft))) - throw new TimeoutException(); - return await task; - } - - /// - /// Allows you to cancel awaiting for a non-cancellable task. - /// - /// - /// Read https://blogs.msdn.com/b/pfxteam/archive/2012/10/05/how-do-i-cancel-non-cancelable-async-operations.aspx - /// and be very careful with this. - /// - internal static async Task WithCancellation(this Task task, CancellationToken cancellationToken) - { - var tcs = new TaskCompletionSource(); - using (cancellationToken.Register(s => ((TaskCompletionSource)s!).TrySetResult(true), tcs)) - if (task != await Task.WhenAny(task, tcs.Task)) - throw new TaskCanceledException(task); - return await task; - } - - internal static Task WithCancellationAndTimeout(this Task task, NpgsqlTimeout timeout, CancellationToken cancellationToken) - => task.WithCancellation(cancellationToken).WithTimeout(timeout); - -#if NETSTANDARD2_0 || NETSTANDARD2_1 || NETCOREAPP3_1 - /// - /// Utility that simplifies awaiting a task with a timeout. If the given task does not - /// complete within , a is thrown. - /// - /// The task to be awaited - /// How much time to allow to complete before throwing a - /// An awaitable task that represents the original task plus the timeout - internal static async Task WithTimeout(this Task task, NpgsqlTimeout timeout) - { - if (!timeout.IsSet) - { - await task; - return; - } - var timeLeft = timeout.TimeLeft; - if (timeLeft <= TimeSpan.Zero) - throw new TimeoutException(); - if (task != await Task.WhenAny(task, Task.Delay(timeLeft))) - throw new TimeoutException(); - await task; - } - - /// - /// Allows you to cancel awaiting for a non-cancellable task. - /// - /// - /// Read https://blogs.msdn.com/b/pfxteam/archive/2012/10/05/how-do-i-cancel-non-cancelable-async-operations.aspx - /// and be very careful with this. - /// - internal static async Task WithCancellation(this Task task, CancellationToken cancellationToken) - { - var tcs = new TaskCompletionSource(); - using (cancellationToken.Register(s => ((TaskCompletionSource)s!).TrySetResult(true), tcs)) - if (task != await Task.WhenAny(task, tcs.Task)) - throw new TaskCanceledException(task); - await task; - } - - internal static Task WithCancellationAndTimeout(this Task task, NpgsqlTimeout timeout, CancellationToken cancellationToken) - => task.WithCancellation(cancellationToken).WithTimeout(timeout); -#endif - } -} diff --git a/src/Npgsql/TaskTimeoutAndCancellation.cs b/src/Npgsql/TaskTimeoutAndCancellation.cs new file mode 100644 index 0000000000..ceed87ba94 --- /dev/null +++ b/src/Npgsql/TaskTimeoutAndCancellation.cs @@ -0,0 +1,66 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Util; + +namespace Npgsql; + +/// +/// Utility class to execute a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it and at the same time prevent event if the original fails later. +/// +static class TaskTimeoutAndCancellation +{ + /// + /// Executes a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it. + /// If the given task does not complete within , a is thrown. + /// The executed may be left in an incomplete state after the that this method returns completes dues to timeout and/or cancellation request. + /// The method guarantees that the abandoned, incomplete is not going to produce event if it fails later. + /// + /// Gets the for execution with a combined that attempts to cancel the in an event of the timeout or external cancellation request. + /// The timeout after which the should be faulted with a if it hasn't otherwise completed. + /// The to monitor for a cancellation request. + /// The result . + /// The representing the asynchronous wait. + internal static async Task ExecuteAsync(Func> getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken) + { + Task? task = default; + await ExecuteAsync(ct => (Task)(task = getTaskFunc(ct)), timeout, cancellationToken).ConfigureAwait(false); + return await task!.ConfigureAwait(false); + } + + /// + /// Executes a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it. + /// If the given task does not complete within , a is thrown. + /// The executed may be left in an incomplete state after the that this method returns completes dues to timeout and/or cancellation request. + /// The method guarantees that the abandoned, incomplete is not going to produce event if it fails later. + /// + /// Gets the for execution with a combined that attempts to cancel the in an event of the timeout or external cancellation request. + /// The timeout after which the should be faulted with a if it hasn't otherwise completed. + /// The to monitor for a cancellation request. + /// The representing the asynchronous wait. + internal static async Task ExecuteAsync(Func getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken) + { + using var combinedCts = timeout.IsSet ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : null; + var task = getTaskFunc(combinedCts?.Token ?? cancellationToken); + try + { + try + { + await task.WaitAsync(timeout.CheckAndGetTimeLeft(), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) when (!task!.IsCompleted) + { + // Attempt to stop the Task in progress. + combinedCts?.Cancel(); + throw; + } + } + catch + { + // Prevent unobserved Task notifications by observing the failed Task exception. + // To test: comment the next line out and re-run TaskExtensionsTest.DelayedFaultedTaskCancellation. + _ = task.ContinueWith(t => _ = t.Exception, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Current); + throw; + } + } +} diff --git a/src/Npgsql/ThrowHelper.cs b/src/Npgsql/ThrowHelper.cs index 24de0f62e8..f20dac780c 100644 --- a/src/Npgsql/ThrowHelper.cs +++ b/src/Npgsql/ThrowHelper.cs @@ -1,41 +1,105 @@ using Npgsql.BackendMessages; -using Npgsql.TypeHandling; using System; using System.Diagnostics.CodeAnalysis; -using System.Reflection; +using Npgsql.Internal; -namespace Npgsql +namespace Npgsql; + +static class ThrowHelper { - static class ThrowHelper - { - [DoesNotReturn] - internal static void ThrowInvalidCastException_NotSupportedType(NpgsqlTypeHandler handler, NpgsqlParameter? parameter, Type type) - { - var parameterName = parameter is null - ? null - : parameter.TrimmedName == string.Empty - ? $"${parameter.Collection!.IndexOf(parameter) + 1}" - : parameter.TrimmedName; - - throw new InvalidCastException(parameterName is null - ? $"Cannot write a value of CLR type '{type}' as database type '{handler.PgDisplayName}'." - : $"Cannot write a value of CLR type '{type}' as database type '{handler.PgDisplayName}' for parameter '{parameterName}'."); - } - - [DoesNotReturn] - internal static void ThrowInvalidCastException_NoValue(FieldDescription field) => - throw new InvalidCastException($"Column '{field.Name}' is null."); - - [DoesNotReturn] - internal static void ThrowInvalidOperationException_NoPropertyGetter(Type type, MemberInfo property) => - throw new InvalidOperationException($"Composite type '{type}' cannot be written because the '{property}' property has no getter."); - - [DoesNotReturn] - internal static void ThrowInvalidOperationException_NoPropertySetter(Type type, MemberInfo property) => - throw new InvalidOperationException($"Composite type '{type}' cannot be read because the '{property}' property has no setter."); - - [DoesNotReturn] - internal static void ThrowInvalidOperationException_BinaryImportParametersMismatch(int columnCount, int valueCount) => - throw new InvalidOperationException($"The binary import operation was started with {columnCount} column(s), but {valueCount} value(s) were provided."); - } + [DoesNotReturn] + internal static void ThrowArgumentOutOfRangeException() + => throw new ArgumentOutOfRangeException(); + + [DoesNotReturn] + internal static void ThrowArgumentOutOfRangeException(string paramName, string message) + => throw new ArgumentOutOfRangeException(paramName, message); + + [DoesNotReturn] + internal static void ThrowArgumentOutOfRangeException(string paramName, string message, object argument) + => throw new ArgumentOutOfRangeException(paramName, string.Format(message, argument)); + + [DoesNotReturn] + internal static void ThrowInvalidOperationException() + => throw new InvalidOperationException(); + + [DoesNotReturn] + internal static void ThrowInvalidOperationException(string message) + => throw new InvalidOperationException(message); + + [DoesNotReturn] + internal static void ThrowInvalidOperationException(string message, object argument) + => throw new InvalidOperationException(string.Format(message, argument)); + + [DoesNotReturn] + internal static void ThrowObjectDisposedException(string? objectName) + => throw new ObjectDisposedException(objectName); + + [DoesNotReturn] + internal static void ThrowObjectDisposedException(string objectName, string message) + => throw new ObjectDisposedException(objectName, message); + + [DoesNotReturn] + internal static void ThrowObjectDisposedException(string objectName, Exception? innerException) + => throw new ObjectDisposedException(objectName, innerException); + + [DoesNotReturn] + internal static void ThrowInvalidCastException(string message, object argument) + => throw new InvalidCastException(string.Format(message, argument)); + + [DoesNotReturn] + internal static void ThrowInvalidCastException_NoValue(FieldDescription field) => + throw new InvalidCastException($"Column '{field.Name}' is null."); + + [DoesNotReturn] + internal static void ThrowInvalidCastException(string message) => + throw new InvalidCastException(message); + + [DoesNotReturn] + internal static void ThrowInvalidCastException_NoValue() => + throw new InvalidCastException("Field is null."); + + [DoesNotReturn] + internal static void ThrowNpgsqlException(string message) + => throw new NpgsqlException(message); + + [DoesNotReturn] + internal static void ThrowNpgsqlException(string message, Exception? innerException) + => throw new NpgsqlException(message, innerException); + + [DoesNotReturn] + internal static void ThrowNpgsqlOperationInProgressException(NpgsqlCommand command) + => throw new NpgsqlOperationInProgressException(command); + + [DoesNotReturn] + internal static void ThrowNpgsqlOperationInProgressException(ConnectorState state) + => throw new NpgsqlOperationInProgressException(state); + + [DoesNotReturn] + internal static void ThrowArgumentException(string message) + => throw new ArgumentException(message); + + [DoesNotReturn] + internal static void ThrowArgumentException(string message, string paramName) + => throw new ArgumentException(message, paramName); + + [DoesNotReturn] + internal static void ThrowArgumentNullException(string paramName) + => throw new ArgumentNullException(paramName); + + [DoesNotReturn] + internal static void ThrowArgumentNullException(string message, string paramName) + => throw new ArgumentNullException(paramName, message); + + [DoesNotReturn] + internal static void ThrowIndexOutOfRangeException(string message) + => throw new IndexOutOfRangeException(message); + + [DoesNotReturn] + internal static void ThrowNotSupportedException(string? message = null) + => throw new NotSupportedException(message); + + [DoesNotReturn] + internal static void ThrowNpgsqlExceptionWithInnerTimeoutException(string message) + => throw new NpgsqlException(message, new TimeoutException()); } diff --git a/src/Npgsql/TypeHandlers/ArrayHandler.cs b/src/Npgsql/TypeHandlers/ArrayHandler.cs deleted file mode 100644 index ee24828b3e..0000000000 --- a/src/Npgsql/TypeHandlers/ArrayHandler.cs +++ /dev/null @@ -1,509 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers -{ - /// - /// Non-generic base class for all type handlers which handle PostgreSQL arrays. - /// Extend from instead. - /// - /// - /// https://www.postgresql.org/docs/current/static/arrays.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public abstract class ArrayHandler : NpgsqlTypeHandler - { - private protected int LowerBound { get; } // The lower bound value sent to the backend when writing arrays. Normally 1 (the PG default) but is 0 for OIDVector. - private protected NpgsqlTypeHandler ElementHandler { get; } - - static readonly MethodInfo ReadArrayMethod = typeof(ArrayHandler).GetMethod(nameof(ReadArray), BindingFlags.NonPublic | BindingFlags.Instance)!; - static readonly MethodInfo ReadListMethod = typeof(ArrayHandler).GetMethod(nameof(ReadList), BindingFlags.NonPublic | BindingFlags.Instance)!; - - /// - protected ArrayHandler(PostgresType arrayPostgresType, NpgsqlTypeHandler elementHandler, int lowerBound = 1) - : base(arrayPostgresType) - { - LowerBound = lowerBound; - ElementHandler = elementHandler; - } - - internal override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(Array); - internal override Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null) => typeof(Array); - - /// - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => throw new NotSupportedException(); - - /// - public override IRangeHandler CreateRangeHandler(PostgresType rangeBackendType) - => throw new NotSupportedException(); - - #region Read - - /// - public override TRequestedArray Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => Read(buf, len, false, fieldDescription).GetAwaiter().GetResult(); - - /// - protected internal override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - if (ArrayTypeInfo.IsArray) - return (TRequestedArray)(object)await ArrayTypeInfo.ReadArrayFunc(this, buf, async); - - if (ArrayTypeInfo.IsList) - return await ArrayTypeInfo.ReadListFunc(this, buf, async); - - throw new InvalidCastException(fieldDescription == null - ? $"Can't cast database type to {typeof(TRequestedArray).Name}" - : $"Can't cast database type {fieldDescription.Handler.PgDisplayName} to {typeof(TRequestedArray).Name}" - ); - } - - /// - /// Reads an array of element type from the given buffer . - /// - protected async ValueTask ReadArray(NpgsqlReadBuffer buf, bool async, int expectedDimensions = 0) - { - await buf.Ensure(12, async); - var dimensions = buf.ReadInt32(); - var containsNulls = buf.ReadInt32() == 1; - buf.ReadUInt32(); // Element OID. Ignored. - - if (ElementTypeInfo.IsNonNullable && containsNulls) - throw new InvalidOperationException(ReadNonNullableCollectionWithNullsExceptionMessage); - - if (dimensions == 0) - return expectedDimensions > 1 ? Array.CreateInstance(typeof(TRequestedElement), new int[expectedDimensions]) : Array.Empty(); - - if (expectedDimensions > 0 && dimensions != expectedDimensions) - throw new InvalidOperationException($"Cannot read an array with {expectedDimensions} dimension(s) from an array with {dimensions} dimension(s)"); - - if (dimensions == 1) - { - await buf.Ensure(8, async); - var arrayLength = buf.ReadInt32(); - - buf.ReadInt32(); // Lower bound - - var oneDimensional = new TRequestedElement[arrayLength]; - for (var i = 0; i < oneDimensional.Length; i++) - oneDimensional[i] = await ElementHandler.ReadWithLength(buf, async); - - return oneDimensional; - } - - var dimLengths = new int[dimensions]; - await buf.Ensure(dimensions * 8, async); - - for (var i = 0; i < dimensions; i++) - { - dimLengths[i] = buf.ReadInt32(); - buf.ReadInt32(); // Lower bound - } - - var result = Array.CreateInstance(typeof(TRequestedElement), dimLengths); - - // Multidimensional arrays - // We can't avoid boxing here - var indices = new int[dimensions]; - while (true) - { - var element = await ElementHandler.ReadWithLength(buf, async); - result.SetValue(element, indices); - - // TODO: Overly complicated/inefficient... - indices[dimensions - 1]++; - for (var dim = dimensions - 1; dim >= 0; dim--) - { - if (indices[dim] <= result.GetUpperBound(dim)) - continue; - - if (dim == 0) - return result; - - for (var j = dim; j < dimensions; j++) - indices[j] = result.GetLowerBound(j); - indices[dim - 1]++; - } - } - } - - /// - /// Reads a generic list containing elements of type from the given buffer . - /// - protected async ValueTask> ReadList(NpgsqlReadBuffer buf, bool async) - { - await buf.Ensure(12, async); - var dimensions = buf.ReadInt32(); - var containsNulls = buf.ReadInt32() == 1; - buf.ReadUInt32(); // Element OID. Ignored. - - if (dimensions == 0) - return ElementTypeInfo.EmptyList; - if (dimensions > 1) - throw new NotSupportedException($"Can't read multidimensional array as List<{typeof(TRequestedElement).Name}>"); - if (ElementTypeInfo.IsNonNullable && containsNulls) - throw new InvalidOperationException(ReadNonNullableCollectionWithNullsExceptionMessage); - - await buf.Ensure(8, async); - var length = buf.ReadInt32(); - buf.ReadInt32(); // We don't care about the lower bounds - - var list = new List(length); - for (var i = 0; i < length; i++) - list.Add(await ElementHandler.ReadWithLength(buf, async)); - return list; - } - - internal const string ReadNonNullableCollectionWithNullsExceptionMessage = - "Cannot read a non-nullable collection of elements because the returned array contains nulls. " + - "Call GetFieldValue with a nullable array instead."; - - #endregion Read - - #region Static generic caching helpers - - internal static class ElementTypeInfo - { - public static readonly bool IsNonNullable = - typeof(TElement).IsValueType && Nullable.GetUnderlyingType(typeof(TElement)) is null; - - public static readonly List EmptyList = new List(0); - } - - internal static class ArrayTypeInfo - { - // ReSharper disable StaticMemberInGenericType - public static readonly bool IsArray; - public static readonly bool IsList; - public static readonly Type? ElementType; - - public static readonly Func> ReadArrayFunc = default!; - public static readonly Func> ReadListFunc = default!; - // ReSharper restore StaticMemberInGenericType - - public static bool IsArrayOrList => IsArray || IsList; - - static ArrayTypeInfo() - { - var type = typeof(TArrayOrList); - IsArray = type.IsArray; - IsList = type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>); - - ElementType = IsArray - ? type.GetElementType() - : IsList - ? type.GetGenericArguments()[0] - : null; - - if (ElementType == null) - return; - - // Initialize delegates - var arrayHandlerParam = Expression.Parameter(typeof(ArrayHandler), "arrayHandler"); - var bufferParam = Expression.Parameter(typeof(NpgsqlReadBuffer), "buf"); - var asyncParam = Expression.Parameter(typeof(bool), "async"); - - if (IsArray) - { - ReadArrayFunc = Expression - .Lambda>>( - Expression.Call( - arrayHandlerParam, - ReadArrayMethod.MakeGenericMethod(ElementType), - bufferParam, asyncParam, Expression.Constant(type.GetArrayRank())), - arrayHandlerParam, bufferParam, asyncParam) - .Compile(); - } - - if (IsList) - { - ReadListFunc = Expression - .Lambda>>( - Expression.Call( - arrayHandlerParam, - ReadListMethod.MakeGenericMethod(ElementType), - bufferParam, asyncParam), - arrayHandlerParam, bufferParam, asyncParam) - .Compile(); - } - } - } - - #endregion Static generic caching helpers - } - - /// - /// Base class for all type handlers which handle PostgreSQL arrays. - /// - /// - /// https://www.postgresql.org/docs/current/static/arrays.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class ArrayHandler : ArrayHandler - { - /// - public ArrayHandler(PostgresType arrayPostgresType, NpgsqlTypeHandler elementHandler, int lowerBound = 1) - : base(arrayPostgresType, elementHandler, lowerBound) {} - - #region Read - - internal override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => await ReadArray(buf, async); - - internal override object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadArray(buf, false).GetAwaiter().GetResult(); - - #endregion - - #region Write - - static Exception MixedTypesOrJaggedArrayException(Exception innerException) - => new Exception("While trying to write an array, one of its elements failed validation. " + - "You may be trying to mix types in a non-generic IList, or to write a jagged array.", innerException); - - static Exception CantWriteTypeException(Type type) - => new InvalidCastException($"Can't write type {type} as an array of {typeof(TElement)}"); - - // Since TAny isn't constrained to class? or struct (C# doesn't have a non-nullable constraint that doesn't limit us to either struct or class), - // we must use the bang operator here to tell the compiler that a null value will never be returned. - - /// - protected internal override int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value!, ref lengthCache); - - /// - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value!, ref lengthCache); - - int ValidateAndGetLength(object value, ref NpgsqlLengthCache? lengthCache) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - if (value is ICollection generic) - return ValidateAndGetLengthGeneric(generic, ref lengthCache); - if (value is ICollection nonGeneric) - return ValidateAndGetLengthNonGeneric(nonGeneric, ref lengthCache); - throw CantWriteTypeException(value.GetType()); - } - - // Handle single-dimensional arrays and generic IList - int ValidateAndGetLengthGeneric(ICollection value, ref NpgsqlLengthCache lengthCache) - { - // Leave empty slot for the entire array length, and go ahead an populate the element slots - var pos = lengthCache.Position; - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - 1 * 8 + // number of dimensions (1) * (length + lower bound) - 4 * value.Count; // sum of element lengths - - lengthCache.Set(0); - NpgsqlLengthCache? elemLengthCache = lengthCache; - - foreach (var element in value) - { - if (element is null || typeof(TElement) == typeof(DBNull)) - continue; - - try - { - len += ElementHandler.ValidateAndGetLength(element, ref elemLengthCache, null); - } - catch (Exception e) - { - throw MixedTypesOrJaggedArrayException(e); - } - } - - lengthCache.Lengths[pos] = len; - return len; - } - - // Take care of multi-dimensional arrays and non-generic IList, we have no choice but to box/unbox - int ValidateAndGetLengthNonGeneric(ICollection value, ref NpgsqlLengthCache lengthCache) - { - var asMultidimensional = value as Array; - var dimensions = asMultidimensional?.Rank ?? 1; - - // Leave empty slot for the entire array length, and go ahead an populate the element slots - var pos = lengthCache.Position; - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - dimensions * 8 + // number of dimensions * (length + lower bound) - 4 * value.Count; // sum of element lengths - - lengthCache.Set(0); - NpgsqlLengthCache? elemLengthCache = lengthCache; - - foreach (var element in value) - if (element != null && element != DBNull.Value) - try - { - len += ElementHandler.ValidateObjectAndGetLength(element, ref elemLengthCache, null); - } - catch (Exception e) - { - throw MixedTypesOrJaggedArrayException(e); - } - - lengthCache.Lengths[pos] = len; - return len; - } - - internal override Task WriteWithLengthInternal([AllowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - - return WriteWithLength(); - - async Task WriteWithLengthLong() - { - await buf.Flush(async, cancellationToken); - await WriteWithLength(); - } - - Task WriteWithLength() - { - if (value == null || typeof(TAny) == typeof(DBNull)) - { - buf.WriteInt32(-1); - return Task.CompletedTask; - } - - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - - if (value is ICollection list) - return WriteGeneric(list, buf, lengthCache, async, cancellationToken); - - if (value is ICollection nonGeneric) - return WriteNonGeneric(nonGeneric, buf, lengthCache, async, cancellationToken); - - throw CantWriteTypeException(value.GetType()); - } - } - - // The default WriteObjectWithLength casts the type handler to INpgsqlTypeHandler, but that's not sufficient for - // us (need to handle many types of T, e.g. int[], int[,]...) - /// - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value is DBNull - ? WriteWithLengthInternal(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken) - : WriteWithLengthInternal(value, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteGeneric(ICollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - var len = - 4 + // dimensions - 4 + // has_nulls (unused) - 4 + // type OID - 1 * 8; // number of dimensions (1) * (length + lower bound) - if (buf.WriteSpaceLeft < len) - { - await buf.Flush(async, cancellationToken); - Debug.Assert(buf.WriteSpaceLeft >= len, "Buffer too small for header"); - } - - buf.WriteInt32(1); - buf.WriteInt32(1); // has_nulls = 1. Not actually used by the backend. - buf.WriteUInt32(ElementHandler.PostgresType.OID); - buf.WriteInt32(value.Count); - buf.WriteInt32(LowerBound); // We don't map .NET lower bounds to PG - - foreach (var element in value) - await ElementHandler.WriteWithLengthInternal(element, buf, lengthCache, null, async, cancellationToken); - } - - async Task WriteNonGeneric(ICollection value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - var asArray = value as Array; - var dimensions = asArray?.Rank ?? 1; - - var len = - 4 + // ndim - 4 + // has_nulls - 4 + // element_oid - dimensions * 8; // dim (4) + lBound (4) - - if (buf.WriteSpaceLeft < len) - { - await buf.Flush(async, cancellationToken); - Debug.Assert(buf.WriteSpaceLeft >= len, "Buffer too small for header"); - } - - buf.WriteInt32(dimensions); - buf.WriteInt32(1); // HasNulls=1. Not actually used by the backend. - buf.WriteUInt32(ElementHandler.PostgresType.OID); - if (asArray != null) - { - for (var i = 0; i < dimensions; i++) - { - buf.WriteInt32(asArray.GetLength(i)); - buf.WriteInt32(LowerBound); // We don't map .NET lower bounds to PG - } - } - else - { - buf.WriteInt32(value.Count); - buf.WriteInt32(LowerBound); // We don't map .NET lower bounds to PG - } - - foreach (var element in value) - await ElementHandler.WriteObjectWithLength(element ?? DBNull.Value, buf, lengthCache, null, async, cancellationToken); - } - - #endregion - } - - /// - /// https://www.postgresql.org/docs/current/static/arrays.html - /// - /// The .NET type contained as an element within this array - /// The .NET provider-specific type contained as an element within this array - class ArrayHandlerWithPsv : ArrayHandler - { - public ArrayHandlerWithPsv(PostgresType arrayPostgresType, NpgsqlTypeHandler elementHandler) - : base(arrayPostgresType, elementHandler) { } - - protected internal override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - if (ArrayTypeInfo.ElementType == typeof(TElementPsv)) - { - if (ArrayTypeInfo.IsArray) - return (TRequestedArray)(object)await ReadArray(buf, async, typeof(TRequestedArray).GetArrayRank()); - - if (ArrayTypeInfo.IsList) - return (TRequestedArray)(object)await ReadList(buf, async); - } - return await base.Read(buf, len, async, fieldDescription); - } - - internal override object ReadPsvAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadPsvAsObject(buf, len, false, fieldDescription).GetAwaiter().GetResult(); - - internal override async ValueTask ReadPsvAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => await ReadArray(buf, async); - } -} diff --git a/src/Npgsql/TypeHandlers/BitStringHandler.cs b/src/Npgsql/TypeHandlers/BitStringHandler.cs deleted file mode 100644 index 280480c681..0000000000 --- a/src/Npgsql/TypeHandlers/BitStringHandler.cs +++ /dev/null @@ -1,312 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.Specialized; -using System.Diagnostics; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A type handler for the PostgreSQL bit string data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-bit.html. - /// - /// Note that for BIT(1), this handler will return a bool by default, to align with SQLClient - /// (see discussion https://github.com/npgsql/npgsql/pull/362#issuecomment-59622101). - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("bit varying", NpgsqlDbType.Varbit, new[] { typeof(BitArray), typeof(BitVector32) })] - [TypeMapping("bit", NpgsqlDbType.Bit)] - public class BitStringHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler - { - /// - public BitStringHandler(PostgresType postgresType) : base(postgresType) {} - - internal override Type GetFieldType(FieldDescription? fieldDescription = null) - => fieldDescription != null && fieldDescription.TypeModifier == 1 ? typeof(bool) : typeof(BitArray); - - internal override Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null) - => GetFieldType(fieldDescription); - - // BitString requires a special array handler which returns bool or BitArray - /// - public override ArrayHandler CreateArrayHandler(PostgresArrayType backendType) - => new BitStringArrayHandler(backendType, this); - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numBits = buf.ReadInt32(); - var result = new BitArray(numBits); - var bytesLeft = len - 4; // Remove leading number of bits - if (bytesLeft == 0) - return result; - - var bitNo = 0; - while (true) - { - var iterationEndPos = bytesLeft > buf.ReadBytesLeft - ? bytesLeft - buf.ReadBytesLeft - : 1; - - for (; bytesLeft > iterationEndPos; bytesLeft--) - { - // ReSharper disable ShiftExpressionRealShiftCountIsZero - var chunk = buf.ReadByte(); - result[bitNo++] = (chunk & (1 << 7)) != 0; - result[bitNo++] = (chunk & (1 << 6)) != 0; - result[bitNo++] = (chunk & (1 << 5)) != 0; - result[bitNo++] = (chunk & (1 << 4)) != 0; - result[bitNo++] = (chunk & (1 << 3)) != 0; - result[bitNo++] = (chunk & (1 << 2)) != 0; - result[bitNo++] = (chunk & (1 << 1)) != 0; - result[bitNo++] = (chunk & (1 << 0)) != 0; - } - - if (bytesLeft == 1) - break; - - Debug.Assert(buf.ReadBytesLeft == 0); - await buf.Ensure(Math.Min(bytesLeft, buf.Size), async); - } - - if (bitNo < result.Length) - { - var remainder = result.Length - bitNo; - await buf.Ensure(1, async); - var lastChunk = buf.ReadByte(); - for (var i = 7; i >= 8 - remainder; i--) - result[bitNo++] = (lastChunk & (1 << i)) != 0; - } - - return result; - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - if (len > 4 + 4) - throw new InvalidCastException("Can't read PostgreSQL bitstring with more than 32 bits into BitVector32"); - - await buf.Ensure(4 + 4, async); - - var numBits = buf.ReadInt32(); - return numBits == 0 - ? new BitVector32(0) - : new BitVector32(buf.ReadInt32()); - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(5, async); - var bitLen = buf.ReadInt32(); - if (bitLen != 1) - throw new InvalidCastException("Can't convert a BIT(N) type to bool, only BIT(1)"); - var b = buf.ReadByte(); - return (b & 128) != 0; - } - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing string to PostgreSQL bitstring is supported, no reading."); - - internal override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => fieldDescription?.TypeModifier == 1 - ? (object)await Read(buf, len, async, fieldDescription) - : await Read(buf, len, async, fieldDescription); - - internal override object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => fieldDescription?.TypeModifier == 1 - ? (object)Read(buf, len, false, fieldDescription).Result - : Read(buf, len, false, fieldDescription).Result; - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(BitArray value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + (value.Length + 7) / 8; - - /// - public int ValidateAndGetLength(BitVector32 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Data == 0 ? 4 : 8; - - /// - public int ValidateAndGetLength(bool value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 5; - - /// - public int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value.Any(c => c != '0' && c != '1')) - throw new FormatException("Cannot interpret as ASCII BitString: " + value); - return 4 + (value.Length + 7) / 8; - } - - /// - public override async Task Write(BitArray value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // Initial bitlength byte - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Length); - - var byteLen = (value.Length + 7) / 8; - var pos = 0; - while (true) - { - var endPos = pos + Math.Min(byteLen - pos, buf.WriteSpaceLeft); - for (; pos < endPos; pos++) - { - var bitPos = pos*8; - var b = 0; - for (var i = 0; i < Math.Min(8, value.Length - bitPos); i++) - b += (value[bitPos + i] ? 1 : 0) << (8 - i - 1); - buf.WriteByte((byte)b); - } - - if (pos == byteLen) - return; - await buf.Flush(async, cancellationToken); - } - } - - /// - public async Task Write(BitVector32 value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 8) - await buf.Flush(async, cancellationToken); - - if (value.Data == 0) - buf.WriteInt32(0); - else - { - buf.WriteInt32(32); - buf.WriteInt32(value.Data); - } - } - - /// - public async Task Write(bool value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 5) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(1); - buf.WriteByte(value ? (byte)0x80 : (byte)0); - } - - /// - public async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // Initial bitlength byte - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Length); - - var pos = 0; - var byteLen = (value.Length + 7) / 8; - var bytePos = 0; - - while (true) - { - var endBytePos = bytePos + Math.Min(byteLen - bytePos - 1, buf.WriteSpaceLeft); - - for (; bytePos < endBytePos; bytePos++) - { - var b = 0; - b += (value[pos++] - '0') << 7; - b += (value[pos++] - '0') << 6; - b += (value[pos++] - '0') << 5; - b += (value[pos++] - '0') << 4; - b += (value[pos++] - '0') << 3; - b += (value[pos++] - '0') << 2; - b += (value[pos++] - '0') << 1; - b += (value[pos++] - '0'); - buf.WriteByte((byte)b); - } - - if (bytePos >= byteLen - 1) - break; - await buf.Flush(async, cancellationToken); - } - - if (pos < value.Length) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - var remainder = value.Length - pos; - var lastChunk = 0; - for (var i = 7; i >= 8 - remainder; i--) - lastChunk += (value[pos++] - '0') << i; - buf.WriteByte((byte)lastChunk); - } - } - - #endregion - } - - /// - /// A special handler for arrays of bit strings. - /// Differs from the standard array handlers in that it returns arrays of bool for BIT(1) and arrays - /// of BitArray otherwise (just like the scalar BitStringHandler does). - /// - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class BitStringArrayHandler : ArrayHandler - { - /// - public BitStringArrayHandler(PostgresType postgresType, BitStringHandler elementHandler) - : base(postgresType, elementHandler) {} - - /// - protected internal override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - if (ArrayTypeInfo.ElementType == typeof(BitArray)) - { - if (ArrayTypeInfo.IsArray) - return (TRequestedArray)(object)await ReadArray(buf, async); - - if (ArrayTypeInfo.IsList) - return (TRequestedArray)(object)await ReadList(buf, async); - } - - if (ArrayTypeInfo.ElementType == typeof(bool)) - { - if (ArrayTypeInfo.IsArray) - return (TRequestedArray)(object)await ReadArray(buf, async); - - if (ArrayTypeInfo.IsList) - return (TRequestedArray)(object)await ReadList(buf, async); - } - - return await base.Read(buf, len, async, fieldDescription); - } - - internal override object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadAsObject(buf, len, false, fieldDescription).Result; - - internal override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => fieldDescription?.TypeModifier == 1 - ? await ReadArray(buf, async) - : await ReadArray(buf, async); - } -} diff --git a/src/Npgsql/TypeHandlers/BoolHandler.cs b/src/Npgsql/TypeHandlers/BoolHandler.cs deleted file mode 100644 index b3688e0fb9..0000000000 --- a/src/Npgsql/TypeHandlers/BoolHandler.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A type handler for the PostgreSQL bool data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-boolean.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("boolean", NpgsqlDbType.Boolean, DbType.Boolean, typeof(bool))] - public class BoolHandler : NpgsqlSimpleTypeHandler - { - /// - public BoolHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override bool Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadByte() != 0; - - /// - public override int ValidateAndGetLength(bool value, NpgsqlParameter? parameter) - => 1; - - /// - public override void Write(bool value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteByte(value ? (byte)1 : (byte)0); - } -} diff --git a/src/Npgsql/TypeHandlers/ByteaHandler.cs b/src/Npgsql/TypeHandlers/ByteaHandler.cs deleted file mode 100644 index 6c00a19b2f..0000000000 --- a/src/Npgsql/TypeHandlers/ByteaHandler.cs +++ /dev/null @@ -1,139 +0,0 @@ -using System; -using System.Data; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A type handler for the PostgreSQL bytea data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-binary.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping( - "bytea", - NpgsqlDbType.Bytea, - DbType.Binary, - new[] { - typeof(byte[]), - typeof(ArraySegment), -#if !NETSTANDARD2_0 - typeof(ReadOnlyMemory), - typeof(Memory) -#endif - })] - public class ByteaHandler : NpgsqlTypeHandler, INpgsqlTypeHandler> -#if !NETSTANDARD2_0 - , INpgsqlTypeHandler>, INpgsqlTypeHandler> -#endif - { - /// - /// Constructs a . - /// - public ByteaHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - var bytes = new byte[len]; - var pos = 0; - while (true) - { - var toRead = Math.Min(len - pos, buf.ReadBytesLeft); - buf.ReadBytes(bytes, pos, toRead); - pos += toRead; - if (pos == len) - break; - await buf.ReadMore(async); - } - return bytes; - } - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing ArraySegment to PostgreSQL bytea is supported, no reading."); - - int ValidateAndGetLength(int bufferLen, NpgsqlParameter? parameter) - => parameter == null || parameter.Size <= 0 || parameter.Size >= bufferLen - ? bufferLen - : parameter.Size; - - /// - public override int ValidateAndGetLength(byte[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Count, parameter); - - /// - public override Task Write(byte[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, 0, ValidateAndGetLength(value.Length, parameter), async, cancellationToken); - - /// - public Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Array is null ? Task.CompletedTask : Write(value.Array, buf, value.Offset, ValidateAndGetLength(value.Count, parameter), async, cancellationToken); - - async Task Write(byte[] value, NpgsqlWriteBuffer buf, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - // The entire segment fits in our buffer, copy it as usual. - if (count <= buf.WriteSpaceLeft) - { - buf.WriteBytes(value, offset, count); - return; - } - - // The segment is larger than our buffer. Flush whatever is currently in the buffer and - // write the array directly to the socket. - await buf.Flush(async, cancellationToken); - await buf.DirectWrite(new ReadOnlyMemory(value, offset, count), async, cancellationToken); - } - -#if !NETSTANDARD2_0 - /// - public int ValidateAndGetLength(Memory value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public int ValidateAndGetLength(ReadOnlyMemory value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value.Length, parameter); - - /// - public async Task Write(ReadOnlyMemory value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (parameter != null && parameter.Size > 0 && parameter.Size < value.Length) - value = value.Slice(0, parameter.Size); - - // The entire segment fits in our buffer, copy it into the buffer as usual. - if (value.Length <= buf.WriteSpaceLeft) - { - buf.WriteBytes(value.Span); - return; - } - - // The segment is larger than our buffer. Perform a direct write, flushing whatever is currently in the buffer - // and then writing the array directly to the socket. - await buf.DirectWrite(value, async, cancellationToken); - } - - /// - public Task Write(Memory value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((ReadOnlyMemory)value, buf, lengthCache, parameter, async, cancellationToken); - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescriptioncancellationToken) - => throw new NotSupportedException("Only writing ReadOnlyMemory to PostgreSQL bytea is supported, no reading."); - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing Memory to PostgreSQL bytea is supported, no reading."); -#endif - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/ByReference.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/ByReference.cs deleted file mode 100644 index 64df793adb..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/ByReference.cs +++ /dev/null @@ -1,11 +0,0 @@ - -// Only used for value types, but can't constrain because MappedCompositeHandler isn't constrained -#nullable disable - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - sealed class ByReference - { - public T Value; - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs deleted file mode 100644 index e023b8cd35..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler.cs +++ /dev/null @@ -1,64 +0,0 @@ -using System; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - class CompositeConstructorHandler - { - public PostgresType PostgresType { get; } - public ConstructorInfo ConstructorInfo { get; } - public CompositeParameterHandler[] Handlers { get; } - - protected CompositeConstructorHandler(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] handlers) - { - PostgresType = postgresType; - ConstructorInfo = constructorInfo; - Handlers = handlers; - } - - public virtual async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != Handlers.Length) - throw new InvalidOperationException($"pg_attributes contains {Handlers.Length} fields for type {PostgresType.DisplayName}, but {fieldCount} fields were received."); - - var args = new object?[Handlers.Length]; - foreach (var handler in Handlers) - args[handler.ParameterPosition] = await handler.Read(buffer, async); - - return (TComposite)ConstructorInfo.Invoke(args); - } - - public static CompositeConstructorHandler Create(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] parameterHandlers) - { - const int maxGenericParameters = 8; - - if (parameterHandlers.Length > maxGenericParameters) - return new CompositeConstructorHandler(postgresType, constructorInfo, parameterHandlers); - - var parameterTypes = new Type[1 + maxGenericParameters]; - foreach (var parameterHandler in parameterHandlers) - parameterTypes[1 + parameterHandler.ParameterPosition] = parameterHandler.ParameterType; - - for (var parameterIndex = 1; parameterIndex < parameterTypes.Length; parameterIndex++) - parameterTypes[parameterIndex] ??= typeof(Unused); - - parameterTypes[0] = typeof(TComposite); - return (CompositeConstructorHandler)Activator.CreateInstance( - typeof(CompositeConstructorHandler<,,,,,,,,>).MakeGenericType(parameterTypes), - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { postgresType, constructorInfo, parameterHandlers }, - culture: null)!; - } - - readonly struct Unused - { - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs deleted file mode 100644 index 5365a62cf0..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeConstructorHandler`.cs +++ /dev/null @@ -1,68 +0,0 @@ -using System; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - sealed class CompositeConstructorHandler : CompositeConstructorHandler - { - delegate TComposite CompositeConstructor(in Arguments args); - - readonly CompositeConstructor _constructor; - - public CompositeConstructorHandler(PostgresType postgresType, ConstructorInfo constructorInfo, CompositeParameterHandler[] parameterHandlers) - : base(postgresType, constructorInfo, parameterHandlers) - { - var parameter = Expression.Parameter(typeof(Arguments).MakeByRefType()); - var fields = Enumerable - .Range(1, parameterHandlers.Length) - .Select(i => Expression.Field(parameter, "Argument" + i)); - - _constructor = Expression - .Lambda(Expression.New(constructorInfo, fields), parameter) - .Compile(); - } - - public override async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != Handlers.Length) - throw new InvalidOperationException($"pg_attributes contains {Handlers.Length} fields for type {PostgresType.DisplayName}, but {fieldCount} fields were received."); - - var args = default(Arguments); - - foreach (var handler in Handlers) - switch (handler.ParameterPosition) - { - case 0: args.Argument1 = await handler.Read(buffer, async); break; - case 1: args.Argument2 = await handler.Read(buffer, async); break; - case 2: args.Argument3 = await handler.Read(buffer, async); break; - case 3: args.Argument4 = await handler.Read(buffer, async); break; - case 4: args.Argument5 = await handler.Read(buffer, async); break; - case 5: args.Argument6 = await handler.Read(buffer, async); break; - case 6: args.Argument7 = await handler.Read(buffer, async); break; - case 7: args.Argument8 = await handler.Read(buffer, async); break; - } - - return _constructor(args); - } - - struct Arguments - { - public T1 Argument1; - public T2 Argument2; - public T3 Argument3; - public T4 Argument4; - public T5 Argument5; - public T6 Argument6; - public T7 Argument7; - public T8 Argument8; - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeHandler.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeHandler.cs deleted file mode 100644 index ee9c44b00f..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeHandler.cs +++ /dev/null @@ -1,258 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - class CompositeHandler : NpgsqlTypeHandler, ICompositeHandler - { - readonly ConnectorTypeMapper _typeMapper; - readonly INpgsqlNameTranslator _nameTranslator; - - [NotNull] Func? _constructor; - [NotNull] CompositeConstructorHandler? _constructorHandler; - [NotNull] CompositeMemberHandler[]? _memberHandlers; - - public Type CompositeType => typeof(T); - - public CompositeHandler(PostgresCompositeType postgresType, ConnectorTypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - : base(postgresType) - { - _typeMapper = typeMapper; - _nameTranslator = nameTranslator; - } - - public override ValueTask Read(NpgsqlReadBuffer buffer, int length, bool async, FieldDescription? fieldDescription = null) - { - Initialize(); - - return _constructorHandler is null - ? ReadUsingMemberHandlers() - : _constructorHandler.Read(buffer, async); - - async ValueTask ReadUsingMemberHandlers() - { - await buffer.Ensure(sizeof(int), async); - - var fieldCount = buffer.ReadInt32(); - if (fieldCount != _memberHandlers.Length) - throw new InvalidOperationException($"pg_attributes contains {_memberHandlers.Length} fields for type {PgDisplayName}, but {fieldCount} fields were received."); - - if (IsValueType.Value) - { - var composite = new ByReference { Value = _constructor() }; - foreach (var member in _memberHandlers) - await member.Read(composite, buffer, async); - - return composite.Value; - } - else - { - var composite = _constructor(); - foreach (var member in _memberHandlers) - await member.Read(composite, buffer, async); - - return composite; - } - } - } - - public override async Task Write(T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - Initialize(); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteInt32(_memberHandlers.Length); - - foreach (var member in _memberHandlers) - await member.Write(value, buffer, lengthCache, async, cancellationToken); - } - - public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - Initialize(); - - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - // Leave empty slot for the entire composite type, and go ahead an populate the element slots - var position = lengthCache.Position; - lengthCache.Set(0); - - // number of fields + (type oid + field length) * member count - var length = sizeof(int) + sizeof(int) * 2 * _memberHandlers.Length; - foreach (var member in _memberHandlers) - length += member.ValidateAndGetLength(value, ref lengthCache); - - return lengthCache.Lengths[position] = length; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void Initialize() - { - if (_memberHandlers is null) - InitializeCore(); - - void InitializeCore() - { - var pgType = (PostgresCompositeType)PostgresType; - - _memberHandlers = CreateMemberHandlers(pgType, _typeMapper, _nameTranslator); - _constructorHandler = CreateConstructorHandler(pgType, _typeMapper, _nameTranslator); - _constructor = _constructorHandler is null - ? Expression - .Lambda>(Expression.New(typeof(T))) - .Compile() - : null; - } - } - - static CompositeConstructorHandler? CreateConstructorHandler(PostgresCompositeType pgType, ConnectorTypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - { - var pgFields = pgType.Fields; - var clrType = typeof(T); - - ConstructorInfo? clrDefaultConstructor = null; - - foreach (var clrConstructor in clrType.GetConstructors()) - { - var clrParameters = clrConstructor.GetParameters(); - if (clrParameters.Length != pgFields.Count) - { - if (clrParameters.Length == 0) - clrDefaultConstructor = clrConstructor; - - continue; - } - - var clrParameterHandlerCount = 0; - var clrParametersMapped = new ParameterInfo[pgFields.Count]; - - foreach (var clrParameter in clrParameters) - { - var attr = clrParameter.GetCustomAttribute(); - var name = attr?.PgName ?? (clrParameter.Name is string clrName ? nameTranslator.TranslateMemberName(clrName) : null); - if (name is null) - break; - - for (var pgFieldIndex = pgFields.Count - 1; pgFieldIndex >= 0; --pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - if (pgField.Name != name) - continue; - - if (clrParametersMapped[pgFieldIndex] != null) - throw new AmbiguousMatchException($"Multiple constructor parameters are mapped to the '{pgField.Name}' field."); - - clrParameterHandlerCount++; - clrParametersMapped[pgFieldIndex] = clrParameter; - - break; - } - } - - if (clrParameterHandlerCount < pgFields.Count) - continue; - - var clrParameterHandlers = new CompositeParameterHandler[pgFields.Count]; - for (var pgFieldIndex = 0; pgFieldIndex < pgFields.Count; ++pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - - if (!typeMapper.TryGetByOID(pgField.Type.OID, out var handler)) - throw new NpgsqlException($"PostgreSQL composite type {pgType.DisplayName} has field {pgField.Type.DisplayName} with an unknown type (OID = {pgField.Type.OID})."); - - var clrParameter = clrParametersMapped[pgFieldIndex]; - var clrParameterHandlerType = typeof(CompositeParameterHandler<>) - .MakeGenericType(clrParameter.ParameterType); - - clrParameterHandlers[pgFieldIndex] = (CompositeParameterHandler)Activator.CreateInstance( - clrParameterHandlerType, - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { handler, clrParameter }, - culture: null)!; - } - - return CompositeConstructorHandler.Create(pgType, clrConstructor, clrParameterHandlers); - } - - if (clrDefaultConstructor is null && !clrType.IsValueType) - throw new InvalidOperationException($"No parameterless constructor defined for type '{clrType}'."); - - return null; - } - - static CompositeMemberHandler[] CreateMemberHandlers(PostgresCompositeType pgType, ConnectorTypeMapper typeMapper, INpgsqlNameTranslator nameTranslator) - { - var pgFields = pgType.Fields; - - var clrType = typeof(T); - var clrMemberHandlers = new CompositeMemberHandler[pgFields.Count]; - var clrMemberHandlerCount = 0; - var clrMemberHandlerType = IsValueType.Value - ? typeof(CompositeStructMemberHandler<,>) - : typeof(CompositeClassMemberHandler<,>); - - foreach (var clrProperty in clrType.GetProperties(BindingFlags.Instance | BindingFlags.Public)) - CreateMemberHandler(clrProperty, clrProperty.PropertyType); - - foreach (var clrField in clrType.GetFields(BindingFlags.Instance | BindingFlags.Public)) - CreateMemberHandler(clrField, clrField.FieldType); - - if (clrMemberHandlerCount != pgFields.Count) - { - var notMappedFields = string.Join(", ", clrMemberHandlers - .Select((member, memberIndex) => member == null ? $"'{pgFields[memberIndex].Name}'" : null) - .Where(member => member != null)); - throw new InvalidOperationException($"PostgreSQL composite type {pgType.DisplayName} contains fields {notMappedFields} which could not match any on CLR type {clrType.Name}"); - } - - return clrMemberHandlers; - - void CreateMemberHandler(MemberInfo clrMember, Type clrMemberType) - { - var attr = clrMember.GetCustomAttribute(); - var name = attr?.PgName ?? nameTranslator.TranslateMemberName(clrMember.Name); - - for (var pgFieldIndex = pgFields.Count - 1; pgFieldIndex >= 0; --pgFieldIndex) - { - var pgField = pgFields[pgFieldIndex]; - if (pgField.Name != name) - continue; - - if (clrMemberHandlers[pgFieldIndex] != null) - throw new AmbiguousMatchException($"Multiple class members are mapped to the '{pgField.Name}' field."); - - if (!typeMapper.TryGetByOID(pgField.Type.OID, out var handler)) - throw new NpgsqlException($"PostgreSQL composite type {pgType.DisplayName} has field {pgField.Type.DisplayName} with an unknown type (OID = {pgField.Type.OID})."); - - clrMemberHandlerCount++; - clrMemberHandlers[pgFieldIndex] = (CompositeMemberHandler)Activator.CreateInstance( - clrMemberHandlerType.MakeGenericType(clrType, clrMemberType), - BindingFlags.Instance | BindingFlags.Public, - binder: null, - args: new object[] { clrMember, pgField.Type, handler }, - culture: null)!; - - break; - } - } - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs deleted file mode 100644 index cd635257cd..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandler.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - abstract class CompositeMemberHandler - { - public MemberInfo MemberInfo { get; } - public PostgresType PostgresType { get; } - - protected CompositeMemberHandler(MemberInfo memberInfo, PostgresType postgresType) - { - MemberInfo = memberInfo; - PostgresType = postgresType; - } - - public abstract ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async); - - public abstract ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async); - - public abstract Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default); - - public abstract int ValidateAndGetLength(TComposite composite, ref NpgsqlLengthCache? lengthCache); - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs deleted file mode 100644 index 7992a1c005..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfClass.cs +++ /dev/null @@ -1,106 +0,0 @@ -using System; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - sealed class CompositeClassMemberHandler : CompositeMemberHandler - where TComposite : class - { - delegate TMember GetMember(TComposite composite); - delegate void SetMember(TComposite composite, TMember value); - - readonly GetMember? _get; - readonly SetMember? _set; - readonly NpgsqlTypeHandler _handler; - - public CompositeClassMemberHandler(FieldInfo fieldInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(fieldInfo, postgresType) - { - var composite = Expression.Parameter(typeof(TComposite), "composite"); - var value = Expression.Parameter(typeof(TMember), "value"); - - _get = Expression - .Lambda(Expression.Field(composite, fieldInfo), composite) - .Compile(); - _set = Expression - .Lambda(Expression.Assign(Expression.Field(composite, fieldInfo), value), composite, value) - .Compile(); - _handler = handler; - } - - public CompositeClassMemberHandler(PropertyInfo propertyInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(propertyInfo, postgresType) - { - var getMethod = propertyInfo.GetGetMethod(); - if (getMethod != null) - _get = (GetMember)Delegate.CreateDelegate(typeof(GetMember), getMethod); - - var setMethod = propertyInfo.GetSetMethod(); - if (setMethod != null) - _set = (SetMember)Delegate.CreateDelegate(typeof(SetMember), setMethod); - - Debug.Assert(setMethod != null || getMethod != null); - - _handler = handler; - } - - public override async ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async) - { - if (_set == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertySetter(typeof(TComposite), MemberInfo); - - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - Debug.Assert(oid == PostgresType.OID); - - var length = buffer.ReadInt32(); - if (length == -1) - return; - - var value = NullableHandler.Exists - ? await NullableHandler.ReadAsync(_handler, buffer, length, async) - : await _handler.Read(buffer, length, async); - - _set(composite, value); - } - - public override ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async) - => throw new NotSupportedException(); - - public override async Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteUInt32(PostgresType.OID); - if (NullableHandler.Exists) - await NullableHandler.WriteAsync(_handler, _get(composite), buffer, lengthCache, null, async, cancellationToken); - else - await _handler.WriteWithLengthInternal(_get(composite), buffer, lengthCache, null, async, cancellationToken); - } - - public override int ValidateAndGetLength(TComposite composite, ref NpgsqlLengthCache? lengthCache) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - var value = _get(composite); - if (value == null) - return 0; - - return NullableHandler.Exists - ? NullableHandler.ValidateAndGetLength(_handler, value, ref lengthCache, null) - : _handler.ValidateAndGetLength(value, ref lengthCache, null); - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs deleted file mode 100644 index 8d0c63b12e..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeMemberHandlerOfStruct.cs +++ /dev/null @@ -1,112 +0,0 @@ -using System; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - sealed class CompositeStructMemberHandler : CompositeMemberHandler - where TComposite : struct - { - delegate TMember GetMember(ref TComposite composite); - delegate void SetMember(ref TComposite composite, TMember value); - - readonly GetMember? _get; - readonly SetMember? _set; - readonly NpgsqlTypeHandler _handler; - - public CompositeStructMemberHandler(FieldInfo fieldInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(fieldInfo, postgresType) - { - var composite = Expression.Parameter(typeof(TComposite).MakeByRefType(), "composite"); - var value = Expression.Parameter(typeof(TMember), "value"); - - _get = Expression - .Lambda(Expression.Field(composite, fieldInfo), composite) - .Compile(); - _set = Expression - .Lambda(Expression.Assign(Expression.Field(composite, fieldInfo), value), composite, value) - .Compile(); - _handler = handler; - } - - public CompositeStructMemberHandler(PropertyInfo propertyInfo, PostgresType postgresType, NpgsqlTypeHandler handler) - : base(propertyInfo, postgresType) - { - var getMethod = propertyInfo.GetGetMethod(); - if (getMethod != null) - _get = (GetMember)Delegate.CreateDelegate(typeof(GetMember), getMethod); - - var setMethod = propertyInfo.GetSetMethod(); - if (setMethod != null) - _set = (SetMember)Delegate.CreateDelegate(typeof(SetMember), setMethod); - - Debug.Assert(setMethod != null || getMethod != null); - - _handler = handler; - } - - public override ValueTask Read(TComposite composite, NpgsqlReadBuffer buffer, bool async) - => throw new NotSupportedException(); - - public override async ValueTask Read(ByReference composite, NpgsqlReadBuffer buffer, bool async) - { - if (_set == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertySetter(typeof(TComposite), MemberInfo); - - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - Debug.Assert(oid == PostgresType.OID); - - var length = buffer.ReadInt32(); - if (length == -1) - return; - - var value = NullableHandler.Exists - ? await NullableHandler.ReadAsync(_handler, buffer, length, async) - : await _handler.Read(buffer, length, async); - - Set(composite, value); - } - - public override async Task Write(TComposite composite, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, bool async, CancellationToken cancellationToken = default) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - if (buffer.WriteSpaceLeft < sizeof(int)) - await buffer.Flush(async, cancellationToken); - - buffer.WriteUInt32(PostgresType.OID); - await (NullableHandler.Exists - ? NullableHandler.WriteAsync(_handler, _get(ref composite), buffer, lengthCache, null, async, cancellationToken) - : _handler.WriteWithLengthInternal(_get(ref composite), buffer, lengthCache, null, async, cancellationToken)); - } - - public override int ValidateAndGetLength(TComposite composite, ref NpgsqlLengthCache? lengthCache) - { - if (_get == null) - ThrowHelper.ThrowInvalidOperationException_NoPropertyGetter(typeof(TComposite), MemberInfo); - - var value = _get(ref composite); - if (value == null) - return 0; - - return NullableHandler.Exists - ? NullableHandler.ValidateAndGetLength(_handler, value, ref lengthCache, null) - : _handler.ValidateAndGetLength(value, ref lengthCache, null); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - void Set(ByReference composite, TMember value) - { - _set!(ref composite.Value, value); - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs deleted file mode 100644 index 24c2ee4905..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - abstract class CompositeParameterHandler - { - public NpgsqlTypeHandler Handler { get; } - public Type ParameterType { get; } - public int ParameterPosition { get; } - - public CompositeParameterHandler(NpgsqlTypeHandler handler, ParameterInfo parameterInfo) - { - Handler = handler; - ParameterType = parameterInfo.ParameterType; - ParameterPosition = parameterInfo.Position; - } - - public async ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - await buffer.Ensure(sizeof(uint) + sizeof(int), async); - - var oid = buffer.ReadUInt32(); - var length = buffer.ReadInt32(); - if (length == -1) - return default!; - - return NullableHandler.Exists - ? await NullableHandler.ReadAsync(Handler, buffer, length, async) - : await Handler.Read(buffer, length, async); - } - - public abstract ValueTask Read(NpgsqlReadBuffer buffer, bool async); - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs deleted file mode 100644 index 63274cbbcf..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeParameterHandler`.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - sealed class CompositeParameterHandler : CompositeParameterHandler - { - public CompositeParameterHandler(NpgsqlTypeHandler handler, ParameterInfo parameterInfo) - : base(handler, parameterInfo) { } - - public override ValueTask Read(NpgsqlReadBuffer buffer, bool async) - { - var task = Read(buffer, async); - return task.IsCompleted - ? new ValueTask(task.Result) - : AwaitTask(); - - async ValueTask AwaitTask() => await task; - } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeTypeHandlerFactory.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeTypeHandlerFactory.cs deleted file mode 100644 index 602c4cbab5..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/CompositeTypeHandlerFactory.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - class CompositeTypeHandlerFactory : NpgsqlTypeHandlerFactory, ICompositeTypeHandlerFactory - { - public INpgsqlNameTranslator NameTranslator { get; } - - internal CompositeTypeHandlerFactory(INpgsqlNameTranslator nameTranslator) - => NameTranslator = nameTranslator; - - public override NpgsqlTypeHandler Create(PostgresType pgType, NpgsqlConnection conn) - => new CompositeHandler((PostgresCompositeType)pgType, conn.Connector!.TypeMapper, NameTranslator); - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeHandler.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeHandler.cs deleted file mode 100644 index 804129c0c6..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeHandler.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; - -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - interface ICompositeHandler - { - /// - /// The CLR type mapped to the PostgreSQL composite type. - /// - Type CompositeType { get; } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeTypeHandlerFactory.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeTypeHandlerFactory.cs deleted file mode 100644 index 520a662b99..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/ICompositeTypeHandlerFactory.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - /// - /// Interface implemented by all mapped composite handler factories. - /// Used to expose the name translator for those reflecting composite mappings (e.g. EF Core). - /// - public interface ICompositeTypeHandlerFactory - { - /// - /// The name translator used for this composite. - /// - INpgsqlNameTranslator NameTranslator { get; } - } -} diff --git a/src/Npgsql/TypeHandlers/CompositeHandlers/IsValueType.cs b/src/Npgsql/TypeHandlers/CompositeHandlers/IsValueType.cs deleted file mode 100644 index 3b4e09755c..0000000000 --- a/src/Npgsql/TypeHandlers/CompositeHandlers/IsValueType.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Npgsql.TypeHandlers.CompositeHandlers -{ - static class IsValueType - { - public static readonly bool Value = typeof(T).IsValueType; - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/DateHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/DateHandler.cs deleted file mode 100644 index 18cb14ae6e..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/DateHandler.cs +++ /dev/null @@ -1,127 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL date data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("date", NpgsqlDbType.Date, DbType.Date, typeof(NpgsqlDate))] - public class DateHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new DateHandler(postgresType, conn.Connector!.ConvertInfinityDateTime); - } - - /// - /// A type handler for the PostgreSQL date data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class DateHandler : NpgsqlSimpleTypeHandlerWithPsv - { - /// - /// Whether to convert positive and negative infinity values to DateTime.{Max,Min}Value when - /// a DateTime is requested - /// - readonly bool _convertInfinityDateTime; - - /// - /// Constructs a - /// - public DateHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType) - => _convertInfinityDateTime = convertInfinityDateTime; - - #region Read - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var npgsqlDate = ReadPsv(buf, len, fieldDescription); - - if (npgsqlDate.IsFinite) - return (DateTime)npgsqlDate; - if (!_convertInfinityDateTime) - throw new InvalidCastException("Can't convert infinite date values to DateTime"); - if (npgsqlDate.IsInfinity) - return DateTime.MaxValue; - return DateTime.MinValue; - } - - /// - /// Copied wholesale from Postgresql backend/utils/adt/datetime.c:j2date - /// - protected override NpgsqlDate ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var binDate = buf.ReadInt32(); - - return binDate switch - { - int.MaxValue => NpgsqlDate.Infinity, - int.MinValue => NpgsqlDate.NegativeInfinity, - _ => new NpgsqlDate(binDate + 730119) - }; - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) => 4; - - /// - public override int ValidateAndGetLength(NpgsqlDate value, NpgsqlParameter? parameter) => 4; - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - NpgsqlDate value2; - if (_convertInfinityDateTime) - { - if (value == DateTime.MaxValue) - value2 = NpgsqlDate.Infinity; - else if (value == DateTime.MinValue) - value2 = NpgsqlDate.NegativeInfinity; - else - value2 = new NpgsqlDate(value); - } - else - value2 = new NpgsqlDate(value); - - Write(value2, buf, parameter); - } - - /// - public override void Write(NpgsqlDate value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (value == NpgsqlDate.NegativeInfinity) - buf.WriteInt32(int.MinValue); - else if (value == NpgsqlDate.Infinity) - buf.WriteInt32(int.MaxValue); - else - buf.WriteInt32(value.DaysSinceEra - 730119); - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/IntervalHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/IntervalHandler.cs deleted file mode 100644 index 004fc27831..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/IntervalHandler.cs +++ /dev/null @@ -1,79 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL interval data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("interval", NpgsqlDbType.Interval, new[] { typeof(TimeSpan), typeof(NpgsqlTimeSpan) })] - public class IntervalHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes // Check for the legacy floating point timestamps feature - ? new IntervalHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - /// - /// A type handler for the PostgreSQL date interval type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class IntervalHandler : NpgsqlSimpleTypeHandlerWithPsv - { - /// - /// Constructs an - /// - public IntervalHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override TimeSpan Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => (TimeSpan)((INpgsqlSimpleTypeHandler)this).Read(buf, len, fieldDescription); - - /// - protected override NpgsqlTimeSpan ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var ticks = buf.ReadInt64(); - var day = buf.ReadInt32(); - var month = buf.ReadInt32(); - return new NpgsqlTimeSpan(month, day, ticks * 10); - } - - /// - public override int ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) => 16; - - /// - public override int ValidateAndGetLength(NpgsqlTimeSpan value, NpgsqlParameter? parameter) => 16; - - /// - public override void Write(NpgsqlTimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.Ticks / 10); // TODO: round? - buf.WriteInt32(value.Days); - buf.WriteInt32(value.Months); - } - - // TODO: Can write directly from TimeSpan - /// - public override void Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => Write(value, buf, parameter); - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeHandler.cs deleted file mode 100644 index c4f226ebb7..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeHandler.cs +++ /dev/null @@ -1,61 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL time data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("time without time zone", NpgsqlDbType.Time, new[] { DbType.Time })] - public class TimeHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes // Check for the legacy floating point timestamps feature - ? new TimeHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - /// - /// A type handler for the PostgreSQL time data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class TimeHandler : NpgsqlSimpleTypeHandler - { - /// - /// Constructs a . - /// - public TimeHandler(PostgresType postgresType) : base(postgresType) {} - - // PostgreSQL time resolution == 1 microsecond == 10 ticks - /// - public override TimeSpan Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new TimeSpan(buf.ReadInt64() * 10); - - /// - public override int ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) - => 8; - - /// - public override void Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteInt64(value.Ticks / 10); - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs deleted file mode 100644 index 4cfb7e2a69..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimeTzHandler.cs +++ /dev/null @@ -1,113 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL timetz data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("time with time zone", NpgsqlDbType.TimeTz)] - public class TimeTzHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes // Check for the legacy floating point timestamps feature - ? new TimeTzHandler(postgresType) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - /// - /// A type handler for the PostgreSQL timetz data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class TimeTzHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - // Binary Format: int64 expressing microseconds, int32 expressing timezone in seconds, negative - - /// - /// Constructs an . - /// - public TimeTzHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override DateTimeOffset Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - // Adjust from 1 microsecond to 100ns. Time zone (in seconds) is inverted. - var ticks = buf.ReadInt64() * 10; - var offset = new TimeSpan(0, 0, -buf.ReadInt32()); - return new DateTimeOffset(ticks + TimeSpan.TicksPerDay, offset); - } - - DateTime INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription).LocalDateTime; - - TimeSpan INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription).LocalDateTime.TimeOfDay; - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) => 12; - /// - public int ValidateAndGetLength(TimeSpan value, NpgsqlParameter? parameter) => 12; - /// - public int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) => 12; - - /// - public override void Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.TimeOfDay.Ticks / 10); - buf.WriteInt32(-(int)(value.Offset.Ticks / TimeSpan.TicksPerSecond)); - } - - /// - public void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.TimeOfDay.Ticks / 10); - - switch (value.Kind) - { - case DateTimeKind.Utc: - buf.WriteInt32(0); - break; - case DateTimeKind.Unspecified: - // Treat as local... - case DateTimeKind.Local: - buf.WriteInt32(-(int)(TimeZoneInfo.Local.BaseUtcOffset.Ticks / TimeSpan.TicksPerSecond)); - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {value.Kind} of enum {nameof(DateTimeKind)}. Please file a bug."); - } - } - - /// - public void Write(TimeSpan value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteInt64(value.Ticks / 10); - buf.WriteInt32(-(int)(TimeZoneInfo.Local.BaseUtcOffset.Ticks / TimeSpan.TicksPerSecond)); - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampHandler.cs deleted file mode 100644 index 04dfb55085..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampHandler.cs +++ /dev/null @@ -1,202 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using NpgsqlTypes; -using System.Data; -using System.Runtime.CompilerServices; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL timestamp data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("timestamp without time zone", NpgsqlDbType.Timestamp, new[] { DbType.DateTime, DbType.DateTime2 }, new[] { typeof(NpgsqlDateTime), typeof(DateTime) }, DbType.DateTime)] - public class TimestampHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes // Check for the legacy floating point timestamps feature - ? new TimestampHandler(postgresType, conn.Connector!.ConvertInfinityDateTime) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - /// - /// A type handler for the PostgreSQL timestamp data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class TimestampHandler : NpgsqlSimpleTypeHandlerWithPsv - { - /// - /// Whether to convert positive and negative infinity values to DateTime.{Max,Min}Value when - /// a DateTime is requested - /// - protected readonly bool ConvertInfinityDateTime; - - /// - /// Constructs a . - /// - public TimestampHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType) => ConvertInfinityDateTime = convertInfinityDateTime; - - #region Read - - private protected const string InfinityExceptionMessage = "Can't convert infinite timestamp values to DateTime"; - private protected const string OutOfRangeExceptionMessage = "Out of the range of DateTime (year must be between 1 and 9999)"; - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - - var postgresTimestamp = buf.ReadInt64(); - if (postgresTimestamp == long.MaxValue) - return ConvertInfinityDateTime - ? DateTime.MaxValue - : throw new InvalidCastException(InfinityExceptionMessage); - if (postgresTimestamp == long.MinValue) - return ConvertInfinityDateTime - ? DateTime.MinValue - : throw new InvalidCastException(InfinityExceptionMessage); - - try - { - return FromPostgresTimestamp(postgresTimestamp); - } - catch (ArgumentOutOfRangeException e) - { - throw new InvalidCastException(OutOfRangeExceptionMessage, e); - } - } - - /// - protected override NpgsqlDateTime ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadTimeStamp(buf, len, fieldDescription); - - /// - /// Reads a timestamp from the buffer as an . - /// - protected NpgsqlDateTime ReadTimeStamp(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var value = buf.ReadInt64(); - if (value == long.MaxValue) - return NpgsqlDateTime.Infinity; - if (value == long.MinValue) - return NpgsqlDateTime.NegativeInfinity; - if (value >= 0) - { - var date = (int)(value / 86400000000L); - var time = value % 86400000000L; - - date += 730119; // 730119 = days since era (0001-01-01) for 2000-01-01 - time *= 10; // To 100ns - - return new NpgsqlDateTime(new NpgsqlDate(date), new TimeSpan(time)); - } - else - { - value = -value; - var date = (int)(value / 86400000000L); - var time = value % 86400000000L; - if (time != 0) - { - ++date; - time = 86400000000L - time; - } - - date = 730119 - date; // 730119 = days since era (0001-01-01) for 2000-01-01 - time *= 10; // To 100ns - - return new NpgsqlDateTime(new NpgsqlDate(date), new TimeSpan(time)); - } - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(DateTime value, NpgsqlParameter? parameter) => 8; - - /// - public override int ValidateAndGetLength(NpgsqlDateTime value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(NpgsqlDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (value.IsInfinity) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value.IsNegativeInfinity) - { - buf.WriteInt64(long.MinValue); - return; - } - - var uSecsTime = value.Time.Ticks / 10; - - if (value >= new NpgsqlDateTime(2000, 1, 1, 0, 0, 0)) - { - var uSecsDate = (value.Date.DaysSinceEra - 730119) * 86400000000L; - buf.WriteInt64(uSecsDate + uSecsTime); - } - else - { - var uSecsDate = (730119 - value.Date.DaysSinceEra) * 86400000000L; - buf.WriteInt64(-(uSecsDate - uSecsTime)); - } - } - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (ConvertInfinityDateTime) - { - if (value == DateTime.MaxValue) - { - buf.WriteInt64(long.MaxValue); - return; - } - - if (value == DateTime.MinValue) - { - buf.WriteInt64(long.MinValue); - return; - } - } - - var postgresTimestamp = ToPostgresTimestamp(value); - buf.WriteInt64(postgresTimestamp); - } - - #endregion Write - - const long PostgresTimestampOffsetTicks = 630822816000000000L; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static long ToPostgresTimestamp(DateTime value) - // Rounding here would cause problems because we would round up DateTime.MaxValue - // which would make it impossible to retrieve it back from the database, so we just drop the additional precision - => (value.Ticks - PostgresTimestampOffsetTicks) / 10; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static DateTime FromPostgresTimestamp(long value) - => new DateTime(value * 10 + PostgresTimestampOffsetTicks); - } -} diff --git a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs b/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs deleted file mode 100644 index 18e0f159fe..0000000000 --- a/src/Npgsql/TypeHandlers/DateTimeHandlers/TimestampTzHandler.cs +++ /dev/null @@ -1,136 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.DateTimeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL timestamptz data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTimeOffset, typeof(DateTimeOffset))] - public class TimestampTzHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => conn.HasIntegerDateTimes // Check for the legacy floating point timestamps feature - ? new TimestampTzHandler(postgresType, conn.Connector!.ConvertInfinityDateTime) - : throw new NotSupportedException($"The deprecated floating-point date/time format is not supported by {nameof(Npgsql)}."); - } - - /// - /// A type handler for the PostgreSQL timestamptz data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-datetime.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class TimestampTzHandler : TimestampHandler, INpgsqlSimpleTypeHandler - { - /// - /// Constructs an . - /// - public TimestampTzHandler(PostgresType postgresType, bool convertInfinityDateTime) - : base(postgresType, convertInfinityDateTime) {} - - /// - public override IRangeHandler CreateRangeHandler(PostgresType rangeBackendType) - => new RangeHandler(rangeBackendType, this); - - #region Read - - /// - public override DateTime Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => base.Read(buf, len, fieldDescription).ToLocalTime(); - - /// - protected override NpgsqlDateTime ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var ts = ReadTimeStamp(buf, len, fieldDescription); - return new NpgsqlDateTime(ts.Date, ts.Time, DateTimeKind.Utc).ToLocalTime(); - } - - DateTimeOffset INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var postgresTimestamp = buf.ReadInt64(); - if (postgresTimestamp == long.MaxValue) - return ConvertInfinityDateTime - ? DateTimeOffset.MaxValue - : throw new InvalidCastException(InfinityExceptionMessage); - if (postgresTimestamp == long.MinValue) - return ConvertInfinityDateTime - ? DateTimeOffset.MinValue - : throw new InvalidCastException(InfinityExceptionMessage); - try - { - return FromPostgresTimestamp(postgresTimestamp).ToLocalTime(); - } - catch (ArgumentOutOfRangeException e) - { - throw new InvalidCastException(OutOfRangeExceptionMessage, e); - } - } - - #endregion Read - - #region Write - - /// - public int ValidateAndGetLength(DateTimeOffset value, NpgsqlParameter? parameter) => 8; - - /// - public override void Write(NpgsqlDateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - switch (value.Kind) - { - case DateTimeKind.Unspecified: - case DateTimeKind.Utc: - break; - case DateTimeKind.Local: - value = value.ToUniversalTime(); - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {value.Kind} of enum {nameof(DateTimeKind)}. Please file a bug."); - } - - base.Write(value, buf, parameter); - } - - /// - public override void Write(DateTime value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - switch (value.Kind) - { - case DateTimeKind.Unspecified: - case DateTimeKind.Utc: - break; - case DateTimeKind.Local: - value = value.ToUniversalTime(); - break; - default: - throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {value.Kind} of enum {nameof(DateTimeKind)}. Please file a bug."); - } - - base.Write(value, buf, parameter); - } - - /// - public void Write(DateTimeOffset value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => base.Write(value.ToUniversalTime().DateTime, buf, parameter); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/EnumHandler.cs b/src/Npgsql/TypeHandlers/EnumHandler.cs deleted file mode 100644 index 7e9fed2fb6..0000000000 --- a/src/Npgsql/TypeHandlers/EnumHandler.cs +++ /dev/null @@ -1,117 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Reflection; -using System.Text; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// Interface implemented by all concrete handlers which handle enums - /// - interface IEnumHandler - { - /// - /// The CLR enum type mapped to the PostgreSQL enum - /// - Type EnumType { get; } - } - - class EnumHandler : NpgsqlSimpleTypeHandler, IEnumHandler where TEnum : struct, Enum - { - readonly Dictionary _enumToLabel; - readonly Dictionary _labelToEnum; - - public Type EnumType => typeof(TEnum); - - #region Construction - - internal EnumHandler(PostgresType postgresType, Dictionary enumToLabel, Dictionary labelToEnum) - : base(postgresType) - { - Debug.Assert(typeof(TEnum).GetTypeInfo().IsEnum, "EnumHandler instantiated for non-enum type"); - _enumToLabel = enumToLabel; - _labelToEnum = labelToEnum; - } - - #endregion - - #region Read - - public override TEnum Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var str = buf.ReadString(len); - var success = _labelToEnum.TryGetValue(str, out var value); - - if (!success) - throw new InvalidCastException($"Received enum value '{str}' from database which wasn't found on enum {typeof(TEnum)}"); - - return value; - } - - #endregion - - #region Write - - public override int ValidateAndGetLength(TEnum value, NpgsqlParameter? parameter) - => _enumToLabel.TryGetValue(value, out var str) - ? Encoding.UTF8.GetByteCount(str) - : throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); - - public override void Write(TEnum value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - if (!_enumToLabel.TryGetValue(value, out var str)) - throw new InvalidCastException($"Can't write value {value} as enum {typeof(TEnum)}"); - buf.WriteString(str); - } - - #endregion - } - - - /// - /// Interface implemented by all enum handler factories. - /// Used to expose the name translator for those reflecting enum mappings (e.g. EF Core). - /// - public interface IEnumTypeHandlerFactory - { - /// - /// The name translator used for this enum. - /// - INpgsqlNameTranslator NameTranslator { get; } - } - - class EnumTypeHandlerFactory : NpgsqlTypeHandlerFactory, IEnumTypeHandlerFactory - where TEnum : struct, Enum - { - readonly Dictionary _enumToLabel = new Dictionary(); - readonly Dictionary _labelToEnum = new Dictionary(); - - internal EnumTypeHandlerFactory(INpgsqlNameTranslator nameTranslator) - { - NameTranslator = nameTranslator; - - foreach (var field in typeof(TEnum).GetFields(BindingFlags.Static | BindingFlags.Public)) - { - var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault(); - var enumName = attribute is null - ? nameTranslator.TranslateMemberName(field.Name) - : attribute.PgName; - var enumValue = (TEnum)field.GetValue(null)!; - - _enumToLabel[enumValue] = enumName; - _labelToEnum[enumName] = enumValue; - } - } - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new EnumHandler(postgresType, _enumToLabel, _labelToEnum); - - public INpgsqlNameTranslator NameTranslator { get; } - } -} diff --git a/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs b/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs deleted file mode 100644 index e6f632d89c..0000000000 --- a/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsQueryHandler.cs +++ /dev/null @@ -1,301 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -// TODO: Need to work on the nullbility here -#nullable disable -#pragma warning disable CS8632 - -namespace Npgsql.TypeHandlers.FullTextSearchHandlers -{ - /// - /// A type handler for the PostgreSQL tsquery data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("tsquery", NpgsqlDbType.TsQuery, new[] { - typeof(NpgsqlTsQuery), typeof(NpgsqlTsQueryAnd), typeof(NpgsqlTsQueryEmpty), typeof(NpgsqlTsQueryFollowedBy), - typeof(NpgsqlTsQueryLexeme), typeof(NpgsqlTsQueryNot), typeof(NpgsqlTsQueryOr), typeof(NpgsqlTsQueryBinOp) }) - ] - public class TsQueryHandler : NpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler, - INpgsqlTypeHandler, INpgsqlTypeHandler - { - // 1 (type) + 1 (weight) + 1 (is prefix search) + 2046 (max str len) + 1 (null terminator) - const int MaxSingleTokenBytes = 2050; - - readonly Stack _stack = new Stack(); - - /// - public TsQueryHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numTokens = buf.ReadInt32(); - if (numTokens == 0) - return new NpgsqlTsQueryEmpty(); - - NpgsqlTsQuery? value = null; - var nodes = new Stack>(); - len -= 4; - - for (var tokenPos = 0; tokenPos < numTokens; tokenPos++) - { - await buf.Ensure(Math.Min(len, MaxSingleTokenBytes), async); - var readPos = buf.ReadPosition; - - var isOper = buf.ReadByte() == 2; - if (isOper) - { - var operKind = (NpgsqlTsQuery.NodeKind)buf.ReadByte(); - if (operKind == NpgsqlTsQuery.NodeKind.Not) - { - var node = new NpgsqlTsQueryNot(null); - InsertInTree(node, nodes, ref value); - nodes.Push(new Tuple(node, 0)); - } - else - { - var node = operKind switch - { - NpgsqlTsQuery.NodeKind.And => (NpgsqlTsQuery)new NpgsqlTsQueryAnd(null, null), - NpgsqlTsQuery.NodeKind.Or => new NpgsqlTsQueryOr(null, null), - NpgsqlTsQuery.NodeKind.Phrase => new NpgsqlTsQueryFollowedBy(null, buf.ReadInt16(), null), - _ => throw new InvalidOperationException($"Internal Npgsql bug: unexpected value {operKind} of enum {nameof(NpgsqlTsQuery.NodeKind)}. Please file a bug.") - }; - - InsertInTree(node, nodes, ref value); - - nodes.Push(new Tuple(node, 2)); - nodes.Push(new Tuple(node, 1)); - } - } - else - { - var weight = (NpgsqlTsQueryLexeme.Weight)buf.ReadByte(); - var prefix = buf.ReadByte() != 0; - var str = buf.ReadNullTerminatedString(); - InsertInTree(new NpgsqlTsQueryLexeme(str, weight, prefix), nodes, ref value); - } - - len -= buf.ReadPosition - readPos; - } - - if (nodes.Count != 0) - throw new InvalidOperationException("Internal Npgsql bug, please report."); - - return value!; - - static void InsertInTree(NpgsqlTsQuery node, Stack> nodes, ref NpgsqlTsQuery? value) - { - if (nodes.Count == 0) - value = node; - else - { - var parent = nodes.Pop(); - if (parent.Item2 == 0) - ((NpgsqlTsQueryNot)parent.Item1).Child = node; - else if (parent.Item2 == 1) - ((NpgsqlTsQueryBinOp)parent.Item1).Left = node; - else - ((NpgsqlTsQueryBinOp)parent.Item1).Right = node; - } - } - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryEmpty)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryLexeme)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryNot)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryAnd)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryOr)await Read(buf, len, async, fieldDescription); - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => (NpgsqlTsQueryFollowedBy)await Read(buf, len, async, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlTsQuery value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Kind == NpgsqlTsQuery.NodeKind.Empty - ? 4 - : 4 + GetNodeLength(value); - - int GetNodeLength(NpgsqlTsQuery node) - { - // TODO: Figure out the nullability strategy here - switch (node.Kind) - { - case NpgsqlTsQuery.NodeKind.Lexeme: - var strLen = Encoding.UTF8.GetByteCount(((NpgsqlTsQueryLexeme)node).Text); - if (strLen > 2046) - throw new InvalidCastException("Lexeme text too long. Must be at most 2046 bytes in UTF8."); - return 4 + strLen; - case NpgsqlTsQuery.NodeKind.And: - case NpgsqlTsQuery.NodeKind.Or: - return 2 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Phrase: - // 2 additional bytes for uint16 phrase operator "distance" field. - return 4 + GetNodeLength(((NpgsqlTsQueryBinOp)node).Left) + GetNodeLength(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Not: - return 2 + GetNodeLength(((NpgsqlTsQueryNot)node).Child); - case NpgsqlTsQuery.NodeKind.Empty: - throw new InvalidOperationException("Empty tsquery nodes must be top-level"); - default: - throw new InvalidOperationException("Illegal node kind: " + node.Kind); - } - } - - /// - public override async Task Write(NpgsqlTsQuery query, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var numTokens = GetTokenCount(query); - - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(numTokens); - - if (numTokens == 0) - return; - - _stack.Push(query); - - while (_stack.Count > 0) - { - if (buf.WriteSpaceLeft < 2) - await buf.Flush(async, cancellationToken); - - if (_stack.Peek().Kind == NpgsqlTsQuery.NodeKind.Lexeme && buf.WriteSpaceLeft < MaxSingleTokenBytes) - await buf.Flush(async, cancellationToken); - - var node = _stack.Pop(); - buf.WriteByte(node.Kind == NpgsqlTsQuery.NodeKind.Lexeme ? (byte)1 : (byte)2); - if (node.Kind != NpgsqlTsQuery.NodeKind.Lexeme) - { - buf.WriteByte((byte)node.Kind); - if (node.Kind == NpgsqlTsQuery.NodeKind.Not) - _stack.Push(((NpgsqlTsQueryNot)node).Child); - else - { - if (node.Kind == NpgsqlTsQuery.NodeKind.Phrase) - buf.WriteInt16(((NpgsqlTsQueryFollowedBy)node).Distance); - - _stack.Push(((NpgsqlTsQueryBinOp)node).Right); - _stack.Push(((NpgsqlTsQueryBinOp)node).Left); - } - } - else - { - var lexemeNode = (NpgsqlTsQueryLexeme)node; - buf.WriteByte((byte)lexemeNode.Weights); - buf.WriteByte(lexemeNode.IsPrefixSearch ? (byte)1 : (byte)0); - buf.WriteString(lexemeNode.Text); - buf.WriteByte(0); - } - } - - _stack.Clear(); - } - - int GetTokenCount(NpgsqlTsQuery node) - { - switch (node.Kind) - { - case NpgsqlTsQuery.NodeKind.Lexeme: - return 1; - case NpgsqlTsQuery.NodeKind.And: - case NpgsqlTsQuery.NodeKind.Or: - case NpgsqlTsQuery.NodeKind.Phrase: - return 1 + GetTokenCount(((NpgsqlTsQueryBinOp)node).Left) + GetTokenCount(((NpgsqlTsQueryBinOp)node).Right); - case NpgsqlTsQuery.NodeKind.Not: - return 1 + GetTokenCount(((NpgsqlTsQueryNot)node).Child); - case NpgsqlTsQuery.NodeKind.Empty: - return 0; - } - return -1; - } - - /// - public int ValidateAndGetLength(NpgsqlTsQueryOr value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryAnd value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryNot value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryLexeme value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryEmpty value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public int ValidateAndGetLength(NpgsqlTsQueryFollowedBy value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((NpgsqlTsQuery)value, ref lengthCache, parameter); - - /// - public Task Write(NpgsqlTsQueryOr value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryAnd value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryNot value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryLexeme value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write(NpgsqlTsQueryEmpty value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - /// - public Task Write( - NpgsqlTsQueryFollowedBy value, - NpgsqlWriteBuffer buf, - NpgsqlLengthCache? lengthCache, - NpgsqlParameter? parameter, - bool async, - CancellationToken cancellationToken = default) - => Write((NpgsqlTsQuery)value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs b/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs deleted file mode 100644 index 6312a44cf6..0000000000 --- a/src/Npgsql/TypeHandlers/FullTextSearchHandlers/TsVectorHandler.cs +++ /dev/null @@ -1,101 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.FullTextSearchHandlers -{ - /// - /// A type handler for the PostgreSQL tsvector data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-textsearch.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("tsvector", NpgsqlDbType.TsVector, typeof(NpgsqlTsVector))] - public class TsVectorHandler : NpgsqlTypeHandler - { - // 2561 = 2046 (max length lexeme string) + (1) null terminator + - // 2 (num_pos) + sizeof(int16) * 256 (max_num_pos (positions/wegihts)) - const int MaxSingleLexemeBytes = 2561; - - /// - public TsVectorHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numLexemes = buf.ReadInt32(); - len -= 4; - - var lexemes = new List(); - for (var lexemePos = 0; lexemePos < numLexemes; lexemePos++) - { - await buf.Ensure(Math.Min(len, MaxSingleLexemeBytes), async); - var posBefore = buf.ReadPosition; - - List? positions = null; - - var lexemeString = buf.ReadNullTerminatedString(); - int numPositions = buf.ReadInt16(); - for (var i = 0; i < numPositions; i++) - { - var wordEntryPos = buf.ReadInt16(); - if (positions == null) - positions = new List(); - positions.Add(new NpgsqlTsVector.Lexeme.WordEntryPos(wordEntryPos)); - } - - lexemes.Add(new NpgsqlTsVector.Lexeme(lexemeString, positions, true)); - - len -= buf.ReadPosition - posBefore; - } - - return new NpgsqlTsVector(lexemes, true); - } - - #endregion Read - - #region Write - - // TODO: Implement length cache - /// - public override int ValidateAndGetLength(NpgsqlTsVector value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + value.Sum(l => Encoding.UTF8.GetByteCount(l.Text) + 1 + 2 + l.Count * 2); - - /// - public override async Task Write(NpgsqlTsVector vector, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(vector.Count); - - foreach (var lexeme in vector) - { - if (buf.WriteSpaceLeft < MaxSingleLexemeBytes) - await buf.Flush(async, cancellationToken); - - buf.WriteString(lexeme.Text); - buf.WriteByte(0); - buf.WriteInt16(lexeme.Count); - for (var i = 0; i < lexeme.Count; i++) - buf.WriteInt16(lexeme[i].Value); - } - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/BoxHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/BoxHandler.cs deleted file mode 100644 index 65205321be..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/BoxHandler.cs +++ /dev/null @@ -1,45 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL box data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("box", NpgsqlDbType.Box, typeof(NpgsqlBox))] - public class BoxHandler : NpgsqlSimpleTypeHandler - { - /// - public BoxHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlBox Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new NpgsqlBox( - new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble()), - new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble()) - ); - - /// - public override int ValidateAndGetLength(NpgsqlBox value, NpgsqlParameter? parameter) - => 32; - - /// - public override void Write(NpgsqlBox value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.Right); - buf.WriteDouble(value.Top); - buf.WriteDouble(value.Left); - buf.WriteDouble(value.Bottom); - } - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/CircleHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/CircleHandler.cs deleted file mode 100644 index 578be4f910..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/CircleHandler.cs +++ /dev/null @@ -1,41 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL circle data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("circle", NpgsqlDbType.Circle, typeof(NpgsqlCircle))] - public class CircleHandler : NpgsqlSimpleTypeHandler - { - /// - public CircleHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlCircle Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new NpgsqlCircle(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlCircle value, NpgsqlParameter? parameter) - => 24; - - /// - public override void Write(NpgsqlCircle value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.X); - buf.WriteDouble(value.Y); - buf.WriteDouble(value.Radius); - } - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/LineHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/LineHandler.cs deleted file mode 100644 index 8a88c213d2..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/LineHandler.cs +++ /dev/null @@ -1,41 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL line data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("line", NpgsqlDbType.Line, typeof(NpgsqlLine))] - public class LineHandler : NpgsqlSimpleTypeHandler - { - /// - public LineHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlLine Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new NpgsqlLine(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlLine value, NpgsqlParameter? parameter) - => 24; - - /// - public override void Write(NpgsqlLine value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.A); - buf.WriteDouble(value.B); - buf.WriteDouble(value.C); - } - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs deleted file mode 100644 index c0c05427f0..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/LineSegmentHandler.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL lseg data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("lseg", NpgsqlDbType.LSeg, typeof(NpgsqlLSeg))] - public class LineSegmentHandler : NpgsqlSimpleTypeHandler - { - /// - public LineSegmentHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlLSeg Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new NpgsqlLSeg(buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlLSeg value, NpgsqlParameter? parameter) - => 32; - - /// - public override void Write(NpgsqlLSeg value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.Start.X); - buf.WriteDouble(value.Start.Y); - buf.WriteDouble(value.End.X); - buf.WriteDouble(value.End.Y); - } - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/PathHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/PathHandler.cs deleted file mode 100644 index 830b029552..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/PathHandler.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL path data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("path", NpgsqlDbType.Path, typeof(NpgsqlPath))] - public class PathHandler : NpgsqlTypeHandler - { - /// - public PathHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(5, async); - var open = buf.ReadByte() switch - { - 1 => false, - 0 => true, - _ => throw new Exception("Error decoding binary geometric path: bad open byte") - }; - - var numPoints = buf.ReadInt32(); - var result = new NpgsqlPath(numPoints, open); - for (var i = 0; i < numPoints; i++) - { - await buf.Ensure(16, async); - result.Add(new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble())); - } - return result; - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlPath value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 5 + value.Count * 16; - - /// - public override async Task Write(NpgsqlPath value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 5) - await buf.Flush(async, cancellationToken); - buf.WriteByte((byte)(value.Open ? 0 : 1)); - buf.WriteInt32(value.Count); - - foreach (var p in value) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(p.X); - buf.WriteDouble(p.Y); - } - } - - #endregion - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/PointHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/PointHandler.cs deleted file mode 100644 index d88d2bee1c..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/PointHandler.cs +++ /dev/null @@ -1,40 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL point data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("point", NpgsqlDbType.Point, typeof(NpgsqlPoint))] - public class PointHandler : NpgsqlSimpleTypeHandler - { - /// - public PointHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override NpgsqlPoint Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble()); - - /// - public override int ValidateAndGetLength(NpgsqlPoint value, NpgsqlParameter? parameter) - => 16; - - /// - public override void Write(NpgsqlPoint value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteDouble(value.X); - buf.WriteDouble(value.Y); - } - } -} diff --git a/src/Npgsql/TypeHandlers/GeometricHandlers/PolygonHandler.cs b/src/Npgsql/TypeHandlers/GeometricHandlers/PolygonHandler.cs deleted file mode 100644 index 2bc6ffbb33..0000000000 --- a/src/Npgsql/TypeHandlers/GeometricHandlers/PolygonHandler.cs +++ /dev/null @@ -1,69 +0,0 @@ -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.GeometricHandlers -{ - /// - /// A type handler for the PostgreSQL polygon data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("polygon", NpgsqlDbType.Polygon, typeof(NpgsqlPolygon))] - public class PolygonHandler : NpgsqlTypeHandler - { - /// - public PolygonHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numPoints = buf.ReadInt32(); - var result = new NpgsqlPolygon(numPoints); - for (var i = 0; i < numPoints; i++) - { - await buf.Ensure(16, async); - result.Add(new NpgsqlPoint(buf.ReadDouble(), buf.ReadDouble())); - } - return result; - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlPolygon value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => 4 + value.Count * 16; - - /// - public override async Task Write(NpgsqlPolygon value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Count); - - foreach (var p in value) - { - if (buf.WriteSpaceLeft < 16) - await buf.Flush(async, cancellationToken); - buf.WriteDouble(p.X); - buf.WriteDouble(p.Y); - } - } - - #endregion - } -} diff --git a/src/Npgsql/TypeHandlers/HstoreHandler.cs b/src/Npgsql/TypeHandlers/HstoreHandler.cs deleted file mode 100644 index 51970ae2d3..0000000000 --- a/src/Npgsql/TypeHandlers/HstoreHandler.cs +++ /dev/null @@ -1,188 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -#if !NETSTANDARD2_0 && !NETSTANDARD2_1 -using System.Collections.Immutable; -#endif - -namespace Npgsql.TypeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL hstore extension data type, which stores sets of key/value pairs - /// within a single PostgreSQL value. - /// - /// - /// See https://www.postgresql.org/docs/current/hstore.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("hstore", NpgsqlDbType.Hstore, new[] - { - typeof(Dictionary), - typeof(IDictionary), -#if !NETSTANDARD2_0 && !NETSTANDARD2_1 - typeof(ImmutableDictionary) -#endif - })] - public class HstoreHandlerFactory : NpgsqlTypeHandlerFactory> - { - /// - public override NpgsqlTypeHandler> Create(PostgresType postgresType, NpgsqlConnection conn) - => new HstoreHandler(postgresType, conn); - } - - /// - /// A type handler for the PostgreSQL hstore extension data type, which stores sets of key/value pairs within a - /// single PostgreSQL value. - /// - /// - /// See https://www.postgresql.org/docs/current/hstore.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// -#pragma warning disable CA1061 // Do not hide base class methods - public class HstoreHandler : - NpgsqlTypeHandler>, - INpgsqlTypeHandler> -#if !NETSTANDARD2_0 && !NETSTANDARD2_1 - , INpgsqlTypeHandler> -#endif - { - /// - /// The text handler to which we delegate encoding/decoding of the actual strings - /// - readonly TextHandler _textHandler; - - internal HstoreHandler(PostgresType postgresType, NpgsqlConnection connection) - : base(postgresType) => _textHandler = new TextHandler(postgresType, connection); - - #region Write - - /// - public int ValidateAndGetLength(IDictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - // Leave empty slot for the entire hstore length, and go ahead an populate the individual string slots - var pos = lengthCache.Position; - lengthCache.Set(0); - - var totalLen = 4; // Number of key-value pairs - foreach (var kv in value) - { - totalLen += 8; // Key length + value length - if (kv.Key == null) - throw new FormatException("HSTORE doesn't support null keys"); - totalLen += _textHandler.ValidateAndGetLength(kv.Key, ref lengthCache, null); - if (kv.Value != null) - totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null); - } - - return lengthCache.Lengths[pos] = totalLen; - } - - /// - public override int ValidateAndGetLength(Dictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public async Task Write(IDictionary value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(value.Count); - if (value.Count == 0) - return; - - foreach (var kv in value) - { - await _textHandler.WriteWithLengthInternal(kv.Key, buf, lengthCache, parameter, async, cancellationToken); - await _textHandler.WriteWithLengthInternal(kv.Value, buf, lengthCache, parameter, async, cancellationToken); - } - } - - /// - public override Task Write(Dictionary value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion - - #region Read - - async ValueTask ReadInto(T dictionary, int numElements, NpgsqlReadBuffer buf, bool async) - where T : IDictionary - { - for (var i = 0; i < numElements; i++) - { - await buf.Ensure(4, async); - var keyLen = buf.ReadInt32(); - Debug.Assert(keyLen != -1); - var key = await _textHandler.Read(buf, keyLen, async); - - await buf.Ensure(4, async); - var valueLen = buf.ReadInt32(); - - dictionary[key] = valueLen == -1 - ? null - : await _textHandler.Read(buf, valueLen, async); - } - return dictionary; - } - - /// - public override async ValueTask> Read(NpgsqlReadBuffer buf, int len, bool async, - FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var numElements = buf.ReadInt32(); - return await ReadInto(new Dictionary(numElements), numElements, buf, async); - } - - ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => new ValueTask>(Read(buf, len, async, fieldDescription).Result); - - #endregion - -#if !NETSTANDARD2_0 && !NETSTANDARD2_1 - #region ImmutableDictionary - - /// - public int ValidateAndGetLength( - ImmutableDictionary value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength((IDictionary)value, ref lengthCache, parameter); - - /// - public Task Write(ImmutableDictionary value, - NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write((IDictionary)value, buf, lengthCache, parameter, async, cancellationToken); - - async ValueTask> INpgsqlTypeHandler>.Read( - NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(4, async); - var numElements = buf.ReadInt32(); - return (await ReadInto(ImmutableDictionary.Empty.ToBuilder(), numElements, buf, async)) - .ToImmutable(); - } - - #endregion -#endif - } -#pragma warning restore CA1061 // Do not hide base class methods -} diff --git a/src/Npgsql/TypeHandlers/InternalCharHandler.cs b/src/Npgsql/TypeHandlers/InternalCharHandler.cs deleted file mode 100644 index e66dcb3a5b..0000000000 --- a/src/Npgsql/TypeHandlers/InternalCharHandler.cs +++ /dev/null @@ -1,92 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A type handler for the PostgreSQL "char" type, used only internally. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-character.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("char", NpgsqlDbType.InternalChar)] - public class InternalCharHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public InternalCharHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override char Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => (char)buf.ReadByte(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => buf.ReadByte(); - - #endregion - - #region Write - - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 1; - - /// - public override int ValidateAndGetLength(char value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((byte)value); - return 1; - } - - /// - public override void Write(char value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteByte((byte)value); - - #endregion - } -} diff --git a/src/Npgsql/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs b/src/Npgsql/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs deleted file mode 100644 index b1af8ca40e..0000000000 --- a/src/Npgsql/TypeHandlers/InternalTypeHandlers/Int2VectorHandler.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers.NumericHandlers; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.InternalTypeHandlers -{ - [TypeMapping("int2vector", NpgsqlDbType.Int2Vector)] - class Int2VectorHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler CreateNonGeneric(PostgresType pgType, NpgsqlConnection conn) - => new Int2VectorHandler(pgType, conn.Connector!.TypeMapper.DatabaseInfo.ByName["smallint"] - ?? throw new NpgsqlException("Two types called 'smallint' defined in the database")); - - public override Type DefaultValueType => typeof(short[]); - } - - /// - /// An int2vector is simply a regular array of shorts, with the sole exception that its lower bound must - /// be 0 (we send 1 for regular arrays). - /// - class Int2VectorHandler : ArrayHandler - { - public Int2VectorHandler(PostgresType arrayPostgresType, PostgresType postgresShortType) - : base(arrayPostgresType, new Int16Handler(postgresShortType), 0) { } - - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => new ArrayHandler>(arrayBackendType, this); - } -} diff --git a/src/Npgsql/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs b/src/Npgsql/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs deleted file mode 100644 index 50077df3c9..0000000000 --- a/src/Npgsql/TypeHandlers/InternalTypeHandlers/OIDVectorHandler.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers.NumericHandlers; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.InternalTypeHandlers -{ - [TypeMapping("oidvector", NpgsqlDbType.Oidvector)] - class OIDVectorHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler CreateNonGeneric(PostgresType pgType, NpgsqlConnection conn) - => new OIDVectorHandler(pgType, conn.Connector!.TypeMapper.DatabaseInfo.ByName["oid"] - ?? throw new NpgsqlException("Two types called 'oid' defined in the database")); - - public override Type DefaultValueType => typeof(uint[]); - } - - /// - /// An OIDVector is simply a regular array of uints, with the sole exception that its lower bound must - /// be 0 (we send 1 for regular arrays). - /// - class OIDVectorHandler : ArrayHandler - { - public OIDVectorHandler(PostgresType oidvectorType, PostgresType oidType) - : base(oidvectorType, new UInt32Handler(oidType), 0) { } - - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => new ArrayHandler>(arrayBackendType, this); - } -} diff --git a/src/Npgsql/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs b/src/Npgsql/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs deleted file mode 100644 index 8ab1116e91..0000000000 --- a/src/Npgsql/TypeHandlers/InternalTypeHandlers/PgLsnHandler.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.InternalTypeHandlers -{ - [TypeMapping("pg_lsn", NpgsqlDbType.PgLsn, typeof(NpgsqlLogSequenceNumber))] - class PgLsnHandler : NpgsqlSimpleTypeHandler - { - public PgLsnHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - public override NpgsqlLogSequenceNumber Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 8); - return new NpgsqlLogSequenceNumber(buf.ReadUInt64()); - } - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(NpgsqlLogSequenceNumber value, NpgsqlParameter? parameter) => 8; - - public override void Write(NpgsqlLogSequenceNumber value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteUInt64((ulong)value); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/InternalTypeHandlers/TidHandler.cs b/src/Npgsql/TypeHandlers/InternalTypeHandlers/TidHandler.cs deleted file mode 100644 index 673a2d65d2..0000000000 --- a/src/Npgsql/TypeHandlers/InternalTypeHandlers/TidHandler.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System.Diagnostics; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.InternalTypeHandlers -{ - [TypeMapping("tid", NpgsqlDbType.Tid, typeof(NpgsqlTid))] - class TidHandler : NpgsqlSimpleTypeHandler - { - public TidHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - public override NpgsqlTid Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 6); - - var blockNumber = buf.ReadUInt32(); - var offsetNumber = buf.ReadUInt16(); - - return new NpgsqlTid(blockNumber, offsetNumber); - } - - #endregion Read - - #region Write - - public override int ValidateAndGetLength(NpgsqlTid value, NpgsqlParameter? parameter) - => 6; - - public override void Write(NpgsqlTid value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - buf.WriteUInt32(value.BlockNumber); - buf.WriteUInt16(value.OffsetNumber); - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/JsonHandler.cs b/src/Npgsql/TypeHandlers/JsonHandler.cs deleted file mode 100644 index f3854b408a..0000000000 --- a/src/Npgsql/TypeHandlers/JsonHandler.cs +++ /dev/null @@ -1,283 +0,0 @@ -using System; -using System.IO; -using System.Threading.Tasks; -using System.Text.Json; -using System.Threading; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL jsonb data type. - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-json.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("jsonb", NpgsqlDbType.Jsonb, typeof(JsonDocument))] - public class JsonbHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly JsonSerializerOptions? _serializerOptions; - - /// - public JsonbHandlerFactory() => _serializerOptions = null; - - /// - public JsonbHandlerFactory(JsonSerializerOptions serializerOptions) - => _serializerOptions = serializerOptions; - - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new JsonHandler(postgresType, conn, isJsonb: true, _serializerOptions); - } - - /// - /// A factory for type handlers for the PostgreSQL json data type. - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-json.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("json", NpgsqlDbType.Json)] - public class JsonHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly JsonSerializerOptions? _serializerOptions; - - /// - public JsonHandlerFactory() => _serializerOptions = null; - - /// - public JsonHandlerFactory(JsonSerializerOptions serializerOptions) - => _serializerOptions = serializerOptions; - - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new JsonHandler(postgresType, conn, isJsonb: false, _serializerOptions); - } - - /// - /// A type handler for the PostgreSQL json and jsonb data type. - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-json.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class JsonHandler : NpgsqlTypeHandler, ITextReaderHandler - { - readonly JsonSerializerOptions _serializerOptions; - readonly TextHandler _textHandler; - readonly bool _isJsonb; - readonly int _headerLen; - - /// - /// Prepended to the string in the wire encoding - /// - const byte JsonbProtocolVersion = 1; - - static readonly JsonSerializerOptions DefaultSerializerOptions = new JsonSerializerOptions(); - - /// - protected internal JsonHandler(PostgresType postgresType, NpgsqlConnection connection, bool isJsonb, JsonSerializerOptions? serializerOptions = null) - : base(postgresType) - { - _serializerOptions = serializerOptions ?? DefaultSerializerOptions; - _isJsonb = isJsonb; - _headerLen = isJsonb ? 1 : 0; - _textHandler = new TextHandler(postgresType, connection); - } - - /// - protected internal override int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (typeof(TAny) == typeof(string) || - typeof(TAny) == typeof(char[]) || - typeof(TAny) == typeof(ArraySegment) || - typeof(TAny) == typeof(char) || - typeof(TAny) == typeof(byte[])) - { - return _textHandler.ValidateAndGetLength(value, ref lengthCache, parameter) + _headerLen; - } - - if (typeof(TAny) == typeof(JsonDocument)) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - var data = SerializeJsonDocument((JsonDocument)(object)value!); - if (parameter != null) - parameter.ConvertedValue = data; - return lengthCache.Set(data.Length + _headerLen); - } - - // User POCO, need to serialize. At least internally ArrayPool buffers are used... - var s = JsonSerializer.Serialize(value, _serializerOptions); - if (parameter != null) - parameter.ConvertedValue = s; - - return _textHandler.ValidateAndGetLength(s, ref lengthCache, parameter) + _headerLen; - } - - /// - protected override async Task WriteWithLength(TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - - if (_isJsonb) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - buf.WriteByte(JsonbProtocolVersion); - } - - if (typeof(TAny) == typeof(string)) - await _textHandler.Write((string)(object)value!, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(char[])) - await _textHandler.Write((char[])(object)value!, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(ArraySegment)) - await _textHandler.Write((ArraySegment)(object)value!, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(char)) - await _textHandler.Write((char)(object)value!, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(byte[])) - await _textHandler.Write((byte[])(object)value!, buf, lengthCache, parameter, async, cancellationToken); - else if (typeof(TAny) == typeof(JsonDocument)) - { - var data = parameter?.ConvertedValue != null - ? (byte[])parameter.ConvertedValue - : SerializeJsonDocument((JsonDocument)(object)value!); - await buf.WriteBytesRaw(data, async, cancellationToken); - } - else - { - // User POCO, read serialized representation from the validation phase - var s = parameter?.ConvertedValue != null - ? (string)parameter.ConvertedValue - : JsonSerializer.Serialize(value!, value!.GetType(), _serializerOptions); - - await _textHandler.Write(s, buf, lengthCache, parameter, async, cancellationToken); - } - } - - /// - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (_isJsonb) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - buf.WriteByte(JsonbProtocolVersion); - } - - await _textHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - /// - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch - { - DBNull _ => base.ValidateObjectAndGetLength(value, ref lengthCache, parameter), - string s => ValidateAndGetLength(s, ref lengthCache, parameter), - char[] s => ValidateAndGetLength(s, ref lengthCache, parameter), - ArraySegment s => ValidateAndGetLength(s, ref lengthCache, parameter), - char s => ValidateAndGetLength(s, ref lengthCache, parameter), - byte[] s => ValidateAndGetLength(s, ref lengthCache, parameter), - JsonDocument jsonDocument => ValidateAndGetLength(jsonDocument, ref lengthCache, parameter), - _ => ValidateAndGetLength(value, ref lengthCache, parameter) - }; - - /// - protected internal override async Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - // We call into WriteWithLength below, which assumes it as at least enough write space for the length - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - await (value switch - { - DBNull _ => base.WriteObjectWithLength(value, buf, lengthCache, parameter, async, cancellationToken), - string s => WriteWithLength(s, buf, lengthCache, parameter, async, cancellationToken), - char[] s => WriteWithLength(s, buf, lengthCache, parameter, async, cancellationToken), - ArraySegment s => WriteWithLength(s, buf, lengthCache, parameter, async, cancellationToken), - char s => WriteWithLength(s, buf, lengthCache, parameter, async, cancellationToken), - byte[] s => WriteWithLength(s, buf, lengthCache, parameter, async, cancellationToken), - JsonDocument jsonDocument => WriteWithLength(jsonDocument, buf, lengthCache, parameter, async, cancellationToken), - _ => WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken), - }); - } - - /// - protected internal override async ValueTask Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription = null) - { - if (_isJsonb) - { - await buf.Ensure(1, async); - var version = buf.ReadByte(); - if (version != JsonbProtocolVersion) - throw new NotSupportedException($"Don't know how to decode JSONB with wire format {version}, your connection is now broken"); - byteLen--; - } - - if (typeof(T) == typeof(string) || - typeof(T) == typeof(char[]) || - typeof(T) == typeof(ArraySegment) || - typeof(T) == typeof(char) || - typeof(T) == typeof(byte[])) - { - return await _textHandler.Read(buf, byteLen, async, fieldDescription); - } - - // See #2818 for possibly returning a JsonDocument directly over our internal buffer, rather - // than deserializing to string. - var s = await _textHandler.Read(buf, byteLen, async, fieldDescription); - return typeof(T) == typeof(JsonDocument) - ? (T)(object)JsonDocument.Parse(s) - : JsonSerializer.Deserialize(s, _serializerOptions)!; - } - - /// - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => Read(buf, len, async, fieldDescription); - - /// - public TextReader GetTextReader(Stream stream) - { - if (_isJsonb) - { - var version = stream.ReadByte(); - if (version != JsonbProtocolVersion) - throw new NpgsqlException($"Don't know how to decode jsonb with wire format {version}, your connection is now broken"); - } - - return _textHandler.GetTextReader(stream); - } - - byte[] SerializeJsonDocument(JsonDocument document) - { - // TODO: Writing is currently really inefficient - please don't criticize :) - // We need to implement one-pass writing to serialize directly to the buffer (or just switch to pipelines). - using var stream = new MemoryStream(); - using var writer = new Utf8JsonWriter(stream); - document.WriteTo(writer); - writer.Flush(); - return stream.ToArray(); - } - } -} diff --git a/src/Npgsql/TypeHandlers/JsonPathHandler.cs b/src/Npgsql/TypeHandlers/JsonPathHandler.cs deleted file mode 100644 index f7325ffd46..0000000000 --- a/src/Npgsql/TypeHandlers/JsonPathHandler.cs +++ /dev/null @@ -1,91 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql.TypeHandlers -{ - /// - /// A factory for type handlers for the PostgreSQL jsonpath data type. - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-json.html#DATATYPE-JSONPATH. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("jsonpath", NpgsqlDbType.JsonPath)] - public class JsonPathHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new JsonPathHandler(postgresType, conn); - } - - /// - /// A type handler for the PostgreSQL jsonpath data type. - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-json.html#DATATYPE-JSONPATH. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class JsonPathHandler : NpgsqlTypeHandler, ITextReaderHandler - { - readonly TextHandler _textHandler; - - /// - /// Prepended to the string in the wire encoding - /// - const byte JsonPathVersion = 1; - - /// - protected internal JsonPathHandler(PostgresType postgresType, NpgsqlConnection connection) - : base(postgresType) => _textHandler = new TextHandler(postgresType, connection); - - /// - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != JsonPathVersion) - throw new NotSupportedException($"Don't know how to decode JSONPATH with wire format {version}, your connection is now broken"); - - return await _textHandler.Read(buf, len - 1, async, fieldDescription); - } - - /// - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - 1 + _textHandler.ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(JsonPathVersion); - - await _textHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - /// - public TextReader GetTextReader(Stream stream) - { - var version = stream.ReadByte(); - if (version != JsonPathVersion) - throw new NotSupportedException($"Don't know how to decode JSONPATH with wire format {version}, your connection is now broken"); - - return _textHandler.GetTextReader(stream); - } - } -} diff --git a/src/Npgsql/TypeHandlers/LQueryHandler.cs b/src/Npgsql/TypeHandlers/LQueryHandler.cs deleted file mode 100644 index 5bfcd8f8a8..0000000000 --- a/src/Npgsql/TypeHandlers/LQueryHandler.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - - [TypeMapping("lquery", NpgsqlDbType.LQuery)] - class LQueryHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new LQueryHandler(postgresType, conn); - } - - /// - /// LQuery binary encoding is a simple UTF8 string, but prepended with a version number. - /// - public class LQueryHandler : TextHandler - { - /// - /// Prepended to the string in the wire encoding - /// - const byte LQueryProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LQueryHandler(PostgresType postgresType, NpgsqlConnection connection) - : base(postgresType, connection) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LQueryProtocolVersion) - throw new NotSupportedException($"Don't know how to decode lquery with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream) - { - var version = stream.ReadByte(); - if (version != LQueryProtocolVersion) - throw new NpgsqlException($"Don't know how to decode lquery with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream); - } - } -} diff --git a/src/Npgsql/TypeHandlers/LTreeHandler.cs b/src/Npgsql/TypeHandlers/LTreeHandler.cs deleted file mode 100644 index 5d443fd844..0000000000 --- a/src/Npgsql/TypeHandlers/LTreeHandler.cs +++ /dev/null @@ -1,105 +0,0 @@ -using System; -using System.Data; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - - [TypeMapping("ltree", NpgsqlDbType.LTree)] - class LTreeHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new LTreeHandler(postgresType, conn); - } - - /// - /// Ltree binary encoding is a simple UTF8 string, but prepended with a version number. - /// - public class LTreeHandler : TextHandler - { - /// - /// Prepended to the string in the wire encoding - /// - const byte LtreeProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LTreeHandler(PostgresType postgresType, NpgsqlConnection connection) - : base(postgresType, connection) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LtreeProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LtreeProtocolVersion) - throw new NotSupportedException($"Don't know how to decode ltree with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream) - { - var version = stream.ReadByte(); - if (version != LtreeProtocolVersion) - throw new NpgsqlException($"Don't know how to decode ltree with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream); - } - } -} diff --git a/src/Npgsql/TypeHandlers/LTxtQueryHandler.cs b/src/Npgsql/TypeHandlers/LTxtQueryHandler.cs deleted file mode 100644 index 04cb0de44c..0000000000 --- a/src/Npgsql/TypeHandlers/LTxtQueryHandler.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - - [TypeMapping("ltxtquery", NpgsqlDbType.LTxtQuery)] - class LTxtQueryHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new LTxtQueryHandler(postgresType, conn); - } - - /// - /// LTxtQuery binary encoding is a simple UTF8 string, but prepended with a version number. - /// - public class LTxtQueryHandler : TextHandler - { - /// - /// Prepended to the string in the wire encoding - /// - const byte LTxtQueryProtocolVersion = 1; - - internal override bool PreferTextWrite => false; - - protected internal LTxtQueryHandler(PostgresType postgresType, NpgsqlConnection connection) - : base(postgresType, connection) {} - - #region Write - - public override int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) => - base.ValidateAndGetLength(value, ref lengthCache, parameter) + 1; - - - public override async Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - public override async Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte(LTxtQueryProtocolVersion); - await base.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(1, async); - - var version = buf.ReadByte(); - if (version != LTxtQueryProtocolVersion) - throw new NotSupportedException($"Don't know how to decode ltxtquery with wire format {version}, your connection is now broken"); - - return await base.Read(buf, len - 1, async, fieldDescription); - } - - #endregion - - public override TextReader GetTextReader(Stream stream) - { - var version = stream.ReadByte(); - if (version != LTxtQueryProtocolVersion) - throw new NpgsqlException($"Don't know how to decode ltxtquery with wire format {version}, your connection is now broken"); - - return base.GetTextReader(stream); - } - } -} diff --git a/src/Npgsql/TypeHandlers/NetworkHandlers/CidrHandler.cs b/src/Npgsql/TypeHandlers/NetworkHandlers/CidrHandler.cs deleted file mode 100644 index 60336d85e3..0000000000 --- a/src/Npgsql/TypeHandlers/NetworkHandlers/CidrHandler.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System.Net; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -#pragma warning disable 618 - -namespace Npgsql.TypeHandlers.NetworkHandlers -{ - /// - /// A type handler for the PostgreSQL cidr data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("cidr", NpgsqlDbType.Cidr)] - public class CidrHandler : NpgsqlSimpleTypeHandler<(IPAddress Address, int Subnet)>, INpgsqlSimpleTypeHandler - { - /// - public CidrHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override (IPAddress Address, int Subnet) Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => InetHandler.DoRead(buf, len, fieldDescription, true); - - NpgsqlInet INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var (address, subnet) = Read(buf, len, fieldDescription); - return new NpgsqlInet(address, subnet); - } - - /// - public override int ValidateAndGetLength((IPAddress Address, int Subnet) value, NpgsqlParameter? parameter) - => InetHandler.GetLength(value.Address); - - /// - public int ValidateAndGetLength(NpgsqlInet value, NpgsqlParameter? parameter) - => InetHandler.GetLength(value.Address); - - /// - public override void Write((IPAddress Address, int Subnet) value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => InetHandler.DoWrite(value.Address, value.Subnet, buf, true); - - /// - public void Write(NpgsqlInet value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => InetHandler.DoWrite(value.Address, value.Netmask, buf, true); - } -} diff --git a/src/Npgsql/TypeHandlers/NetworkHandlers/InetHandler.cs b/src/Npgsql/TypeHandlers/NetworkHandlers/InetHandler.cs deleted file mode 100644 index 968d9a87da..0000000000 --- a/src/Npgsql/TypeHandlers/NetworkHandlers/InetHandler.cs +++ /dev/null @@ -1,161 +0,0 @@ -using System; -using System.Diagnostics; -using System.Net; -using System.Net.Sockets; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -#pragma warning disable 618 - -namespace Npgsql.TypeHandlers.NetworkHandlers -{ - /// - /// A type handler for the PostgreSQL cidr data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping( - "inet", - NpgsqlDbType.Inet, - new[] { typeof(IPAddress), typeof((IPAddress Address, int Subnet)), typeof(NpgsqlInet) })] - public class InetHandler : NpgsqlSimpleTypeHandlerWithPsv, - INpgsqlSimpleTypeHandler - { - // ReSharper disable InconsistentNaming - const byte IPv4 = 2; - const byte IPv6 = 3; - // ReSharper restore InconsistentNaming - - /// - public InetHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override IPAddress Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => DoRead(buf, len, fieldDescription, false).Address; - -#pragma warning disable CA1801 // Review unused parameters - internal static (IPAddress Address, int Subnet) DoRead( - NpgsqlReadBuffer buf, - int len, - FieldDescription? fieldDescription, - bool isCidrHandler) - { - buf.ReadByte(); // addressFamily - var mask = buf.ReadByte(); - var isCidr = buf.ReadByte() == 1; - Debug.Assert(isCidrHandler == isCidr); - var numBytes = buf.ReadByte(); - var bytes = new byte[numBytes]; - for (var i = 0; i < numBytes; i++) - bytes[i] = buf.ReadByte(); - - return (new IPAddress(bytes), mask); - } -#pragma warning restore CA1801 // Review unused parameters - - /// - protected override (IPAddress Address, int Subnet) ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => DoRead(buf, len, fieldDescription, false); - - NpgsqlInet INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - { - var (address, subnet) = DoRead(buf, len, fieldDescription, false); - return new NpgsqlInet(address, subnet); - } - - #endregion Read - - #region Write - - /// - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value switch { - null => -1, - DBNull _ => -1, - IPAddress ip => ValidateAndGetLength(ip, parameter), - ValueTuple tup => ValidateAndGetLength(tup, parameter), - NpgsqlInet inet => ValidateAndGetLength(inet, parameter), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType().Name} to database type {PgDisplayName}") - }; - - /// - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value switch { - DBNull _ => WriteWithLengthInternal(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken), - IPAddress ip => WriteWithLengthInternal(ip, buf, lengthCache, parameter, async, cancellationToken), - ValueTuple tup => WriteWithLengthInternal(tup, buf, lengthCache, parameter, async, cancellationToken), - NpgsqlInet inet => WriteWithLengthInternal(inet, buf, lengthCache, parameter, async, cancellationToken), - _ => throw new InvalidCastException($"Can't write CLR type {value.GetType().Name} to database type {PgDisplayName}") - }; - - /// - public override int ValidateAndGetLength(IPAddress value, NpgsqlParameter? parameter) - => GetLength(value); - - /// - public override int ValidateAndGetLength((IPAddress Address, int Subnet) value, NpgsqlParameter? parameter) - => GetLength(value.Address); - - /// - public int ValidateAndGetLength(NpgsqlInet value, NpgsqlParameter? parameter) - => GetLength(value.Address); - - /// - public override void Write(IPAddress value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value, -1, buf, false); - - /// - public override void Write((IPAddress Address, int Subnet) value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value.Address, value.Subnet, buf, false); - - /// - public void Write(NpgsqlInet value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => DoWrite(value.Address, value.Netmask, buf, false); - - internal static void DoWrite(IPAddress ip, int mask, NpgsqlWriteBuffer buf, bool isCidrHandler) - { - switch (ip.AddressFamily) { - case AddressFamily.InterNetwork: - buf.WriteByte(IPv4); - if (mask == -1) - mask = 32; - break; - case AddressFamily.InterNetworkV6: - buf.WriteByte(IPv6); - if (mask == -1) - mask = 128; - break; - default: - throw new InvalidCastException($"Can't handle IPAddress with AddressFamily {ip.AddressFamily}, only InterNetwork or InterNetworkV6!"); - } - - buf.WriteByte((byte)mask); - buf.WriteByte((byte)(isCidrHandler ? 1 : 0)); // Ignored on server side - var bytes = ip.GetAddressBytes(); - buf.WriteByte((byte)bytes.Length); - buf.WriteBytes(bytes, 0, bytes.Length); - } - - internal static int GetLength(IPAddress value) - => value.AddressFamily switch - { - AddressFamily.InterNetwork => 8, - AddressFamily.InterNetworkV6 => 20, - _ => throw new InvalidCastException($"Can't handle IPAddress with AddressFamily {value.AddressFamily}, only InterNetwork or InterNetworkV6!") - }; - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NetworkHandlers/MacaddrHandler.cs b/src/Npgsql/TypeHandlers/NetworkHandlers/MacaddrHandler.cs deleted file mode 100644 index 7ad6658e98..0000000000 --- a/src/Npgsql/TypeHandlers/NetworkHandlers/MacaddrHandler.cs +++ /dev/null @@ -1,58 +0,0 @@ -using System.Diagnostics; -using System.Net.NetworkInformation; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NetworkHandlers -{ - /// - /// A type handler for the PostgreSQL macaddr and macaddr8 data types. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-net-types.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("macaddr8", NpgsqlDbType.MacAddr8)] - [TypeMapping("macaddr", NpgsqlDbType.MacAddr, typeof(PhysicalAddress))] - public class MacaddrHandler : NpgsqlSimpleTypeHandler - { - /// - public MacaddrHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override PhysicalAddress Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len == 6 || len == 8); - - var bytes = new byte[len]; - - buf.ReadBytes(bytes, 0, len); - return new PhysicalAddress(bytes); - } - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(PhysicalAddress value, NpgsqlParameter? parameter) - => value.GetAddressBytes().Length; - - /// - public override void Write(PhysicalAddress value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var bytes = value.GetAddressBytes(); - buf.WriteBytes(bytes, 0, bytes.Length); - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/DecimalRaw.cs b/src/Npgsql/TypeHandlers/NumericHandlers/DecimalRaw.cs deleted file mode 100644 index 2d925f7fd9..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/DecimalRaw.cs +++ /dev/null @@ -1,151 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - [StructLayout(LayoutKind.Explicit)] - struct DecimalRaw - { - const int SignMask = unchecked((int)0x80000000); - const int ScaleMask = 0x00FF0000; - const int ScaleShift = 16; - - // Fast access for 10^n where n is 0-9 - internal static readonly uint[] Powers10 = new uint[] - { - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000 - }; - - // The maximum power of 10 that a 32 bit unsigned integer can store - internal static readonly int MaxUInt32Scale = Powers10.Length - 1; - - // Do not change the order in which these fields are declared. It - // should be same as in the System.Decimal struct. - [FieldOffset(0)] - decimal _value; - [FieldOffset(0)] - int _flags; - [FieldOffset(4)] - uint _high; - [FieldOffset(8)] - uint _low; - [FieldOffset(12)] - uint _mid; - - public bool Negative => (_flags & SignMask) != 0; - - public int Scale - { - get => (_flags & ScaleMask) >> ScaleShift; - set => _flags = (_flags & SignMask) | ((value << ScaleShift) & ScaleMask); - } - - public uint High => _high; - public uint Mid => _mid; - public uint Low => _low; - public decimal Value => _value; - - public DecimalRaw(decimal value) : this() => _value = value; - - public DecimalRaw(long value) : this() - { - if (value >= 0) - _flags = 0; - else - { - _flags = SignMask; - value = -value; - } - - _low = (uint)value; - _mid = (uint)(value >> 32); - _high = 0; - } - - public static void Negate(ref DecimalRaw value) - => value._flags ^= SignMask; - - public static void Add(ref DecimalRaw value, uint addend) - { - uint integer; - uint sum; - - integer = value._low; - value._low = sum = integer + addend; - - if (sum >= integer && sum >= addend) - return; - - integer = value._mid; - value._mid = sum = integer + 1; - - if (sum >= integer && sum >= 1) - return; - - integer = value._high; - value._high = sum = integer + 1; - - if (sum < integer || sum < 1) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - } - - public static void Multiply(ref DecimalRaw value, uint multiplier) - { - ulong integer; - uint remainder; - - integer = (ulong)value._low * multiplier; - value._low = (uint)integer; - remainder = (uint)(integer >> 32); - - integer = (ulong)value._mid * multiplier + remainder; - value._mid = (uint)integer; - remainder = (uint)(integer >> 32); - - integer = (ulong)value._high * multiplier + remainder; - value._high = (uint)integer; - remainder = (uint)(integer >> 32); - - if (remainder != 0) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - } - - public static uint Divide(ref DecimalRaw value, uint divisor) - { - ulong integer; - uint remainder = 0; - - if (value._high != 0) - { - integer = value._high; - value._high = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - if (value._mid != 0 || remainder != 0) - { - integer = ((ulong)remainder << 32) | value._mid; - value._mid = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - if (value._low != 0 || remainder != 0) - { - integer = ((ulong)remainder << 32) | value._low; - value._low = (uint)(integer / divisor); - remainder = (uint)(integer % divisor); - } - - return remainder; - } - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/DoubleHandler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/DoubleHandler.cs deleted file mode 100644 index 7ca52acd84..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/DoubleHandler.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL double precision data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("double precision", NpgsqlDbType.Double, DbType.Double, typeof(double))] - public class DoubleHandler : NpgsqlSimpleTypeHandler - { - /// - public DoubleHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override double Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadDouble(); - - /// - public override int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - => 8; - - /// - public override void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteDouble(value); - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/Int16Handler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/Int16Handler.cs deleted file mode 100644 index d0ecdc90f7..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/Int16Handler.cs +++ /dev/null @@ -1,115 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL smallint data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("smallint", NpgsqlDbType.Smallint, new[] { DbType.Int16, DbType.Byte, DbType.SByte }, new[] { typeof(short), typeof(byte), typeof(sbyte) }, DbType.Int16)] - public class Int16Handler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public Int16Handler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override short Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt16(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - sbyte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((sbyte)Read(buf, len, fieldDescription)); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(sbyte value, NpgsqlParameter? parameter) => 2; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 2; - - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((short)value); - return 2; - } - - /// - public override void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(sbyte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16(value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt16((short)value); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/Int32Handler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/Int32Handler.cs deleted file mode 100644 index d3b6fb0d0f..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/Int32Handler.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL integer data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("integer", NpgsqlDbType.Integer, DbType.Int32, typeof(int))] - public class Int32Handler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public Int32Handler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override int Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt32(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((short)Read(buf, len, fieldDescription)); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 4; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 4; - - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((int)value); - return 4; - } - - /// - public override void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32(value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt32((int)value); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/Int64Handler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/Int64Handler.cs deleted file mode 100644 index 7a3881b74c..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/Int64Handler.cs +++ /dev/null @@ -1,98 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL bigint data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("bigint", NpgsqlDbType.Bigint, DbType.Int64, typeof(long))] - public class Int64Handler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public Int64Handler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override long Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadInt64(); - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((byte)Read(buf, len, fieldDescription)); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((short)Read(buf, len, fieldDescription)); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => checked((int)Read(buf, len, fieldDescription)); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - decimal INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public override int ValidateAndGetLength(long value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => 8; - /// - public int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) => 8; - - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) - { - _ = checked((long)value); - return 8; - } - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) - { - _ = checked((long)value); - return 8; - } - - /// - public override void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64(value); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - /// - public void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteInt64((long)value); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/MoneyHandler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/MoneyHandler.cs deleted file mode 100644 index 38f0d1207d..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/MoneyHandler.cs +++ /dev/null @@ -1,58 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL money data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-money.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("money", NpgsqlDbType.Money, dbType: DbType.Currency)] - public class MoneyHandler : NpgsqlSimpleTypeHandler - { - const int MoneyScale = 2; - - /// - public MoneyHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override decimal Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => new DecimalRaw(buf.ReadInt64()) { Scale = MoneyScale }.Value; - - /// - public override int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) - => value < -92233720368547758.08M || value > 92233720368547758.07M - ? throw new OverflowException($"The supplied value ({value}) is outside the range for a PostgreSQL money value.") - : 8; - - /// - public override void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var raw = new DecimalRaw(value); - - var scaleDifference = MoneyScale - raw.Scale; - if (scaleDifference > 0) - DecimalRaw.Multiply(ref raw, DecimalRaw.Powers10[scaleDifference]); - else - { - value = Math.Round(value, MoneyScale, MidpointRounding.AwayFromZero); - raw = new DecimalRaw(value); - } - - var result = (long)raw.Mid << 32 | raw.Low; - if (raw.Negative) result = -result; - buf.WriteInt64(result); - } - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/NumericHandler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/NumericHandler.cs deleted file mode 100644 index 6e7a36787d..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/NumericHandler.cs +++ /dev/null @@ -1,228 +0,0 @@ -using System; -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL numeric data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("numeric", NpgsqlDbType.Numeric, new[] { DbType.Decimal, DbType.VarNumeric }, typeof(decimal), DbType.Decimal)] - public class NumericHandler : NpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler, - INpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public NumericHandler(PostgresType postgresType) : base(postgresType) {} - - const int MaxDecimalScale = 28; - - const int SignPositive = 0x0000; - const int SignNegative = 0x4000; - const int SignNan = 0xC000; - - const int MaxGroupCount = 8; - const int MaxGroupScale = 4; - - static readonly uint MaxGroupSize = DecimalRaw.Powers10[MaxGroupScale]; - - #region Read - - /// - public override decimal Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var result = new DecimalRaw(); - var groups = buf.ReadInt16(); - var weight = buf.ReadInt16() - groups + 1; - var sign = buf.ReadUInt16(); - - if (sign == SignNan) - throw new InvalidCastException("Numeric NaN not supported by System.Decimal"); - - if (sign == SignNegative) - DecimalRaw.Negate(ref result); - - var scale = buf.ReadInt16(); - if (scale > MaxDecimalScale) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - - result.Scale = scale; - - var scaleDifference = scale + weight * MaxGroupScale; - if (groups == MaxGroupCount) - { - while (groups-- > 1) - { - DecimalRaw.Multiply(ref result, MaxGroupSize); - DecimalRaw.Add(ref result, buf.ReadUInt16()); - } - - var group = buf.ReadUInt16(); - var groupSize = DecimalRaw.Powers10[-scaleDifference]; - if (group % groupSize != 0) - throw new OverflowException("Numeric value does not fit in a System.Decimal"); - - DecimalRaw.Multiply(ref result, MaxGroupSize / groupSize); - DecimalRaw.Add(ref result, group / groupSize); - } - else - { - while (groups-- > 0) - { - DecimalRaw.Multiply(ref result, MaxGroupSize); - DecimalRaw.Add(ref result, buf.ReadUInt16()); - } - - if (scaleDifference < 0) - DecimalRaw.Divide(ref result, DecimalRaw.Powers10[-scaleDifference]); - else - while (scaleDifference > 0) - { - var scaleChunk = Math.Min(DecimalRaw.MaxUInt32Scale, scaleDifference); - DecimalRaw.Multiply(ref result, DecimalRaw.Powers10[scaleChunk]); - scaleDifference -= scaleChunk; - } - } - return result.Value; - } - - byte INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (byte)Read(buf, len, fieldDescription); - - short INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (short)Read(buf, len, fieldDescription); - - int INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (int)Read(buf, len, fieldDescription); - - long INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (long)Read(buf, len, fieldDescription); - - float INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (float)Read(buf, len, fieldDescription); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => (double)Read(buf, len, fieldDescription); - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(decimal value, NpgsqlParameter? parameter) - { - var groupCount = 0; - var raw = new DecimalRaw(value); - if (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - uint remainder = default; - var scaleChunk = raw.Scale % MaxGroupScale; - if (scaleChunk > 0) - { - var divisor = DecimalRaw.Powers10[scaleChunk]; - var multiplier = DecimalRaw.Powers10[MaxGroupScale - scaleChunk]; - remainder = DecimalRaw.Divide(ref raw, divisor) * multiplier; - } - - while (remainder == 0) - remainder = DecimalRaw.Divide(ref raw, MaxGroupSize); - - groupCount++; - - while (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - DecimalRaw.Divide(ref raw, MaxGroupSize); - groupCount++; - } - } - - return 4 * sizeof(short) + groupCount * sizeof(short); - } - - /// - public int ValidateAndGetLength(short value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - /// - public int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - /// - public int ValidateAndGetLength(long value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - /// - public int ValidateAndGetLength(float value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - /// - public int ValidateAndGetLength(byte value, NpgsqlParameter? parameter) => ValidateAndGetLength((decimal)value, parameter); - - /// - public override void Write(decimal value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var weight = 0; - var groupCount = 0; - Span groups = stackalloc short[MaxGroupCount]; - - var raw = new DecimalRaw(value); - if (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - { - var scale = raw.Scale; - weight = -scale / MaxGroupScale - 1; - - uint remainder; - var scaleChunk = scale % MaxGroupScale; - if (scaleChunk > 0) - { - var divisor = DecimalRaw.Powers10[scaleChunk]; - var multiplier = DecimalRaw.Powers10[MaxGroupScale - scaleChunk]; - remainder = DecimalRaw.Divide(ref raw, divisor) * multiplier; - - if (remainder != 0) - { - weight--; - goto WriteGroups; - } - } - - while ((remainder = DecimalRaw.Divide(ref raw, MaxGroupSize)) == 0) - weight++; - - WriteGroups: - groups[groupCount++] = (short)remainder; - - while (raw.Low != 0 || raw.Mid != 0 || raw.High != 0) - groups[groupCount++] = (short)DecimalRaw.Divide(ref raw, MaxGroupSize); - } - - buf.WriteInt16(groupCount); - buf.WriteInt16(groupCount + weight); - buf.WriteInt16(raw.Negative ? SignNegative : SignPositive); - buf.WriteInt16(raw.Scale); - - while (groupCount > 0) - buf.WriteInt16(groups[--groupCount]); - } - - /// - public void Write(short value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - /// - public void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - /// - public void Write(long value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - /// - public void Write(byte value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - /// - public void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => Write((decimal)value, buf, parameter); - - #endregion - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/SingleHandler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/SingleHandler.cs deleted file mode 100644 index db129b6e81..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/SingleHandler.cs +++ /dev/null @@ -1,51 +0,0 @@ -using System.Data; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL real data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-numeric.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("real", NpgsqlDbType.Real, DbType.Single, typeof(float))] - public class SingleHandler : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - public SingleHandler(PostgresType postgresType) : base(postgresType) {} - - #region Read - - /// - public override float Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadSingle(); - - double INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => Read(buf, len, fieldDescription); - - #endregion Read - - #region Write - - /// - public int ValidateAndGetLength(double value, NpgsqlParameter? parameter) => 4; - /// - public override int ValidateAndGetLength(float value, NpgsqlParameter? parameter) => 4; - - /// - public void Write(double value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteSingle((float)value); - /// - public override void Write(float value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => buf.WriteSingle(value); - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/NumericHandlers/UInt32Handler.cs b/src/Npgsql/TypeHandlers/NumericHandlers/UInt32Handler.cs deleted file mode 100644 index 3f2ee7c138..0000000000 --- a/src/Npgsql/TypeHandlers/NumericHandlers/UInt32Handler.cs +++ /dev/null @@ -1,40 +0,0 @@ -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers.NumericHandlers -{ - /// - /// A type handler for the PostgreSQL real data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-oid.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("oid", NpgsqlDbType.Oid)] - [TypeMapping("xid", NpgsqlDbType.Xid)] - [TypeMapping("cid", NpgsqlDbType.Cid)] - [TypeMapping("regtype", NpgsqlDbType.Regtype)] - [TypeMapping("regconfig", NpgsqlDbType.Regconfig)] - public class UInt32Handler : NpgsqlSimpleTypeHandler - { - /// - public UInt32Handler(PostgresType postgresType) : base(postgresType) {} - - /// - public override uint Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => buf.ReadUInt32(); - - /// - public override int ValidateAndGetLength(uint value, NpgsqlParameter? parameter) => 4; - - /// - public override void Write(uint value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => buf.WriteUInt32(value); - } -} diff --git a/src/Npgsql/TypeHandlers/RangeHandler.cs b/src/Npgsql/TypeHandlers/RangeHandler.cs deleted file mode 100644 index cae4365875..0000000000 --- a/src/Npgsql/TypeHandlers/RangeHandler.cs +++ /dev/null @@ -1,224 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// An interface implementing by , exposing the handler's supported range - /// CLR types. - /// - public interface IRangeHandler - { - /// - /// Exposes the range CLR types supported by this handler. - /// - Type[] SupportedRangeClrTypes { get; } - } - - /// - /// A type handler for PostgreSQL range types. - /// - /// - /// See https://www.postgresql.org/docs/current/static/rangetypes.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - /// the range subtype - public class RangeHandler : NpgsqlTypeHandler>, IRangeHandler - { - /// - /// The type handler for the element that this range type holds - /// - readonly NpgsqlTypeHandler _elementHandler; - - /// - public Type[] SupportedRangeClrTypes { get; } - - /// - public RangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler elementHandler) - : this(rangePostgresType, elementHandler, new[] { typeof(NpgsqlRange)}) {} - - /// - protected RangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler elementHandler, Type[] supportedElementClrTypes) - : base(rangePostgresType) - { - _elementHandler = elementHandler; - SupportedRangeClrTypes = supportedElementClrTypes; - } - - /// - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => new ArrayHandler>(arrayBackendType, this); - - internal override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(NpgsqlRange); - internal override Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null) => typeof(NpgsqlRange); - - /// - public override IRangeHandler CreateRangeHandler(PostgresType rangeBackendType) - => throw new NotSupportedException(); - - #region Read - - /// - public override TAny Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => Read(buf, len, false, fieldDescription).Result; - - /// - public override ValueTask> Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => DoRead(buf, len, async, fieldDescription); - - private protected async ValueTask> DoRead(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - await buf.Ensure(1, async); - - var flags = (RangeFlags)buf.ReadByte(); - if ((flags & RangeFlags.Empty) != 0) - return NpgsqlRange.Empty; - - var lowerBound = default(TAny); - var upperBound = default(TAny); - - if ((flags & RangeFlags.LowerBoundInfinite) == 0) - lowerBound = await _elementHandler.ReadWithLength(buf, async); - - if ((flags & RangeFlags.UpperBoundInfinite) == 0) - upperBound = await _elementHandler.ReadWithLength(buf, async); - - return new NpgsqlRange(lowerBound, upperBound, flags); - } - - #endregion - - #region Write - - /// - public override int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - private protected int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var totalLen = 1; - var lengthCachePos = lengthCache?.Position ?? 0; - if (!value.IsEmpty) - { - if (!value.LowerBoundInfinite) - { - totalLen += 4; - if (!(value.LowerBound is null) && typeof(TElement) != typeof(DBNull)) - totalLen += _elementHandler.ValidateAndGetLength(value.LowerBound, ref lengthCache, null); - } - - if (!value.UpperBoundInfinite) - { - totalLen += 4; - if (!(value.UpperBound is null) && typeof(TElement) != typeof(DBNull)) - totalLen += _elementHandler.ValidateAndGetLength(value.UpperBound, ref lengthCache, null); - } - } - - // If we're traversing an already-populated length cache, rewind to first element slot so that - // the elements' handlers can access their length cache values - if (lengthCache != null && lengthCache.IsPopulated) - lengthCache.Position = lengthCachePos; - - return totalLen; - } - - internal override Task WriteWithLengthInternal([AllowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - - if (value == null || typeof(TAny) == typeof(DBNull)) - { - buf.WriteInt32(-1); - return Task.CompletedTask; - } - - return WriteWithLengthCore(); - - async Task WriteWithLengthLong() - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - if (value == null || typeof(TAny) == typeof(DBNull)) - { - buf.WriteInt32(-1); - return; - } - - await WriteWithLengthCore(); - } - - Task WriteWithLengthCore() - { - if (this is INpgsqlTypeHandler typedHandler) - { - buf.WriteInt32(typedHandler.ValidateAndGetLength(value, ref lengthCache, parameter)); - return typedHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - else - throw new InvalidCastException($"Can't write CLR type {typeof(TAny)} to database type {PgDisplayName}"); - } - } - - /// - public override Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, lengthCache, parameter, async, cancellationToken); - - private protected async Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 1) - await buf.Flush(async, cancellationToken); - - buf.WriteByte((byte)value.Flags); - - if (value.IsEmpty) - return; - - if (!value.LowerBoundInfinite) - await _elementHandler.WriteWithLengthInternal(value.LowerBound, buf, lengthCache, null, async, cancellationToken); - - if (!value.UpperBoundInfinite) - await _elementHandler.WriteWithLengthInternal(value.UpperBound, buf, lengthCache, null, async, cancellationToken); - } - - #endregion - } - - /// - /// Type handler for PostgreSQL range types - /// - /// - /// Introduced in PostgreSQL 9.2. - /// https://www.postgresql.org/docs/current/static/rangetypes.html - /// - /// the main range subtype - /// an alternative range subtype - public class RangeHandler : RangeHandler, INpgsqlTypeHandler> - { - /// - public RangeHandler(PostgresType rangePostgresType, NpgsqlTypeHandler elementHandler) - : base(rangePostgresType, elementHandler, new[] { typeof(NpgsqlRange), typeof(NpgsqlRange) }) {} - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => DoRead(buf, len, async, fieldDescription); - - /// - public int ValidateAndGetLength(NpgsqlRange value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value, ref lengthCache, parameter); - - /// - public Task Write(NpgsqlRange value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => Write(value, buf, lengthCache, parameter, async, cancellationToken); - } -} diff --git a/src/Npgsql/TypeHandlers/RecordHandler.cs b/src/Npgsql/TypeHandlers/RecordHandler.cs deleted file mode 100644 index 48ffce560e..0000000000 --- a/src/Npgsql/TypeHandlers/RecordHandler.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; - -namespace Npgsql.TypeHandlers -{ - [TypeMapping("record")] - class RecordHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType pgType, NpgsqlConnection conn) - => new RecordHandler(pgType, conn.Connector!.TypeMapper); - } - - /// - /// Type handler for PostgreSQL record types. - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-pseudo.html - /// - /// Encoding (identical to composite): - /// A 32-bit integer with the number of columns, then for each column: - /// * An OID indicating the type of the column - /// * The length of the column(32-bit integer), or -1 if null - /// * The column data encoded as binary - /// - class RecordHandler : NpgsqlTypeHandler - { - readonly ConnectorTypeMapper _typeMapper; - - public RecordHandler(PostgresType postgresType, ConnectorTypeMapper typeMapper) - : base(postgresType) - { - _typeMapper = typeMapper; - } - - #region Read - - public override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var fieldCount = buf.ReadInt32(); - var result = new object[fieldCount]; - - for (var i = 0; i < fieldCount; i++) - { - await buf.Ensure(8, async); - var typeOID = buf.ReadUInt32(); - var fieldLen = buf.ReadInt32(); - if (fieldLen == -1) // Null field, simply skip it and leave at default - continue; - result[i] = await _typeMapper.GetByOID(typeOID).ReadAsObject(buf, fieldLen, async); - } - - return result; - } - - #endregion - - #region Write (unsupported) - - public override int ValidateAndGetLength(object[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException("Can't write record types"); - - public override Task Write(object[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => throw new NotSupportedException("Can't write record types"); - - #endregion - } -} diff --git a/src/Npgsql/TypeHandlers/TextHandler.cs b/src/Npgsql/TypeHandlers/TextHandler.cs deleted file mode 100644 index 7299472c27..0000000000 --- a/src/Npgsql/TypeHandlers/TextHandler.cs +++ /dev/null @@ -1,312 +0,0 @@ -using System; -using System.Data; -using System.IO; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A factory for type handlers for PostgreSQL character data types (text, char, varchar, xml...). - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-character.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("text", NpgsqlDbType.Text, - new[] { DbType.String, DbType.StringFixedLength, DbType.AnsiString, DbType.AnsiStringFixedLength }, - new[] { typeof(string), typeof(char[]), typeof(char), typeof(ArraySegment) }, - DbType.String - )] - [TypeMapping("xml", NpgsqlDbType.Xml, dbType: DbType.Xml)] - - [TypeMapping("character varying", NpgsqlDbType.Varchar, inferredDbType: DbType.String)] - [TypeMapping("character", NpgsqlDbType.Char, inferredDbType: DbType.String)] - [TypeMapping("name", NpgsqlDbType.Name, inferredDbType: DbType.String)] - [TypeMapping("refcursor", NpgsqlDbType.Refcursor, inferredDbType: DbType.String)] - [TypeMapping("citext", NpgsqlDbType.Citext, inferredDbType: DbType.String)] - [TypeMapping("unknown")] - public class TextHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - public override NpgsqlTypeHandler Create(PostgresType pgType, NpgsqlConnection conn) - => new TextHandler(pgType, conn); - } - - /// - /// A type handler for PostgreSQL character data types (text, char, varchar, xml...). - /// - /// - /// See https://www.postgresql.org/docs/current/datatype-character.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - public class TextHandler : NpgsqlTypeHandler, INpgsqlTypeHandler, INpgsqlTypeHandler>, - INpgsqlTypeHandler, INpgsqlTypeHandler, ITextReaderHandler - { - // Text types are handled a bit more efficiently when sent as text than as binary - // see https://github.com/npgsql/npgsql/issues/1210#issuecomment-235641670 - internal override bool PreferTextWrite => true; - - readonly Encoding _encoding; - - #region State - - readonly char[] _singleCharArray = new char[1]; - - #endregion - - /// - protected internal TextHandler(PostgresType postgresType, NpgsqlConnection connection) - : this(postgresType, connection.Connector!.TextEncoding) { } - - /// - protected internal TextHandler(PostgresType postgresType, Encoding encoding) - : base(postgresType) => _encoding = encoding; - - #region Read - - /// - public override ValueTask Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription = null) - { - return buf.ReadBytesLeft >= byteLen - ? new ValueTask(buf.ReadString(byteLen)) - : ReadLong(); - - async ValueTask ReadLong() - { - if (byteLen <= buf.Size) - { - // The string's byte representation can fit in our read buffer, read it. - while (buf.ReadBytesLeft < byteLen) - await buf.ReadMore(async); - return buf.ReadString(byteLen); - } - - // Bad case: the string's byte representation doesn't fit in our buffer. - // This is rare - will only happen in CommandBehavior.Sequential mode (otherwise the - // entire row is in memory). Tweaking the buffer length via the connection string can - // help avoid this. - - // Allocate a temporary byte buffer to hold the entire string and read it in chunks. - var tempBuf = new byte[byteLen]; - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(tempBuf, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return buf.TextEncoding.GetString(tempBuf); - } - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription) - { - if (byteLen <= buf.Size) - { - // The string's byte representation can fit in our read buffer, read it. - while (buf.ReadBytesLeft < byteLen) - await buf.ReadMore(async); - return buf.ReadChars(byteLen); - } - - // TODO: The following can be optimized with Decoder - no need to allocate a byte[] - var tempBuf = new byte[byteLen]; - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(tempBuf, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return buf.TextEncoding.GetChars(tempBuf); - } - - async ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - { - // Make sure we have enough bytes in the buffer for a single character - var maxBytes = Math.Min(buf.TextEncoding.GetMaxByteCount(1), len); - while (buf.ReadBytesLeft < maxBytes) - await buf.ReadMore(async); - - var decoder = buf.TextEncoding.GetDecoder(); - decoder.Convert(buf.Buffer, buf.ReadPosition, maxBytes, _singleCharArray, 0, 1, true, out var bytesUsed, out var charsUsed, out _); - buf.Skip(len - bytesUsed); - - if (charsUsed < 1) - throw new NpgsqlException("Could not read char - string was empty"); - - return _singleCharArray[0]; - } - - ValueTask> INpgsqlTypeHandler>.Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription) - => throw new NotSupportedException("Only writing ArraySegment to PostgreSQL text is supported, no reading."); - - ValueTask INpgsqlTypeHandler.Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription) - { - var bytes = new byte[byteLen]; - if (buf.ReadBytesLeft >= byteLen) - { - buf.ReadBytes(bytes, 0, byteLen); - return new ValueTask(bytes); - } - return ReadLong(); - - async ValueTask ReadLong() - { - if (byteLen <= buf.Size) - { - // The bytes can fit in our read buffer, read it. - while (buf.ReadBytesLeft < byteLen) - await buf.ReadMore(async); - buf.ReadBytes(bytes, 0, byteLen); - return bytes; - } - - // Bad case: the bytes don't fit in our buffer. - // This is rare - will only happen in CommandBehavior.Sequential mode (otherwise the - // entire row is in memory). Tweaking the buffer length via the connection string can - // help avoid this. - - var pos = 0; - while (true) - { - var len = Math.Min(buf.ReadBytesLeft, byteLen - pos); - buf.ReadBytes(bytes, pos, len); - pos += len; - if (pos < byteLen) - { - await buf.ReadMore(async); - continue; - } - break; - } - return bytes; - } - } - - #endregion - - #region Write - - /// - public override unsafe int ValidateAndGetLength(string value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - if (parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length) - return lengthCache.Set(_encoding.GetByteCount(value)); - fixed (char* p = value) - return lengthCache.Set(_encoding.GetByteCount(p, parameter.Size)); - } - - /// - public virtual int ValidateAndGetLength(char[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - return lengthCache.Set( - parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length - ? _encoding.GetByteCount(value) - : _encoding.GetByteCount(value, 0, parameter.Size) - ); - } - - /// - public virtual int ValidateAndGetLength(ArraySegment value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (lengthCache == null) - lengthCache = new NpgsqlLengthCache(1); - if (lengthCache.IsPopulated) - return lengthCache.Get(); - - if (parameter?.Size > 0) - throw new ArgumentException($"Parameter {parameter.ParameterName} is of type ArraySegment and should not have its Size set", parameter.ParameterName); - - return lengthCache.Set(value.Array is null ? 0 : _encoding.GetByteCount(value.Array, value.Offset, value.Count)); - } - - /// - public int ValidateAndGetLength(char value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - _singleCharArray[0] = value; - return _encoding.GetByteCount(_singleCharArray); - } - - /// - public int ValidateAndGetLength(byte[] value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value.Length; - - /// - public override Task Write(string value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteString(value, buf, lengthCache!, parameter, async, cancellationToken); - - /// - public virtual Task Write(char[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var charLen = parameter == null || parameter.Size <= 0 || parameter.Size >= value.Length - ? value.Length - : parameter.Size; - return buf.WriteChars(value, 0, charLen, lengthCache!.GetLast(), async, cancellationToken); - } - - /// - public virtual Task Write(ArraySegment value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value.Array is null ? Task.CompletedTask : buf.WriteChars(value.Array, value.Offset, value.Count, lengthCache!.GetLast(), async, cancellationToken); - - Task WriteString(string str, NpgsqlWriteBuffer buf, NpgsqlLengthCache lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var charLen = parameter == null || parameter.Size <= 0 || parameter.Size >= str.Length - ? str.Length - : parameter.Size; - return buf.WriteString(str, charLen, lengthCache!.GetLast(), async, cancellationToken); - } - - /// - public Task Write(char value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - _singleCharArray[0] = value; - var len = _encoding.GetByteCount(_singleCharArray); - return buf.WriteChars(_singleCharArray, 0, 1, len, async, cancellationToken); - } - - /// - public Task Write(byte[] value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => buf.WriteBytesRaw(value, async, cancellationToken); - - #endregion - - /// - public virtual TextReader GetTextReader(Stream stream) => new StreamReader(stream); - } -} diff --git a/src/Npgsql/TypeHandlers/UnknownTypeHandler.cs b/src/Npgsql/TypeHandlers/UnknownTypeHandler.cs deleted file mode 100644 index 2ed5c309e0..0000000000 --- a/src/Npgsql/TypeHandlers/UnknownTypeHandler.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; - -namespace Npgsql.TypeHandlers -{ - /// - /// Handles "conversions" for columns sent by the database with unknown OIDs. - /// This differs from TextHandler in that its a text-only handler (we don't want to receive binary - /// representations of the types registered here). - /// Note that this handler is also used in the very initial query that loads the OID mappings - /// (chicken and egg problem). - /// Also used for sending parameters with unknown types (OID=0) - /// - class UnknownTypeHandler : TextHandler - { - readonly NpgsqlConnector _connector; - - internal UnknownTypeHandler(NpgsqlConnection connection) - : base(UnknownBackendType.Instance, connection) => _connector = connection.Connector!; - - #region Read - - public override ValueTask Read(NpgsqlReadBuffer buf, int byteLen, bool async, FieldDescription? fieldDescription = null) - { - if (fieldDescription == null) - throw new Exception($"Received an unknown field but {nameof(fieldDescription)} is null (i.e. COPY mode)"); - - if (fieldDescription.IsBinaryFormat) - // At least get the name of the PostgreSQL type for the exception - throw new NotSupportedException( - _connector.TypeMapper.DatabaseInfo.ByOID.TryGetValue(fieldDescription.TypeOID, out var pgType) - ? $"The field '{fieldDescription.Name}' has type '{pgType.DisplayName}', which is currently unknown to Npgsql. You can retrieve it as a string by marking it as unknown, please see the FAQ." - : $"The field '{fieldDescription.Name}' has a type currently unknown to Npgsql (OID {fieldDescription.TypeOID}). You can retrieve it as a string by marking it as unknown, please see the FAQ." - ); - - return base.Read(buf, byteLen, async, fieldDescription); - } - - #endregion Read - - #region Write - - // Allow writing anything that is a string or can be converted to one via the unknown type handler - - protected internal override int ValidateAndGetLength(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateObjectAndGetLength(value!, ref lengthCache, parameter); - - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - if (value is string asString) - return base.ValidateAndGetLength(asString, ref lengthCache, parameter); - - if (parameter == null) - throw CreateConversionButNoParamException(value.GetType()); - - var converted = Convert.ToString(value)!; - parameter.ConvertedValue = converted; - - return base.ValidateAndGetLength(converted, ref lengthCache, parameter); - } - - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is DBNull) - return base.WriteObjectWithLength(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken); - - var convertedValue = value is string asString - ? asString - : (string)parameter!.ConvertedValue!; - - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - - buf.WriteInt32(ValidateObjectAndGetLength(value, ref lengthCache, parameter)); - return base.Write(convertedValue, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteWithLengthLong() - { - await buf.Flush(async, cancellationToken); - buf.WriteInt32(ValidateObjectAndGetLength(value!, ref lengthCache, parameter)); - await base.Write(convertedValue, buf, lengthCache, parameter, async, cancellationToken); - } - } - - #endregion Write - } -} diff --git a/src/Npgsql/TypeHandlers/UnmappedEnumHandler.cs b/src/Npgsql/TypeHandlers/UnmappedEnumHandler.cs deleted file mode 100644 index 9b5b96b97b..0000000000 --- a/src/Npgsql/TypeHandlers/UnmappedEnumHandler.cs +++ /dev/null @@ -1,153 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - class UnmappedEnumHandler : TextHandler - { - readonly INpgsqlNameTranslator _nameTranslator; - - readonly Dictionary _enumToLabel = new Dictionary(); - readonly Dictionary _labelToEnum = new Dictionary(); - - Type? _resolvedType; - - internal UnmappedEnumHandler(PostgresType pgType, INpgsqlNameTranslator nameTranslator, NpgsqlConnection connection) - : base(pgType, connection) - { - _nameTranslator = nameTranslator; - } - - #region Read - - protected internal override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - var s = await base.Read(buf, len, async, fieldDescription); - if (typeof(TAny) == typeof(string)) - return (TAny)(object)s; - - if (_resolvedType != typeof(TAny)) - Map(typeof(TAny)); - - if (!_labelToEnum.TryGetValue(s, out var value)) - throw new InvalidCastException($"Received enum value '{s}' from database which wasn't found on enum {typeof(TAny)}"); - - // TODO: Avoid boxing - return (TAny)(object)value; - } - - public override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => base.Read(buf, len, async, fieldDescription); - - #endregion - - #region Write - - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value == null || value is DBNull - ? -1 - : ValidateAndGetLength(value, ref lengthCache, parameter); - - protected internal override int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => ValidateAndGetLength(value!, ref lengthCache, parameter); - - int ValidateAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var type = value.GetType(); - if (type == typeof(string)) - return base.ValidateAndGetLength((string)value, ref lengthCache, parameter); - if (_resolvedType != type) - Map(type); - - // TODO: Avoid boxing - return _enumToLabel.TryGetValue((Enum)value, out var str) - ? base.ValidateAndGetLength(str, ref lengthCache, parameter) - : throw new InvalidCastException($"Can't write value {value} as enum {type}"); - } - - // TODO: This boxes the enum (again) - protected override Task WriteWithLength(TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => WriteObjectWithLength(value!, buf, lengthCache, parameter, async, cancellationToken); - - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value is DBNull) - return WriteWithLengthInternal(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken); - - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - - buf.WriteInt32(ValidateAndGetLength(value, ref lengthCache, parameter)); - return Write(value, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteWithLengthLong() - { - await buf.Flush(async, cancellationToken); - buf.WriteInt32(ValidateAndGetLength(value!, ref lengthCache, parameter)); - await Write(value!, buf, lengthCache, parameter, async, cancellationToken); - } - } - - internal Task Write(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - var type = value.GetType(); - if (type == typeof(string)) - return base.Write((string)value, buf, lengthCache, parameter, async, cancellationToken); - if (_resolvedType != type) - Map(type); - - // TODO: Avoid boxing - if (!_enumToLabel.TryGetValue((Enum)value, out var str)) - throw new InvalidCastException($"Can't write value {value} as enum {type}"); - return base.Write(str, buf, lengthCache, parameter, async, cancellationToken); - } - - #endregion - - #region Misc - - void Map(Type type) - { - Debug.Assert(_resolvedType != type); - - _enumToLabel.Clear(); - _labelToEnum.Clear(); - - foreach (var field in type.GetFields(BindingFlags.Static | BindingFlags.Public)) - { - var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault(); - var enumName = attribute?.PgName ?? _nameTranslator.TranslateMemberName(field.Name); - var enumValue = (Enum)field.GetValue(null)!; - - _enumToLabel[enumValue] = enumName; - _labelToEnum[enumName] = enumValue; - } - - _resolvedType = type; - } - - #endregion - } - - class UnmappedEnumTypeHandlerFactory : NpgsqlTypeHandlerFactory, IEnumTypeHandlerFactory - { - internal UnmappedEnumTypeHandlerFactory(INpgsqlNameTranslator nameTranslator) - { - NameTranslator = nameTranslator; - } - - public override NpgsqlTypeHandler Create(PostgresType pgType, NpgsqlConnection conn) - => new UnmappedEnumHandler(pgType, NameTranslator, conn); - - public INpgsqlNameTranslator NameTranslator { get; } - } -} diff --git a/src/Npgsql/TypeHandlers/UuidHandler.cs b/src/Npgsql/TypeHandlers/UuidHandler.cs deleted file mode 100644 index eb09ddebf9..0000000000 --- a/src/Npgsql/TypeHandlers/UuidHandler.cs +++ /dev/null @@ -1,82 +0,0 @@ -using System; -using System.Data; -using System.Runtime.InteropServices; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; - -namespace Npgsql.TypeHandlers -{ - /// - /// A type handler for the PostgreSQL uuid data type. - /// - /// - /// See https://www.postgresql.org/docs/current/static/datatype-uuid.html. - /// - /// The type handler API allows customizing Npgsql's behavior in powerful ways. However, although it is public, it - /// should be considered somewhat unstable, and may change in breaking ways, including in non-major releases. - /// Use it at your own risk. - /// - [TypeMapping("uuid", NpgsqlDbType.Uuid, DbType.Guid, typeof(Guid))] - public class UuidHandler : NpgsqlSimpleTypeHandler - { - // The following table shows .NET GUID vs Postgres UUID (RFC 4122) layouts. - // - // Note that the first fields are converted from/to native endianness (handled by the Read* - // and Write* methods), while the last field is always read/written in big-endian format. - // - // We're passing BitConverter.IsLittleEndian to prevent reversing endianness on little-endian systems. - // - // | Bits | Bytes | Name | Endianness (GUID) | Endianness (RFC 4122) | - // | ---- | ----- | ----- | ----------------- | --------------------- | - // | 32 | 4 | Data1 | Native | Big | - // | 16 | 2 | Data2 | Native | Big | - // | 16 | 2 | Data3 | Native | Big | - // | 64 | 8 | Data4 | Big | Big | - - /// - public UuidHandler(PostgresType postgresType) : base(postgresType) {} - - /// - public override Guid Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - var raw = new GuidRaw - { - Data1 = buf.ReadInt32(), - Data2 = buf.ReadInt16(), - Data3 = buf.ReadInt16(), - Data4 = buf.ReadInt64(BitConverter.IsLittleEndian) - }; - - return raw.Value; - } - - /// - public override int ValidateAndGetLength(Guid value, NpgsqlParameter? parameter) - => 16; - - /// - public override void Write(Guid value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - var raw = new GuidRaw(value); - - buf.WriteInt32(raw.Data1); - buf.WriteInt16(raw.Data2); - buf.WriteInt16(raw.Data3); - buf.WriteInt64(raw.Data4, BitConverter.IsLittleEndian); - } - - [StructLayout(LayoutKind.Explicit)] - struct GuidRaw - { - [FieldOffset(00)] public Guid Value; - [FieldOffset(00)] public int Data1; - [FieldOffset(04)] public short Data2; - [FieldOffset(06)] public short Data3; - [FieldOffset(08)] public long Data4; - public GuidRaw(Guid value) : this() => Value = value; - } - } -} diff --git a/src/Npgsql/TypeHandlers/VoidHandler.cs b/src/Npgsql/TypeHandlers/VoidHandler.cs deleted file mode 100644 index b2fee6f0cd..0000000000 --- a/src/Npgsql/TypeHandlers/VoidHandler.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; - -namespace Npgsql.TypeHandlers -{ - /// - /// https://www.postgresql.org/docs/current/static/datatype-boolean.html - /// - [TypeMapping("void")] - class VoidHandler : NpgsqlSimpleTypeHandler - { - public VoidHandler(PostgresType postgresType) : base(postgresType) {} - - public override DBNull Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => DBNull.Value; - - public override int ValidateAndGetLength(DBNull value, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - - public override void Write(DBNull value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - } -} diff --git a/src/Npgsql/TypeHandling/DefaultTypeHandlerFactory.cs b/src/Npgsql/TypeHandling/DefaultTypeHandlerFactory.cs deleted file mode 100644 index 8616694a49..0000000000 --- a/src/Npgsql/TypeHandling/DefaultTypeHandlerFactory.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.Reflection; -using Npgsql.PostgresTypes; - -namespace Npgsql.TypeHandling -{ - /// - /// A type handler factory used to instantiate Npgsql's built-in type handlers. - /// - class DefaultTypeHandlerFactory : NpgsqlTypeHandlerFactory - { - readonly Type _handlerType; - - internal DefaultTypeHandlerFactory(Type handlerType) - { - // Recursively look for the TypeHandler superclass to extract its T as the - // DefaultValueType - Type? baseClass = handlerType; - while (!baseClass.GetTypeInfo().IsGenericType || baseClass.GetGenericTypeDefinition() != typeof(NpgsqlTypeHandler<>)) - { - baseClass = baseClass.GetTypeInfo().BaseType; - if (baseClass == null) - throw new Exception($"Npgsql type handler {handlerType} doesn't inherit from TypeHandler<>?"); - } - - DefaultValueType = baseClass.GetGenericArguments()[0]; - _handlerType = handlerType; - } - - public override NpgsqlTypeHandler CreateNonGeneric(PostgresType pgType, NpgsqlConnection conn) - => (NpgsqlTypeHandler)Activator.CreateInstance(_handlerType, pgType)!; - - public override Type DefaultValueType { get; } - } -} diff --git a/src/Npgsql/TypeHandling/INpgsqlSimpleTypeHandler.cs b/src/Npgsql/TypeHandling/INpgsqlSimpleTypeHandler.cs deleted file mode 100644 index 3e278b384d..0000000000 --- a/src/Npgsql/TypeHandling/INpgsqlSimpleTypeHandler.cs +++ /dev/null @@ -1,47 +0,0 @@ -using Npgsql.BackendMessages; - -namespace Npgsql.TypeHandling -{ - /// - /// Type handlers that wish to support reading other types in additional to the main one can - /// implement this interface for all those types. - /// - public interface INpgsqlSimpleTypeHandler - { - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - T Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - int ValidateAndGetLength(T value, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, with the assumption that there is enough space in the buffer - /// (no I/O will occur). The Npgsql core will have taken care of that. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - void Write(T value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter); - } -} diff --git a/src/Npgsql/TypeHandling/INpgsqlTypeHandler.cs b/src/Npgsql/TypeHandling/INpgsqlTypeHandler.cs deleted file mode 100644 index 58a6befdf1..0000000000 --- a/src/Npgsql/TypeHandling/INpgsqlTypeHandler.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; - -namespace Npgsql.TypeHandling -{ - /// - /// Type handlers that wish to support reading other types in additional to the main one can - /// implement this interface for all those types. - /// - public interface INpgsqlTypeHandler - { - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception should be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// A cache where the length calculated during the validation phase can be stored for use at the writing phase. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer. - /// - /// The value to write. - /// The buffer to which to write. - /// A cache where the length calculated during the validation phase can be stored for use at the writing phase. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// - /// If I/O will be necessary (i.e. the buffer is full), determines whether it will be done synchronously or asynchronously. - /// - /// The that can be used to cancel the operation. - Task Write(T value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - } -} diff --git a/src/Npgsql/TypeHandling/ITextReaderHandler.cs b/src/Npgsql/TypeHandling/ITextReaderHandler.cs deleted file mode 100644 index 79eaa382bd..0000000000 --- a/src/Npgsql/TypeHandling/ITextReaderHandler.cs +++ /dev/null @@ -1,16 +0,0 @@ -using System.Data.Common; -using System.IO; - -namespace Npgsql.TypeHandling -{ - /// - /// Implemented by handlers which support , returns a standard - /// TextReader given a binary Stream. - /// - interface ITextReaderHandler - { - TextReader GetTextReader(Stream stream); - } - -#pragma warning disable CA1032 -} diff --git a/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandler.cs b/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandler.cs deleted file mode 100644 index f4abbb8e6c..0000000000 --- a/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandler.cs +++ /dev/null @@ -1,308 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Data.Common; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.Util; - -namespace Npgsql.TypeHandling -{ - /// - /// Base class for all simple type handlers, which read and write short, non-arbitrary lengthed - /// values to PostgreSQL. Provides a simpler API to implement when compared to - - /// Npgsql takes care of all I/O before calling into this type, so no I/O needs to be performed by it. - /// - /// - /// The default CLR type that this handler will read and write. For example, calling - /// on a column with this handler will return a value with type . - /// Type handlers can support additional types by implementing . - /// - public abstract class NpgsqlSimpleTypeHandler : NpgsqlTypeHandler, INpgsqlSimpleTypeHandler - { - delegate int NonGenericValidateAndGetLength(NpgsqlTypeHandler handler, object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - readonly NonGenericValidateAndGetLength _nonGenericValidateAndGetLength; - readonly NonGenericWriteWithLength _nonGenericWriteWithLength; - - static readonly ConcurrentDictionary - NonGenericDelegateCache = new ConcurrentDictionary(); - - /// - /// Constructs an . - /// - protected NpgsqlSimpleTypeHandler(PostgresType postgresType) - : base(postgresType) - { - // Get code-generated delegates for non-generic ValidateAndGetLength/WriteWithLengthInternal - (_nonGenericValidateAndGetLength, _nonGenericWriteWithLength) = - NonGenericDelegateCache.GetOrAdd(GetType(), t => ( - GenerateNonGenericValidationMethod(GetType()), - GenerateNonGenericWriteMethod(GetType(), typeof(INpgsqlSimpleTypeHandler<>))) - ); - } - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public abstract TDefault Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. This method is sealed for , - /// override . - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public sealed override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => Read(buf, len, async, fieldDescription); - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. This method is sealed for . - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - protected internal sealed override async ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(len, async); - return Read(buf, len, fieldDescription); - } - - /// - /// Reads a value of type with the given length from the provided buffer. - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. Type handlers typically don't need to override this - override - /// - but may do - /// so in exceptional cases where reading of arbitrary types is required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public override TAny Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - Debug.Assert(len <= buf.ReadBytesLeft); - - var asTypedHandler = this as INpgsqlSimpleTypeHandler; - if (asTypedHandler == null) - throw new InvalidCastException(fieldDescription == null - ? $"Can't cast database type to {typeof(TAny).Name}" - : $"Can't cast database type {fieldDescription.Handler.PgDisplayName} to {typeof(TAny).Name}" - ); - - return asTypedHandler.Read(buf, len, fieldDescription); - } - - #endregion Read - - #region Write - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - public abstract int ValidateAndGetLength(TDefault value, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, with the assumption that there is enough space in the buffer - /// (no I/O will occur). The Npgsql core will have taken care of that. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - public abstract void Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter); - - /// - /// This method is sealed, override . - /// - protected internal override int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => this is INpgsqlSimpleTypeHandler typedHandler - ? typedHandler.ValidateAndGetLength(value, parameter) - : throw new InvalidCastException($"Can't write CLR type {typeof(TAny)} to database type {PgDisplayName}"); - - /// - /// In the vast majority of cases writing a parameter to the buffer won't need to perform I/O. - /// - internal sealed override Task WriteWithLengthInternal([AllowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (value == null || typeof(TAny) == typeof(DBNull)) - { - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - buf.WriteInt32(-1); - return Task.CompletedTask; - } - - Debug.Assert(this is INpgsqlSimpleTypeHandler); - var typedHandler = (INpgsqlSimpleTypeHandler)this; - - var elementLen = typedHandler.ValidateAndGetLength(value, parameter); - if (buf.WriteSpaceLeft < 4 + elementLen) - return WriteWithLengthLong(); - buf.WriteInt32(elementLen); - typedHandler.Write(value, buf, parameter); - return Task.CompletedTask; - - async Task WriteWithLengthLong() - { - if (value == null || typeof(TAny) == typeof(DBNull)) - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(-1); - return; - } - - typedHandler = (INpgsqlSimpleTypeHandler)this; - elementLen = typedHandler.ValidateAndGetLength(value, parameter); - if (buf.WriteSpaceLeft < 4 + elementLen) - await buf.Flush(async, cancellationToken); - buf.WriteInt32(elementLen); - typedHandler.Write(value, buf, parameter); - } - } - - /// - /// Simple type handlers override instead of this. - /// - public sealed override Task Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => throw new NotSupportedException(); - - /// - /// Simple type handlers override instead of this. - /// - public sealed override int ValidateAndGetLength(TDefault value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => throw new NotSupportedException(); - - // Object overloads for non-generic NpgsqlParameter - - /// - /// Called to validate and get the length of a value of a non-generic . - /// Type handlers generally don't need to override this. - /// - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value == null || value is DBNull - ? -1 - : _nonGenericValidateAndGetLength(this, value, ref lengthCache, parameter); - - /// - /// Called to write the value of a non-generic . - /// Type handlers generally don't need to override this. - /// - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value is DBNull // For null just go through the default WriteWithLengthInternal - ? WriteWithLengthInternal(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken) - : _nonGenericWriteWithLength(this, value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion - - #region Code generation for non-generic writing - - // We need to support writing via non-generic NpgsqlParameter, which means we get requests - // to write some object with no generic typing information. - // We need to find out which INpgsqlTypeHandler interfaces our handler implements, and call - // the ValidateAndGetLength/WriteWithLengthInternal methods on the interface which corresponds to the - // value type. - // Since doing this with reflection every time is slow, we generate delegates to do this for us - // for each type handler. - - static NonGenericValidateAndGetLength GenerateNonGenericValidationMethod(Type handlerType) - { - var interfaces = handlerType.GetInterfaces().Where(i => - i.GetTypeInfo().IsGenericType && - i.GetGenericTypeDefinition() == typeof(INpgsqlSimpleTypeHandler<>) - ).Reverse().ToList(); - - Expression? ifElseExpression = null; - - var handlerParam = Expression.Parameter(typeof(NpgsqlTypeHandler), "handler"); - var valueParam = Expression.Parameter(typeof(object), "value"); - var lengthCacheParam = Expression.Parameter(typeof(NpgsqlLengthCache).MakeByRefType(), "lengthCache"); - var parameterParam = Expression.Parameter(typeof(NpgsqlParameter), "parameter"); - - var resultVariable = Expression.Variable(typeof(int), "result"); - - foreach (var i in interfaces) - { - var handledType = i.GenericTypeArguments[0]; - - ifElseExpression = Expression.IfThenElse( - // Test whether the type of the value given to the delegate corresponds - // to our current interface's handled type (i.e. the T in INpgsqlTypeHandler) - Expression.TypeEqual(valueParam, handledType), - // If it corresponds, cast the handler type (this) to INpgsqlTypeHandler - // and call its ValidateAndGetLength method - Expression.Assign( - resultVariable, - Expression.Call( - Expression.Convert(handlerParam, i), - i.GetMethod(nameof(INpgsqlSimpleTypeHandler.ValidateAndGetLength))!, - // Cast the value from object down to the interface's T - Expression.Convert(valueParam, handledType), - parameterParam - ) - ), - // If this is the first interface we're looking at, the else clause throws. - // Otherwise we stick the previous interface's IfThenElse in our else clause - ifElseExpression ?? Expression.Throw( - Expression.New( - MethodInfos.InvalidCastExceptionCtor, - Expression.Call( // Call string.Format to generate a nice informative exception message - MethodInfos.StringFormat, - new Expression[] - { - Expression.Constant($"Can't write CLR type {{0}} with handler type {handlerType.Name}"), - Expression.Call( // GetType() on the value - valueParam, - MethodInfos.ObjectGetType - ) - } - ) - ) - ) - ); - } - - if (ifElseExpression is null) - throw new Exception($"Type handler {handlerType.GetType().Name} does not implement the proper interface"); - - return Expression.Lambda( - Expression.Block(new[] { resultVariable }, ifElseExpression, resultVariable), - handlerParam, valueParam, lengthCacheParam, parameterParam - ).Compile(); - } - - #endregion Code generation for non-generic writing - } -} diff --git a/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandlerWithPsv.cs b/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandlerWithPsv.cs deleted file mode 100644 index 58e0b3afe4..0000000000 --- a/src/Npgsql/TypeHandling/NpgsqlSimpleTypeHandlerWithPsv.cs +++ /dev/null @@ -1,110 +0,0 @@ -using System; -using System.Data.Common; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using NpgsqlTypes; - -namespace Npgsql.TypeHandling -{ - /// - /// A simple type handler that supports a provider-specific value in addition to its default value. - /// This is necessary mainly in cases where the CLR type cannot represent the full range of the - /// PostgreSQL type, and a custom CLR type is needed (e.g. and - /// ). The provider-specific type will be returned - /// from calls to . - /// - /// - /// The default CLR type that this handler will read and write. For example, calling - /// on a column with this handler will return a value with type . - /// Type handlers can support additional types by implementing . - /// - /// The provider-specific CLR type that this handler will read and write. - public abstract class NpgsqlSimpleTypeHandlerWithPsv : NpgsqlSimpleTypeHandler, INpgsqlSimpleTypeHandler - { - /// - /// Constructs an - /// - /// - protected NpgsqlSimpleTypeHandlerWithPsv(PostgresType postgresType) - : base(postgresType) {} - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - protected abstract TPsv ReadPsv(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - TPsv INpgsqlSimpleTypeHandler.Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription) - => ReadPsv(buf, len, fieldDescription); - - // Since TAny isn't constrained to class? or struct (C# doesn't have a non-nullable constraint that doesn't limit us to either struct or class), - // we must use the bang operator here to tell the compiler that a null value will never returned. - - /// - /// Reads a column as the type handler's provider-specific type, assuming that it is already entirely - /// in memory (i.e. no I/O is necessary). Called by in non-sequential mode, which - /// buffers entire rows in memory. - /// - internal override object ReadPsvAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => Read(buf, len, fieldDescription)!; - - /// - /// Reads a column as the type handler's provider-specific type. If it is not already entirely in - /// memory, sync or async I/O will be performed as specified by . - /// - internal override async ValueTask ReadPsvAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - #endregion Read - - #region Write - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - public abstract int ValidateAndGetLength(TPsv value, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, with the assumption that there is enough space in the buffer - /// (no I/O will occur). The Npgsql core will have taken care of that. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - public abstract void Write(TPsv value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter); - - #endregion Write - - #region Misc - - internal override Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null) - => typeof(TPsv); - - /// - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => new ArrayHandlerWithPsv(arrayBackendType, this); - - #endregion Misc - } -} diff --git a/src/Npgsql/TypeHandling/NpgsqlTypeHandler.cs b/src/Npgsql/TypeHandling/NpgsqlTypeHandler.cs deleted file mode 100644 index 9ae3a58537..0000000000 --- a/src/Npgsql/TypeHandling/NpgsqlTypeHandler.cs +++ /dev/null @@ -1,245 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; - -namespace Npgsql.TypeHandling -{ - /// - /// Base class for all type handlers, which read and write CLR types into their PostgreSQL - /// binary representation. - /// Type handler writers shouldn't inherit from this class, inherit - /// or instead. - /// - public abstract class NpgsqlTypeHandler - { - /// - /// The PostgreSQL type handled by this type handler. - /// - internal PostgresType PostgresType { get; } - - /// - /// Constructs a . - /// - protected NpgsqlTypeHandler(PostgresType postgresType) => PostgresType = postgresType; - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - protected internal abstract ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Reads a value of type with the given length from the provided buffer, - /// with the assumption that it is entirely present in the provided memory buffer and no I/O will be - /// required. This can save the overhead of async functions and improves performance. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public abstract TAny Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - /// - /// Reads a column as the type handler's default read type, assuming that it is already entirely - /// in memory (i.e. no I/O is necessary). Called by in non-sequential mode, which - /// buffers entire rows in memory. - /// - internal abstract object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null); - - /// - /// Reads a column as the type handler's default read type. If it is not already entirely in - /// memory, sync or async I/O will be performed as specified by . - /// - internal abstract ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Reads a column as the type handler's provider-specific type, assuming that it is already entirely - /// in memory (i.e. no I/O is necessary). Called by in non-sequential mode, which - /// buffers entire rows in memory. - /// - internal virtual object ReadPsvAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => ReadAsObject(buf, len, fieldDescription); - - /// - /// Reads a column as the type handler's provider-specific type. If it is not already entirely in - /// memory, sync or async I/O will be performed as specified by . - /// - internal virtual ValueTask ReadPsvAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => ReadAsObject(buf, len, async, fieldDescription); - - /// - /// Reads a value from the buffer, assuming our read position is at the value's preceding length. - /// If the length is -1 (null), this method will return the default value. - /// - internal async ValueTask ReadWithLength(NpgsqlReadBuffer buf, bool async, FieldDescription? fieldDescription = null) - { - await buf.Ensure(4, async); - var len = buf.ReadInt32(); - return len == -1 - ? default! - : NullableHandler.Exists - ? await NullableHandler.ReadAsync(this, buf, len, async, fieldDescription) - : await Read(buf, len, async, fieldDescription); - } - - #endregion - - #region Write - - /// - /// Called to validate and get the length of a value of a generic . - /// - protected internal abstract int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Called to write the value of a generic . - /// - internal abstract Task WriteWithLengthInternal([AllowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - /// - /// Responsible for validating that a value represents a value of the correct and which can be - /// written for PostgreSQL - if the value cannot be written for any reason, an exception shold be thrown. - /// Also returns the byte length needed to write the value. - /// - /// The value to be written to PostgreSQL - /// - /// If the byte length calculation is costly (e.g. for UTF-8 strings), its result can be stored in the - /// length cache to be reused in the writing process, preventing recalculation. - /// - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// The number of bytes required to write the value. - protected internal abstract int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Writes a value to the provided buffer, using either sync or async I/O. - /// - /// The value to write. - /// The buffer to which to write. - /// - /// - /// The instance where this value resides. Can be used to access additional - /// information relevant to the write process (e.g. ). - /// - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// The that can be used to cancel the operation. - protected internal abstract Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - #endregion Write - - #region Misc - - internal abstract Type GetFieldType(FieldDescription? fieldDescription = null); - internal abstract Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null); - - internal virtual bool PreferTextWrite => false; - - /// - /// Creates a type handler for arrays of this handler's type. - /// - public abstract ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType); - - /// - /// Creates a type handler for ranges of this handler's type. - /// - public abstract IRangeHandler CreateRangeHandler(PostgresType rangeBackendType); - - /// - /// Used to create an exception when the provided type can be converted and written, but an - /// instance of is required for caching of the converted value - /// (in . - /// - protected Exception CreateConversionButNoParamException(Type clrType) - => new InvalidCastException($"Can't convert .NET type '{clrType}' to PostgreSQL '{PgDisplayName}' within an array"); - - internal string PgDisplayName => PostgresType.DisplayName; - - #endregion Misc - - #region Code generation for non-generic writing - - internal delegate Task NonGenericWriteWithLength(NpgsqlTypeHandler handler, object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken); - - internal static NonGenericWriteWithLength GenerateNonGenericWriteMethod(Type handlerType, Type interfaceType) - { - var interfaces = handlerType.GetInterfaces().Where(i => - i.GetTypeInfo().IsGenericType && - i.GetGenericTypeDefinition() == interfaceType - ).Reverse().ToList(); - - Expression? ifElseExpression = null; - - // NpgsqlTypeHandler handler, object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache lengthCache, NpgsqlParameter parameter, bool async, CancellationToken cancellationToken - var handlerParam = Expression.Parameter(typeof(NpgsqlTypeHandler), "handler"); - var valueParam = Expression.Parameter(typeof(object), "value"); - var bufParam = Expression.Parameter(typeof(NpgsqlWriteBuffer), "buf"); - var lengthCacheParam = Expression.Parameter(typeof(NpgsqlLengthCache), "lengthCache"); - var parameterParam = Expression.Parameter(typeof(NpgsqlParameter), "parameter"); - var asyncParam = Expression.Parameter(typeof(bool), "async"); - var cancellationTokenParam = Expression.Parameter(typeof(CancellationToken), "cancellationToken"); - - var resultVariable = Expression.Variable(typeof(Task), "result"); - - foreach (var i in interfaces) - { - var handledType = i.GenericTypeArguments[0]; - - ifElseExpression = Expression.IfThenElse( - // Test whether the type of the value given to the delegate corresponds - // to our current interface's handled type (i.e. the T in INpgsqlTypeHandler) - Expression.TypeEqual(valueParam, handledType), - // If it corresponds, call the handler's Write method with the appropriate generic parameter - Expression.Assign( - resultVariable, - Expression.Call( - handlerParam, - // Call the generic WriteWithLengthInternal with our handled type - nameof(WriteWithLengthInternal), - new[] { handledType }, - // Cast the value from object down to the interface's T - Expression.Convert(valueParam, handledType), - bufParam, - lengthCacheParam, - parameterParam, - asyncParam, - cancellationTokenParam - ) - ), - // If this is the first interface we're looking at, the else clause throws. - // Note that this should never happen since we passed ValidateAndGetLength. - // Otherwise we stick the previous interface's IfThenElse in our else clause - ifElseExpression ?? Expression.Throw(Expression.New(typeof(InvalidCastException))) - ); - } - - if (ifElseExpression is null) - throw new Exception($"Type handler {handlerType.GetType().Name} does not implement the proper interface"); - - return Expression.Lambda( - Expression.Block( - new[] { resultVariable }, - ifElseExpression, resultVariable - ), - handlerParam, valueParam, bufParam, lengthCacheParam, parameterParam, asyncParam, cancellationTokenParam - ).Compile(); - } - - #endregion Code generation for non-generic writing - } -} diff --git a/src/Npgsql/TypeHandling/NpgsqlTypeHandlerFactory.cs b/src/Npgsql/TypeHandling/NpgsqlTypeHandlerFactory.cs deleted file mode 100644 index cad0ac9cfd..0000000000 --- a/src/Npgsql/TypeHandling/NpgsqlTypeHandlerFactory.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System; -using Npgsql.PostgresTypes; -using Npgsql.TypeMapping; - -namespace Npgsql.TypeHandling -{ - /// - /// Base class for all type handler factories, which construct type handlers that know how - /// to read and write CLR types from/to PostgreSQL types. - /// - /// - /// In general, do not inherit from this class, inherit from instead. - /// - public abstract class NpgsqlTypeHandlerFactory - { - /// - /// Creates a type handler. - /// - public abstract NpgsqlTypeHandler CreateNonGeneric(PostgresType pgType, NpgsqlConnection conn); - - /// - /// The default CLR type that handlers produced by this factory will read and write. - /// - public abstract Type DefaultValueType { get; } - } - - /// - /// Base class for all type handler factories, which construct type handlers that know how - /// to read and write CLR types from/to PostgreSQL types. Type handler factories are set up - /// via in either the global or connection-specific type mapper. - /// - /// - /// - /// - /// The default CLR type that handlers produced by this factory will read and write. - public abstract class NpgsqlTypeHandlerFactory : NpgsqlTypeHandlerFactory - { - /// - /// Creates a type handler. - /// - public abstract NpgsqlTypeHandler Create(PostgresType pgType, NpgsqlConnection conn); - - /// - public override NpgsqlTypeHandler CreateNonGeneric(PostgresType pgType, NpgsqlConnection conn) - => Create(pgType, conn); - - /// - public override Type DefaultValueType => typeof(TDefault); - } -} diff --git a/src/Npgsql/TypeHandling/NpgsqlTypeHandler`.cs b/src/Npgsql/TypeHandling/NpgsqlTypeHandler`.cs deleted file mode 100644 index e743317b07..0000000000 --- a/src/Npgsql/TypeHandling/NpgsqlTypeHandler`.cs +++ /dev/null @@ -1,293 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Data.Common; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using Npgsql.Util; - -namespace Npgsql.TypeHandling -{ - /// - /// Base class for all type handlers, which read and write CLR types into their PostgreSQL - /// binary representation. Unless your type is arbitrary-length, consider inheriting from - /// instead. - /// - /// - /// The default CLR type that this handler will read and write. For example, calling - /// on a column with this handler will return a value with type . - /// Type handlers can support additional types by implementing . - /// - public abstract class NpgsqlTypeHandler : NpgsqlTypeHandler, INpgsqlTypeHandler - { - delegate int NonGenericValidateAndGetLength(NpgsqlTypeHandler handler, object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - readonly NonGenericValidateAndGetLength _nonGenericValidateAndGetLength; - readonly NonGenericWriteWithLength _nonGenericWriteWithLength; - -#pragma warning disable CA1823 - static readonly ConcurrentDictionary - NonGenericDelegateCache = new ConcurrentDictionary(); -#pragma warning restore CA1823 - - /// - /// Constructs an . - /// - protected NpgsqlTypeHandler(PostgresType postgresType) - : base(postgresType) - // Get code-generated delegates for non-generic ValidateAndGetLength/WriteWithLengthInternal - => - (_nonGenericValidateAndGetLength, _nonGenericWriteWithLength) = - NonGenericDelegateCache.GetOrAdd(GetType(), t => ( - GenerateNonGenericValidationMethod(GetType()), - GenerateNonGenericWriteMethod(GetType(), typeof(INpgsqlTypeHandler<>))) - ); - - #region Read - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - public abstract ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null); - - /// - /// Reads a value of type with the given length from the provided buffer, - /// using either sync or async I/O. Type handlers typically don't need to override this - - /// override - but may do - /// so in exceptional cases where reading of arbitrary types is required. - /// - /// The buffer from which to read. - /// The byte length of the value. The buffer might not contain the full length, requiring I/O to be performed. - /// If I/O is required to read the full length of the value, whether it should be performed synchronously or asynchronously. - /// Additional PostgreSQL information about the type, such as the length in varchar(30). - /// The fully-read value. - protected internal override ValueTask Read(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - { - var asTypedHandler = this as INpgsqlTypeHandler; - if (asTypedHandler == null) - throw new InvalidCastException(fieldDescription == null - ? $"Can't cast database type to {typeof(TAny).Name}" - : $"Can't cast database type {fieldDescription.Handler.PgDisplayName} to {typeof(TAny).Name}" - ); - - return asTypedHandler.Read(buf, len, async, fieldDescription); - } - - /// - public override TAny Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => Read(buf, len, false, fieldDescription).Result; - - // Since TAny isn't constrained to class? or struct (C# doesn't have a non-nullable constraint that doesn't limit us to either struct or class), - // we must use the bang operator here to tell the compiler that a null value will never returned. - internal override async ValueTask ReadAsObject(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null) - => (await Read(buf, len, async, fieldDescription))!; - - internal override object ReadAsObject(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - => Read(buf, len, fieldDescription)!; - - #endregion Read - - #region Write - - /// - /// Called to validate and get the length of a value of a generic . - /// - public abstract int ValidateAndGetLength(TDefault value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - - /// - /// Called to write the value of a generic . - /// - public abstract Task Write(TDefault value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - /// - /// Called to validate and get the length of a value of an arbitrary type. - /// Checks that the current handler supports that type and throws an exception otherwise. - /// - protected internal override int ValidateAndGetLength(TAny value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - { - var typedHandler = this as INpgsqlTypeHandler; - if (typedHandler is null) - ThrowHelper.ThrowInvalidCastException_NotSupportedType(this, parameter, typeof(TAny)); - - return typedHandler.ValidateAndGetLength(value, ref lengthCache, parameter); - } - - /// - /// In the vast majority of cases writing a parameter to the buffer won't need to perform I/O. - /// - internal override Task WriteWithLengthInternal([AllowNull] TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - if (buf.WriteSpaceLeft < 4) - return WriteWithLengthLong(); - - if (value == null || typeof(TAny) == typeof(DBNull)) - { - buf.WriteInt32(-1); - return Task.CompletedTask; - } - - return WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - - async Task WriteWithLengthLong() - { - if (buf.WriteSpaceLeft < 4) - await buf.Flush(async, cancellationToken); - - if (value == null || typeof(TAny) == typeof(DBNull)) - { - buf.WriteInt32(-1); - return; - } - - await WriteWithLength(value, buf, lengthCache, parameter, async, cancellationToken); - } - } - - /// - /// Typically does not need to be overridden by type handlers, but may be needed in some - /// cases (e.g. . - /// Note that this method assumes it can write 4 bytes of length (already verified by - /// ). - /// - protected virtual Task WriteWithLength(TAny value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - { - Debug.Assert(this is INpgsqlTypeHandler); - - var typedHandler = (INpgsqlTypeHandler)this; - buf.WriteInt32(typedHandler.ValidateAndGetLength(value, ref lengthCache, parameter)); - return typedHandler.Write(value, buf, lengthCache, parameter, async, cancellationToken); - } - - // Object overloads for non-generic NpgsqlParameter - - /// - /// Called to validate and get the length of a value of a non-generic . - /// Type handlers generally don't need to override this. - /// - protected internal override int ValidateObjectAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - => value == null || value is DBNull - ? -1 - : _nonGenericValidateAndGetLength(this, value, ref lengthCache, parameter); - - /// - /// Called to write the value of a non-generic . - /// Type handlers generally don't need to override this. - /// - protected internal override Task WriteObjectWithLength(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - => value is DBNull - ? WriteWithLengthInternal(DBNull.Value, buf, lengthCache, parameter, async, cancellationToken) - : _nonGenericWriteWithLength(this, value, buf, lengthCache, parameter, async, cancellationToken); - - #endregion Write - - #region Code generation for non-generic writing - - // We need to support writing via non-generic NpgsqlParameter, which means we get requests - // to write some object with no generic typing information. - // We need to find out which INpgsqlTypeHandler interfaces our handler implements, and call - // the ValidateAndGetLength/WriteWithLengthInternal methods on the interface which corresponds to the - // value type. - // Since doing this with reflection every time is slow, we generate delegates to do this for us - // for each type handler. - - static NonGenericValidateAndGetLength GenerateNonGenericValidationMethod(Type handlerType) - { - var interfaces = handlerType.GetInterfaces().Where(i => - i.GetTypeInfo().IsGenericType && - i.GetGenericTypeDefinition() == typeof(INpgsqlTypeHandler<>) - ).Reverse().ToList(); - - Expression? ifElseExpression = null; - - var handlerParam = Expression.Parameter(typeof(NpgsqlTypeHandler), "handler"); - var valueParam = Expression.Parameter(typeof(object), "value"); - var lengthCacheParam = Expression.Parameter(typeof(NpgsqlLengthCache).MakeByRefType(), "lengthCache"); - var parameterParam = Expression.Parameter(typeof(NpgsqlParameter), "parameter"); - - var resultVariable = Expression.Variable(typeof(int), "result"); - - foreach (var i in interfaces) - { - var handledType = i.GenericTypeArguments[0]; - - ifElseExpression = Expression.IfThenElse( - // Test whether the type of the value given to the delegate corresponds - // to our current interface's handled type (i.e. the T in INpgsqlTypeHandler) - Expression.TypeEqual(valueParam, handledType), - // If it corresponds, cast the handler type (this) to INpgsqlTypeHandler - // and call its ValidateAndGetLength method - Expression.Assign( - resultVariable, - Expression.Call( - Expression.Convert(handlerParam, i), - i.GetMethod(nameof(INpgsqlTypeHandler.ValidateAndGetLength))!, - // Cast the value from object down to the interface's T - Expression.Convert(valueParam, handledType), - lengthCacheParam, - parameterParam - ) - ), - // If this is the first interface we're looking at, the else clause throws. - // Otherwise we stick the previous interface's IfThenElse in our else clause - ifElseExpression ?? Expression.Throw( - Expression.New( - MethodInfos.InvalidCastExceptionCtor, - Expression.Call( // Call string.Format to generate a nice informative exception message - MethodInfos.StringFormat, - new Expression[] - { - Expression.Constant($"Can't write CLR type {{0}} with handler type {handlerType.Name}"), - Expression.Call( // GetType() on the value - valueParam, - MethodInfos.ObjectGetType - ) - } - ) - ) - ) - ); - } - - if (ifElseExpression is null) - throw new Exception($"Type handler {handlerType.GetType().Name} does not implement the proper interface"); - - return Expression.Lambda( - Expression.Block( - new[] { resultVariable }, - ifElseExpression, resultVariable - ), - handlerParam, valueParam, lengthCacheParam, parameterParam - ).Compile(); - } - - #endregion Code generation for non-generic writing - - #region Misc - - internal override Type GetFieldType(FieldDescription? fieldDescription = null) => typeof(TDefault); - internal override Type GetProviderSpecificFieldType(FieldDescription? fieldDescription = null) => typeof(TDefault); - - /// - public override ArrayHandler CreateArrayHandler(PostgresArrayType arrayBackendType) - => new ArrayHandler(arrayBackendType, this); - - /// - public override IRangeHandler CreateRangeHandler(PostgresType rangeBackendType) - => new RangeHandler(rangeBackendType, this); - - #endregion Misc - } -} diff --git a/src/Npgsql/TypeHandling/NullableHandler.cs b/src/Npgsql/TypeHandling/NullableHandler.cs deleted file mode 100644 index c240b83f8d..0000000000 --- a/src/Npgsql/TypeHandling/NullableHandler.cs +++ /dev/null @@ -1,70 +0,0 @@ -using System; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; - -// ReSharper disable StaticMemberInGenericType -namespace Npgsql.TypeHandling -{ - delegate T ReadDelegate(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, FieldDescription? fieldDescription = null); - delegate ValueTask ReadAsyncDelegate(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLen, bool async, FieldDescription? fieldDescription = null); - - delegate int ValidateAndGetLengthDelegate(NpgsqlTypeHandler handler, T value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter); - delegate Task WriteAsyncDelegate(NpgsqlTypeHandler handler, T value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default); - - static class NullableHandler - { - public static readonly Type? UnderlyingType; - [NotNull] public static readonly ReadDelegate? Read; - [NotNull] public static readonly ReadAsyncDelegate? ReadAsync; - [NotNull] public static readonly ValidateAndGetLengthDelegate? ValidateAndGetLength; - [NotNull] public static readonly WriteAsyncDelegate? WriteAsync; - - public static bool Exists => UnderlyingType != null; - - static NullableHandler() - { - UnderlyingType = Nullable.GetUnderlyingType(typeof(T)); - - if (UnderlyingType == null) - return; - - Read = NullableHandler.CreateDelegate>(UnderlyingType, NullableHandler.ReadMethod); - ReadAsync = NullableHandler.CreateDelegate>(UnderlyingType, NullableHandler.ReadAsyncMethod); - ValidateAndGetLength = NullableHandler.CreateDelegate>(UnderlyingType, NullableHandler.ValidateMethod); - WriteAsync = NullableHandler.CreateDelegate>(UnderlyingType, NullableHandler.WriteAsyncMethod); - } - } - - static class NullableHandler - { - internal static readonly MethodInfo ReadMethod = new ReadDelegate(Read).Method.GetGenericMethodDefinition(); - internal static readonly MethodInfo ReadAsyncMethod = new ReadAsyncDelegate(ReadAsync).Method.GetGenericMethodDefinition(); - internal static readonly MethodInfo ValidateMethod = new ValidateAndGetLengthDelegate(ValidateAndGetLength).Method.GetGenericMethodDefinition(); - internal static readonly MethodInfo WriteAsyncMethod = new WriteAsyncDelegate(WriteAsync).Method.GetGenericMethodDefinition(); - - static T? Read(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, FieldDescription? fieldDescription) - where T : struct - => handler.Read(buffer, columnLength, fieldDescription); - - static async ValueTask ReadAsync(NpgsqlTypeHandler handler, NpgsqlReadBuffer buffer, int columnLength, bool async, FieldDescription? fieldDescription) - where T : struct - => await handler.Read(buffer, columnLength, async, fieldDescription); - - static int ValidateAndGetLength(NpgsqlTypeHandler handler, T? value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter) - where T : struct - => value.HasValue ? handler.ValidateAndGetLength(value.Value, ref lengthCache, parameter) : 0; - - static Task WriteAsync(NpgsqlTypeHandler handler, T? value, NpgsqlWriteBuffer buffer, NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter, bool async, CancellationToken cancellationToken = default) - where T : struct - => value.HasValue - ? handler.WriteWithLengthInternal(value.Value, buffer, lengthCache, parameter, async, cancellationToken) - : handler.WriteWithLengthInternal(DBNull.Value, buffer, lengthCache, parameter, async, cancellationToken); - - internal static TDelegate CreateDelegate(Type underlyingType, MethodInfo method) - where TDelegate : Delegate - => (TDelegate)method.MakeGenericMethod(underlyingType).CreateDelegate(typeof(TDelegate)); - } -} diff --git a/src/Npgsql/TypeMapping/ConnectorTypeMapper.cs b/src/Npgsql/TypeMapping/ConnectorTypeMapper.cs deleted file mode 100644 index 808ae84ecd..0000000000 --- a/src/Npgsql/TypeMapping/ConnectorTypeMapper.cs +++ /dev/null @@ -1,427 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Data; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; -using Npgsql.Logging; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping -{ - class ConnectorTypeMapper : TypeMapperBase - { - /// - /// The connector to which this type mapper belongs. - /// - readonly NpgsqlConnector _connector; - - NpgsqlDatabaseInfo? _databaseInfo; - - /// - /// Type information for the database of this mapper. - /// - internal NpgsqlDatabaseInfo DatabaseInfo - => _databaseInfo ?? throw new InvalidOperationException("Internal error: this type mapper hasn't yet been bound to a database info object"); - - internal NpgsqlTypeHandler UnrecognizedTypeHandler { get; } - - readonly Dictionary _byOID = new Dictionary(); - readonly Dictionary _byNpgsqlDbType = new Dictionary(); - readonly Dictionary _byDbType = new Dictionary(); - readonly Dictionary _byTypeName = new Dictionary(); - - /// - /// Maps CLR types to their type handlers. - /// - readonly Dictionary _byClrType= new Dictionary(); - - /// - /// Maps CLR types to their array handlers. - /// - readonly Dictionary _arrayHandlerByClrType = new Dictionary(); - - /// - /// Copy of at the time when this - /// mapper was created, to detect mapping changes. If changes are made to this connection's - /// mapper, the change counter is set to -1. - /// - internal int ChangeCounter { get; private set; } - - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(ConnectorTypeMapper)); - - #region Construction - - internal ConnectorTypeMapper(NpgsqlConnector connector) : base(GlobalTypeMapper.Instance.DefaultNameTranslator) - { - _connector = connector; - UnrecognizedTypeHandler = new UnknownTypeHandler(_connector.Connection!); - ClearBindings(); - ResetMappings(); - } - - #endregion Constructors - - #region Type handler lookup - - /// - /// Looks up a type handler by its PostgreSQL type's OID. - /// - /// A PostgreSQL type OID - /// A type handler that can be used to encode and decode values. - internal NpgsqlTypeHandler GetByOID(uint oid) - => TryGetByOID(oid, out var result) ? result : UnrecognizedTypeHandler; - - internal bool TryGetByOID(uint oid, [NotNullWhen(true)] out NpgsqlTypeHandler? handler) - => _byOID.TryGetValue(oid, out handler); - - internal NpgsqlTypeHandler GetByNpgsqlDbType(NpgsqlDbType npgsqlDbType) - => _byNpgsqlDbType.TryGetValue(npgsqlDbType, out var handler) - ? handler - : throw new NpgsqlException($"The NpgsqlDbType '{npgsqlDbType}' isn't present in your database. " + - "You may need to install an extension or upgrade to a newer version."); - - - internal NpgsqlTypeHandler GetByDbType(DbType dbType) - => _byDbType.TryGetValue(dbType, out var handler) - ? handler - : throw new NotSupportedException("This DbType is not supported in Npgsql: " + dbType); - - internal NpgsqlTypeHandler GetByDataTypeName(string typeName) - => _byTypeName.TryGetValue(typeName, out var handler) - ? handler - : throw new NotSupportedException("Could not find PostgreSQL type " + typeName); - - internal NpgsqlTypeHandler GetByClrType(Type type) - { - if (_byClrType.TryGetValue(type, out var handler)) - return handler; - - if (Nullable.GetUnderlyingType(type) is Type underlyingType && _byClrType.TryGetValue(underlyingType, out handler)) - return handler; - - // Try to see if it is an array type - var arrayElementType = GetArrayElementType(type); - if (arrayElementType != null) - { - if (_arrayHandlerByClrType.TryGetValue(arrayElementType, out var elementHandler)) - return elementHandler; - throw new NotSupportedException($"The CLR array type {type} isn't supported by Npgsql or your PostgreSQL. " + - "If you wish to map it to a PostgreSQL composite type array you need to register it before usage, please refer to the documentation."); - } - - // Nothing worked - if (type.GetTypeInfo().IsEnum) - throw new NotSupportedException($"The CLR enum type {type.Name} must be registered with Npgsql before usage, please refer to the documentation."); - - if (typeof(IEnumerable).IsAssignableFrom(type)) - throw new NotSupportedException("Npgsql 3.x removed support for writing a parameter with an IEnumerable value, use .ToList()/.ToArray() instead"); - - throw new NotSupportedException($"The CLR type {type} isn't natively supported by Npgsql or your PostgreSQL. " + - $"To use it with a PostgreSQL composite you need to specify {nameof(NpgsqlParameter.DataTypeName)} or to map it, please refer to the documentation."); - } - - static Type? GetArrayElementType(Type type) - { - var typeInfo = type.GetTypeInfo(); - if (typeInfo.IsArray) - return GetUnderlyingType(type.GetElementType()!); // The use of bang operator is justified here as Type.GetElementType() only returns null for the Array base class which can't be mapped in a useful way. - - var ilist = typeInfo.ImplementedInterfaces.FirstOrDefault(x => x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IList<>)); - if (ilist != null) - return GetUnderlyingType(ilist.GetGenericArguments()[0]); - - if (typeof(IList).IsAssignableFrom(type)) - throw new NotSupportedException("Non-generic IList is a supported parameter, but the NpgsqlDbType parameter must be set on the parameter"); - - return null; - - Type GetUnderlyingType(Type t) - => Nullable.GetUnderlyingType(t) ?? t; - } - - #endregion Type handler lookup - - #region Mapping management - - public override INpgsqlTypeMapper AddMapping(NpgsqlTypeMapping mapping) - { - CheckReady(); - - base.AddMapping(mapping); - BindType(mapping, _connector, externalCall: true); - ChangeCounter = -1; - return this; - } - - public override bool RemoveMapping(string pgTypeName) - { - CheckReady(); - - var removed = base.RemoveMapping(pgTypeName); - if (!removed) - return false; - - // Rebind everything. We redo rather than trying to update the - // existing dictionaries because it's complex to remove arrays, ranges... - ClearBindings(); - BindTypes(); - ChangeCounter = -1; - return true; - } - - void CheckReady() - { - if (_connector.State != ConnectorState.Ready) - throw new InvalidOperationException("Connection must be open and idle to perform registration"); - } - - void ResetMappings() - { - var globalMapper = GlobalTypeMapper.Instance; - globalMapper.Lock.EnterReadLock(); - try - { - Mappings.Clear(); - foreach (var kv in globalMapper.Mappings) - Mappings.Add(kv.Key, kv.Value); - } - finally - { - globalMapper.Lock.ExitReadLock(); - } - ChangeCounter = GlobalTypeMapper.Instance.ChangeCounter; - } - - void ClearBindings() - { - _byOID.Clear(); - _byNpgsqlDbType.Clear(); - _byDbType.Clear(); - _byClrType.Clear(); - _arrayHandlerByClrType.Clear(); - - _byNpgsqlDbType[NpgsqlDbType.Unknown] = UnrecognizedTypeHandler; - _byClrType[typeof(DBNull)] = UnrecognizedTypeHandler; - } - - public override void Reset() - { - ClearBindings(); - ResetMappings(); - BindTypes(); - } - - #endregion Mapping management - - #region Binding - - internal void Bind(NpgsqlDatabaseInfo databaseInfo) - { - _databaseInfo = databaseInfo; - BindTypes(); - } - - void BindTypes() - { - foreach (var mapping in Mappings.Values) - BindType(mapping, _connector, externalCall: false); - - // Enums - var enumFactory = new UnmappedEnumTypeHandlerFactory(DefaultNameTranslator); - foreach (var e in DatabaseInfo.EnumTypes.Where(e => !_byOID.ContainsKey(e.OID))) - BindType(enumFactory.Create(e, _connector.Connection!), e); - - // Wire up any domains we find to their base type mappings, this is important - // for reading domain fields of composites - foreach (var domain in DatabaseInfo.DomainTypes) - if (_byOID.TryGetValue(domain.BaseType.OID, out var baseTypeHandler)) - { - _byOID[domain.OID] = baseTypeHandler; - if (domain.Array != null) - BindType(baseTypeHandler.CreateArrayHandler(domain.Array), domain.Array); - } - } - - void BindType(NpgsqlTypeMapping mapping, NpgsqlConnector connector, bool externalCall) - { - // Binding can occur at two different times: - // 1. When a user adds a mapping for a specific connection (and exception should bubble up to them) - // 2. When binding the global mappings, in which case we want to log rather than throw - // (i.e. missing database type for some unused defined binding shouldn't fail the connection) - - var pgName = mapping.PgTypeName; - - PostgresType? pgType; - if (pgName.IndexOf('.') > -1) - DatabaseInfo.ByFullName.TryGetValue(pgName, out pgType); // Full type name with namespace - else if (DatabaseInfo.ByName.TryGetValue(pgName, out pgType) && pgType is null) // No dot, partial type name - { - // If the name was found but the value is null, that means that there are - // two db types with the same name (different schemas). - // Try to fall back to pg_catalog, otherwise fail. - if (!DatabaseInfo.ByFullName.TryGetValue($"pg_catalog.{pgName}", out pgType)) - { - var msg = $"More than one PostgreSQL type was found with the name {mapping.PgTypeName}, please specify a full name including schema"; - if (externalCall) - throw new ArgumentException(msg); - Log.Debug(msg); - return; - } - } - - if (pgType is null) - { - var msg = $"A PostgreSQL type with the name {mapping.PgTypeName} was not found in the database"; - if (externalCall) - throw new ArgumentException(msg); - Log.Debug(msg); - return; - } - if (pgType is PostgresDomainType) - { - var msg = "Cannot add a mapping to a PostgreSQL domain type"; - if (externalCall) - throw new NotSupportedException(msg); - Log.Debug(msg); - return; - } - - var handler = mapping.TypeHandlerFactory.CreateNonGeneric(pgType, connector.Connection!); - BindType(handler, pgType, mapping.NpgsqlDbType, mapping.DbTypes, mapping.ClrTypes); - - if (!externalCall) - return; - - foreach (var domain in DatabaseInfo.DomainTypes) - if (domain.BaseType.OID == pgType.OID) - { - _byOID[domain.OID] = handler; - if (domain.Array != null) - BindType(handler.CreateArrayHandler(domain.Array), domain.Array); - } - } - - void BindType(NpgsqlTypeHandler handler, PostgresType pgType, NpgsqlDbType? npgsqlDbType = null, DbType[]? dbTypes = null, Type[]? clrTypes = null) - { - _byOID[pgType.OID] = handler; - _byTypeName[pgType.FullName] = handler; - _byTypeName[pgType.Name] = handler; - - if (npgsqlDbType.HasValue) - { - var value = npgsqlDbType.Value; - if (_byNpgsqlDbType.ContainsKey(value)) - throw new InvalidOperationException($"Two type handlers registered on same NpgsqlDbType '{npgsqlDbType}': {_byNpgsqlDbType[value].GetType().Name} and {handler.GetType().Name}"); - _byNpgsqlDbType[npgsqlDbType.Value] = handler; - } - - if (dbTypes != null) - { - foreach (var dbType in dbTypes) - { - if (_byDbType.ContainsKey(dbType)) - throw new InvalidOperationException($"Two type handlers registered on same DbType {dbType}: {_byDbType[dbType].GetType().Name} and {handler.GetType().Name}"); - _byDbType[dbType] = handler; - } - } - - if (clrTypes != null) - { - foreach (var type in clrTypes) - { - if (_byClrType.ContainsKey(type)) - throw new InvalidOperationException($"Two type handlers registered on same .NET type '{type}': {_byClrType[type].GetType().Name} and {handler.GetType().Name}"); - _byClrType[type] = handler; - } - } - - if (pgType.Array != null) - BindArrayType(handler, pgType.Array, npgsqlDbType, clrTypes); - - if (pgType.Range != null) - BindRangeType(handler, pgType.Range, npgsqlDbType, clrTypes); - } - - void BindArrayType(NpgsqlTypeHandler elementHandler, PostgresArrayType pgArrayType, NpgsqlDbType? elementNpgsqlDbType, Type[]? elementClrTypes) - { - var arrayHandler = elementHandler.CreateArrayHandler(pgArrayType); - - var arrayNpgsqlDbType = elementNpgsqlDbType.HasValue - ? NpgsqlDbType.Array | elementNpgsqlDbType.Value - : (NpgsqlDbType?)null; - - BindType(arrayHandler, pgArrayType, arrayNpgsqlDbType); - - // Note that array handlers aren't registered in ByClrType like base types, because they handle all - // dimension types and not just one CLR type (e.g. int[], int[,], int[,,]). - // So the by-type lookup is special and goes via _arrayHandlerByClrType, see this[Type type] - // TODO: register single-dimensional in _byType as a specific optimization? But do PSV as well... - if (elementClrTypes != null) - { - foreach (var elementType in elementClrTypes) - { - if (_arrayHandlerByClrType.ContainsKey(elementType)) - throw new Exception( - $"Two array type handlers registered on same .NET type {elementType}: {_arrayHandlerByClrType[elementType].GetType().Name} and {arrayHandler.GetType().Name}"); - _arrayHandlerByClrType[elementType] = arrayHandler; - } - } - } - - void BindRangeType(NpgsqlTypeHandler elementHandler, PostgresRangeType pgRangeType, NpgsqlDbType? elementNpgsqlDbType, Type[]? elementClrTypes) - { - var rangeHandler = elementHandler.CreateRangeHandler(pgRangeType); - - var rangeNpgsqlDbType = elementNpgsqlDbType.HasValue - ? NpgsqlDbType.Range | elementNpgsqlDbType.Value - : (NpgsqlDbType?)null; - - // We only want to bind supported range CLR types whose element CLR types are being bound as well. - var clrTypes = elementClrTypes is null - ? null - : rangeHandler.SupportedRangeClrTypes - .Where(r => elementClrTypes.Contains(r.GenericTypeArguments[0])) - .ToArray(); - - BindType((NpgsqlTypeHandler)rangeHandler, pgRangeType, rangeNpgsqlDbType, null, clrTypes); - } - - #endregion Binding - - internal (NpgsqlDbType? npgsqlDbType, PostgresType postgresType) GetTypeInfoByOid(uint oid) - { - if (!DatabaseInfo.ByOID.TryGetValue(oid, out var postgresType)) - throw new InvalidOperationException($"Couldn't find PostgreSQL type with OID {oid}"); - - // Try to find the postgresType in the mappings - if (TryGetMapping(postgresType, out var npgsqlTypeMapping)) - return (npgsqlTypeMapping.NpgsqlDbType, postgresType); - - // Try to find the elements' postgresType in the mappings - if (postgresType is PostgresArrayType arrayType && - TryGetMapping(arrayType.Element, out var elementNpgsqlTypeMapping)) - return (elementNpgsqlTypeMapping.NpgsqlDbType | NpgsqlDbType.Array, postgresType); - - // Try to find the elements' postgresType of the base type in the mappings - // this happens with domains over arrays - if (postgresType is PostgresDomainType domainType && domainType.BaseType is PostgresArrayType baseType && - TryGetMapping(baseType.Element, out var baseTypeElementNpgsqlTypeMapping)) - return (baseTypeElementNpgsqlTypeMapping.NpgsqlDbType | NpgsqlDbType.Array, postgresType); - - // It might be an unmapped enum/composite type, or some other unmapped type - return (null, postgresType); - } - - bool TryGetMapping(PostgresType pgType, [NotNullWhen(true)] out NpgsqlTypeMapping? mapping) - => Mappings.TryGetValue(pgType.Name, out mapping) || - Mappings.TryGetValue(pgType.FullName, out mapping) || - pgType is PostgresDomainType domain && ( - Mappings.TryGetValue(domain.BaseType.Name, out mapping) || - Mappings.TryGetValue(domain.BaseType.FullName, out mapping)); - } -} diff --git a/src/Npgsql/TypeMapping/EntityFrameworkCoreCompat.cs b/src/Npgsql/TypeMapping/EntityFrameworkCoreCompat.cs deleted file mode 100644 index db347da38e..0000000000 --- a/src/Npgsql/TypeMapping/EntityFrameworkCoreCompat.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Npgsql.TypeMapping; - -// This file contains some pretty awful hacks to make current version of the EF Core provider -// compatible with the new type mapping/handling system introduced in Npgsql 4.0. -// The EF Core provider dynamically loads its type mappings from Npgsql, which allows it to -// automatically support any type supported by Npgsql. Unfortunately, the current loading -// system is very tightly coupled to pre-3.2 type mapping types (e.g. TypeHandlerRegistry), -// and so this shim is required. - -// ReSharper disable once CheckNamespace -namespace Npgsql -{ - [Obsolete("Purely for EF Core backwards compatibility")] - class TypeHandlerRegistry - { - internal static readonly Dictionary HandlerTypes; - - static TypeHandlerRegistry() - { - HandlerTypes = GlobalTypeMapper.Instance.Mappings.Values.ToDictionary( - m => m.PgTypeName, - m => new TypeAndMapping - { - HandlerType = typeof(TypeHandler<>).MakeGenericType(m.TypeHandlerFactory.DefaultValueType), - Mapping = new TypeMappingAttribute(m.PgTypeName, m.NpgsqlDbType, - m.DbTypes, m.ClrTypes, m.InferredDbType) - } - ); - } - } - - // ReSharper disable once UnusedTypeParameter - [Obsolete("Purely for EF Core backwards compatibility")] - class TypeHandler {} - - [Obsolete("Purely for EF Core backwards compatibility")] - struct TypeAndMapping - { - internal Type HandlerType; - internal TypeMappingAttribute Mapping; - } -} diff --git a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs index 56dbaa11f7..4ef1313adf 100644 --- a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs +++ b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs @@ -1,233 +1,375 @@ using System; using System.Collections.Generic; -using System.Data; +using System.Diagnostics.CodeAnalysis; using System.Linq; -using System.Net; -using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; using System.Threading; -using Npgsql.NameTranslation; -using Npgsql.TypeHandling; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using Npgsql.Internal.ResolverFactories; using NpgsqlTypes; -namespace Npgsql.TypeMapping +namespace Npgsql.TypeMapping; + +/// +sealed class GlobalTypeMapper : INpgsqlTypeMapper { - class GlobalTypeMapper : TypeMapperBase - { - public static GlobalTypeMapper Instance { get; } + readonly UserTypeMapper _userTypeMapper = new(); + readonly List _pluginResolverFactories = new(); + readonly ReaderWriterLockSlim _lock = new(); + PgTypeInfoResolverFactory[] _typeMappingResolvers = Array.Empty(); - /// - /// A counter that is incremented whenever a global mapping change occurs. - /// Used to invalidate bound type mappers. - /// - internal int ChangeCounter => _changeCounter; + internal List HackyEnumTypeMappings { get; } = new(); - internal ReaderWriterLockSlim Lock { get; } - = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion); + internal IEnumerable GetPluginResolverFactories() + { + var resolvers = new List(); + _lock.EnterReadLock(); + try + { + resolvers.AddRange(_pluginResolverFactories); + } + finally + { + _lock.ExitReadLock(); + } - int _changeCounter; + return resolvers; + } - static GlobalTypeMapper() + internal PgTypeInfoResolverFactory? GetUserMappingsResolverFactory() + { + _lock.EnterReadLock(); + try + { + return _userTypeMapper.Items.Count > 0 ? _userTypeMapper : null; + } + finally { - var instance = new GlobalTypeMapper(); - instance.SetupGlobalTypeMapper(); - Instance = instance; + _lock.ExitReadLock(); } + } - internal GlobalTypeMapper() : base(new NpgsqlSnakeCaseNameTranslator()) {} + internal void AddGlobalTypeMappingResolvers(PgTypeInfoResolverFactory[] factories, Func? builderFactory = null, bool overwrite = false) + { + // Good enough logic to prevent SlimBuilder overriding the normal Builder. + if (overwrite || factories.Length > _typeMappingResolvers.Length) + { + _builderFactory = builderFactory; + _typeMappingResolvers = factories; + ResetTypeMappingCache(); + } + } - #region Mapping management + void ResetTypeMappingCache() => _typeMappingOptions = null; - public override INpgsqlTypeMapper AddMapping(NpgsqlTypeMapping mapping) + PgSerializerOptions? _typeMappingOptions; + Func? _builderFactory; + JsonSerializerOptions? _jsonSerializerOptions; + + PgSerializerOptions TypeMappingOptions + { + get { - Lock.EnterWriteLock(); + if (_typeMappingOptions is not null) + return _typeMappingOptions; + + _lock.EnterReadLock(); try { - base.AddMapping(mapping); - RecordChange(); - - if (mapping.NpgsqlDbType.HasValue) + var builder = _builderFactory?.Invoke() ?? new(); + builder.AppendResolverFactory(_userTypeMapper); + foreach (var factory in _pluginResolverFactories) + builder.AppendResolverFactory(factory); + foreach (var factory in _typeMappingResolvers) + builder.AppendResolverFactory(factory); + var chain = builder.Build(); + return _typeMappingOptions = new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog, chain) { - foreach (var dbType in mapping.DbTypes) - _dbTypeToNpgsqlDbType[dbType] = mapping.NpgsqlDbType.Value; - - if (mapping.InferredDbType.HasValue) - _npgsqlDbTypeToDbType[mapping.NpgsqlDbType.Value] = mapping.InferredDbType.Value; - - foreach (var clrType in mapping.ClrTypes) - _typeToNpgsqlDbType[clrType] = mapping.NpgsqlDbType.Value; - } - - if (mapping.InferredDbType.HasValue) - foreach (var clrType in mapping.ClrTypes) - _typeToDbType[clrType] = mapping.InferredDbType.Value; - - return this; + // This means we don't ever have a missing oid for a datatypename as our canonical format is datatypenames. + PortableTypeIds = true, + // Don't throw if our catalog doesn't know the datatypename. + IntrospectionMode = true + }; } finally { - Lock.ExitWriteLock(); + _lock.ExitReadLock(); } } + } - public override bool RemoveMapping(string pgTypeName) + internal DataTypeName? FindDataTypeName(Type type, object value) + { + DataTypeName? dataTypeName; + try { - Lock.EnterWriteLock(); - try - { - var result = base.RemoveMapping(pgTypeName); - RecordChange(); - return result; - } - finally - { - Lock.ExitWriteLock(); - } + var typeInfo = TypeMappingOptions.GetTypeInfo(type); + if (typeInfo is PgResolverTypeInfo info) + dataTypeName = info.GetObjectResolution(value).PgTypeId.DataTypeName; + else + dataTypeName = typeInfo?.GetResolution().PgTypeId.DataTypeName; + } + catch + { + dataTypeName = null; } + return dataTypeName; + } + + internal static GlobalTypeMapper Instance { get; } - public override void Reset() + static GlobalTypeMapper() + => Instance = new GlobalTypeMapper(); + + /// + public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) + { + _lock.EnterWriteLock(); + try { - Lock.EnterWriteLock(); - try + var type = factory.GetType(); + + // Since EFCore.PG plugins (and possibly other users) repeatedly call NpgsqlConnection.GlobalTypeMapper.UseNodaTime, + // we replace an existing resolver of the same CLR type. + if (_pluginResolverFactories.Count > 0 && _pluginResolverFactories[0].GetType() == type) + _pluginResolverFactories[0] = factory; + for (var i = 0; i < _pluginResolverFactories.Count; i++) { - Mappings.Clear(); - SetupGlobalTypeMapper(); - RecordChange(); + if (_pluginResolverFactories[i].GetType() == type) + { + _pluginResolverFactories.RemoveAt(i); + break; + } } - finally + + _pluginResolverFactories.Insert(0, factory); + ResetTypeMappingCache(); + } + finally + { + _lock.ExitWriteLock(); + } + } + + void ReplaceTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) + { + _lock.EnterWriteLock(); + try + { + var type = factory.GetType(); + + for (var i = 0; i < _pluginResolverFactories.Count; i++) { - Lock.ExitWriteLock(); + if (_pluginResolverFactories[i].GetType() == type) + { + _pluginResolverFactories[i] = factory; + break; + } } - } - internal void RecordChange() => Interlocked.Increment(ref _changeCounter); + ResetTypeMappingCache(); + } + finally + { + _lock.ExitWriteLock(); + } + } - #endregion Mapping management + /// + public void Reset() + { + _lock.EnterWriteLock(); + try + { + _pluginResolverFactories.Clear(); + _userTypeMapper.Items.Clear(); + HackyEnumTypeMappings.Clear(); + } + finally + { + _lock.ExitWriteLock(); + } + } - #region NpgsqlDbType/DbType inference for NpgsqlParameter + /// + public INpgsqlNameTranslator DefaultNameTranslator + { + get => _userTypeMapper.DefaultNameTranslator; + set => _userTypeMapper.DefaultNameTranslator = value; + } - readonly Dictionary _npgsqlDbTypeToDbType = new Dictionary(); - readonly Dictionary _dbTypeToNpgsqlDbType = new Dictionary(); - readonly Dictionary _typeToNpgsqlDbType = new Dictionary(); - readonly Dictionary _typeToDbType = new Dictionary(); + /// + public INpgsqlTypeMapper ConfigureJsonOptions(JsonSerializerOptions serializerOptions) + { + _jsonSerializerOptions = serializerOptions; + // If JsonTypeInfoResolverFactory exists we replace it with a configured instance on the same index of the array. + ReplaceTypeInfoResolverFactory(new JsonTypeInfoResolverFactory(serializerOptions)); + return this; + } - internal DbType ToDbType(NpgsqlDbType npgsqlDbType) - => _npgsqlDbTypeToDbType.TryGetValue(npgsqlDbType, out var dbType) ? dbType : DbType.Object; + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + public INpgsqlTypeMapper EnableDynamicJson( + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null) + { + AddTypeInfoResolverFactory(new JsonDynamicTypeInfoResolverFactory(jsonbClrTypes, jsonClrTypes, _jsonSerializerOptions)); + return this; + } - internal NpgsqlDbType ToNpgsqlDbType(DbType dbType) - { - if (!_dbTypeToNpgsqlDbType.TryGetValue(dbType, out var npgsqlDbType)) - throw new NotSupportedException($"The parameter type DbType.{dbType} isn't supported by PostgreSQL or Npgsql"); - return npgsqlDbType; - } + /// + [RequiresUnreferencedCode("The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + public INpgsqlTypeMapper EnableRecordsAsTuples() + { + AddTypeInfoResolverFactory(new TupledRecordTypeInfoResolverFactory()); + return this; + } - internal DbType ToDbType(Type type) - => _typeToDbType.TryGetValue(type, out var dbType) ? dbType : DbType.Object; + /// + [RequiresUnreferencedCode("The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode("The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + public INpgsqlTypeMapper EnableUnmappedTypes() + { + AddTypeInfoResolverFactory(new UnmappedTypeInfoResolverFactory()); + return this; + } - internal NpgsqlDbType ToNpgsqlDbType(Type type) + /// + public INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum + { + _lock.EnterWriteLock(); + try { - if (_typeToNpgsqlDbType.TryGetValue(type, out var npgsqlDbType)) - return npgsqlDbType; + _userTypeMapper.MapEnum(pgName, nameTranslator); - if (type.IsArray) - { - if (type == typeof(byte[])) - return NpgsqlDbType.Bytea; - return NpgsqlDbType.Array | ToNpgsqlDbType(type.GetElementType()!); - } + // Temporary hack for EFCore.PG enum mapping compat + if (_userTypeMapper.Items.FirstOrDefault(i => i.ClrType == typeof(TEnum)) is UserTypeMapping userTypeMapping) + HackyEnumTypeMappings.Add(new(typeof(TEnum), userTypeMapping.PgTypeName, nameTranslator ?? DefaultNameTranslator)); - var typeInfo = type.GetTypeInfo(); + ResetTypeMappingCache(); - var ilist = typeInfo.ImplementedInterfaces.FirstOrDefault(x => x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IList<>)); - if (ilist != null) - return NpgsqlDbType.Array | ToNpgsqlDbType(ilist.GetGenericArguments()[0]); + return this; + } + finally + { + _lock.ExitWriteLock(); + } + } - if (typeInfo.IsGenericType && type.GetGenericTypeDefinition() == typeof(NpgsqlRange<>)) - return NpgsqlDbType.Range | ToNpgsqlDbType(type.GetGenericArguments()[0]); + /// + public bool UnmapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum + { + _lock.EnterWriteLock(); + try + { + var removed = _userTypeMapper.UnmapEnum(pgName, nameTranslator); - if (type == typeof(DBNull)) - return NpgsqlDbType.Unknown; + // Temporary hack for EFCore.PG enum mapping compat + if (removed && ((List)_userTypeMapper.Items).FindIndex(m => m.ClrType == typeof(TEnum)) is > -1 and var index) + HackyEnumTypeMappings.RemoveAt(index); - throw new NotSupportedException("Can't infer NpgsqlDbType for type " + type); - } + ResetTypeMappingCache(); + return removed; + } + finally + { + _lock.ExitWriteLock(); + } + } - #endregion NpgsqlDbType/DbType inference for NpgsqlParameter + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + public INpgsqlTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _lock.EnterWriteLock(); + try + { + _userTypeMapper.MapEnum(clrType, pgName, nameTranslator); - #region Setup for built-in handlers + // Temporary hack for EFCore.PG enum mapping compat + if (_userTypeMapper.Items.FirstOrDefault(i => i.ClrType == clrType) is UserTypeMapping userTypeMapping) + HackyEnumTypeMappings.Add(new(clrType, userTypeMapping.PgTypeName, nameTranslator ?? DefaultNameTranslator)); - void SetupGlobalTypeMapper() + ResetTypeMappingCache(); + return this; + } + finally { - // Look for TypeHandlerFactories with mappings in our assembly, set them up - foreach (var t in typeof(TypeMapperBase).GetTypeInfo().Assembly.GetTypes().Where(t => typeof(NpgsqlTypeHandlerFactory).IsAssignableFrom(t.GetTypeInfo()))) - { - var mappingAttributes = t.GetTypeInfo().GetCustomAttributes(typeof(TypeMappingAttribute), false); - if (!mappingAttributes.Any()) - continue; + _lock.ExitWriteLock(); + } + } - var factory = (NpgsqlTypeHandlerFactory)Activator.CreateInstance(t)!; + /// + public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _lock.EnterWriteLock(); + try + { + var removed = _userTypeMapper.UnmapEnum(clrType, pgName, nameTranslator); - foreach (TypeMappingAttribute m in mappingAttributes) - { - // TODO: Duplication between TypeMappingAttribute and TypeMapping. Look at this later. - AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = m.PgName, - NpgsqlDbType = m.NpgsqlDbType, - DbTypes = m.DbTypes, - ClrTypes = m.ClrTypes, - InferredDbType = m.InferredDbType, - TypeHandlerFactory = factory, - }.Build()); - } - } + // Temporary hack for EFCore.PG enum mapping compat + if (removed && ((List)_userTypeMapper.Items).FindIndex(m => m.ClrType == clrType) is > -1 and var index) + HackyEnumTypeMappings.RemoveAt(index); - // Look for NpgsqlTypeHandler classes with mappings in our assembly, set them up with the DefaultTypeHandlerFactory. - // This is a shortcut that allows us to not specify a factory for each and every type handler - foreach (var t in typeof(TypeMapperBase).GetTypeInfo().Assembly.GetTypes().Where(t => t.GetTypeInfo().IsSubclassOf(typeof(NpgsqlTypeHandler)))) - { - var mappingAttributes = t.GetTypeInfo().GetCustomAttributes(typeof(TypeMappingAttribute), false); - if (!mappingAttributes.Any()) - continue; + ResetTypeMappingCache(); + return removed; + } + finally + { + _lock.ExitWriteLock(); + } + } - var factory = new DefaultTypeHandlerFactory(t); + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => MapComposite(typeof(T), pgName, nameTranslator); - foreach (TypeMappingAttribute m in mappingAttributes) - { - // TODO: Duplication between TypeMappingAttribute and TypeMapping. Look at this later. - AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = m.PgName, - NpgsqlDbType = m.NpgsqlDbType, - DbTypes = m.DbTypes, - ClrTypes = m.ClrTypes, - InferredDbType = m.InferredDbType, - TypeHandlerFactory = factory - }.Build()); - } - } + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => UnmapComposite(typeof(T), pgName, nameTranslator); - // This is an extremely ugly hack to support ReadOnlyIPAddress, which as an internal subclass of IPAddress - // added to .NET Core 3.0 (see https://github.com/dotnet/corefx/issues/33373) - if (_typeToNpgsqlDbType.ContainsKey(typeof(IPAddress)) && - Mappings.TryGetValue("inet", out var inetMapping) && - typeof(IPAddress).GetNestedType("ReadOnlyIPAddress", BindingFlags.NonPublic) is Type readOnlyIpType) - { - _typeToNpgsqlDbType[readOnlyIpType] = _typeToNpgsqlDbType[typeof(IPAddress)]; - var augmentedClrType = new Type[inetMapping.ClrTypes.Length + 1]; - Array.Copy(inetMapping.ClrTypes, augmentedClrType, inetMapping.ClrTypes.Length); - augmentedClrType[augmentedClrType.Length - 1] = readOnlyIpType; - Mappings["inet"] = new NpgsqlTypeMappingBuilder - { - PgTypeName = "inet", - NpgsqlDbType = inetMapping.NpgsqlDbType, - DbTypes = inetMapping.DbTypes, - ClrTypes = augmentedClrType, - InferredDbType = inetMapping.InferredDbType, - TypeHandlerFactory = inetMapping.TypeHandlerFactory - }.Build(); - } + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public INpgsqlTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _lock.EnterWriteLock(); + try + { + _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); + ResetTypeMappingCache(); + return this; } + finally + { + _lock.ExitWriteLock(); + } + } - #endregion Setup for built-in handlers + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public bool UnmapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + _lock.EnterWriteLock(); + try + { + var result = _userTypeMapper.UnmapComposite(clrType, pgName, nameTranslator); + ResetTypeMappingCache(); + return result; + } + finally + { + _lock.ExitWriteLock(); + } } } diff --git a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs index aa3ce1cf50..83728785d6 100644 --- a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs +++ b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs @@ -1,174 +1,251 @@ using System; -using System.Collections.Generic; -using JetBrains.Annotations; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Nodes; +using Npgsql.Internal; +using Npgsql.Internal.ResolverFactories; using Npgsql.NameTranslation; using NpgsqlTypes; // ReSharper disable UnusedMember.Global -namespace Npgsql.TypeMapping +namespace Npgsql.TypeMapping; + +/// +/// A type mapper, managing how to read and write CLR values to PostgreSQL data types. +/// +/// +/// The preferred way to manage type mappings is on . An alternative, but discouraged, method, is to +/// manage them globally via ). +/// +public interface INpgsqlTypeMapper { /// - /// A type mapper, managing how to read and write CLR values to PostgreSQL data types. - /// A type mapper exists for each connection, as well as a single global type mapper - /// (accessible via ). + /// The default name translator to convert CLR type names and member names. Defaults to . + /// + INpgsqlNameTranslator DefaultNameTranslator { get; set; } + + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET enum type to be mapped + INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>( + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum; + + /// + /// Removes an existing enum mapping. + /// + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + bool UnmapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>( + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum; + + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// The .NET enum type to be mapped + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + INpgsqlTypeMapper MapEnum( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type clrType, + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Removes an existing enum mapping. + /// + /// The .NET enum type to be mapped + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + bool UnmapEnum( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type clrType, + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Maps a CLR type to a PostgreSQL composite type. /// /// + /// CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your members to manually specify a PostgreSQL name. + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. /// - public interface INpgsqlTypeMapper - { - /// - /// The default name translator to convert CLR type names and member names. - /// - [NotNull] - INpgsqlNameTranslator DefaultNameTranslator { get; } - - /// - /// Enumerates all mappings currently set up on this type mapper. - /// - [NotNull] - [ItemNotNull] - IEnumerable Mappings { get; } - - /// - /// Adds a new type mapping to this mapper, overwriting any existing mapping in the process. - /// - [NotNull] - INpgsqlTypeMapper AddMapping([NotNull] NpgsqlTypeMapping mapping); - - /// - /// Removes an existing mapping from this mapper. Attempts to read or write this type - /// after removal will result in an exception. - /// - /// A PostgreSQL type name for the type in the database. - bool RemoveMapping([NotNull] string pgTypeName); - - /// - /// Maps a CLR enum to a PostgreSQL enum type. - /// - /// - /// CLR enum labels are mapped by name to PostgreSQL enum labels. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. - /// If there is a discrepancy between the .NET and database labels while an enum is read or written, - /// an exception will be raised. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET enum type to be mapped - [NotNull] - INpgsqlTypeMapper MapEnum( - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum; - - /// - /// Removes an existing enum mapping. - /// - /// - /// A PostgreSQL type name for the corresponding enum type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - bool UnmapEnum( - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum; - - /// - /// Maps a CLR type to a PostgreSQL composite type. - /// - /// - /// CLR fields and properties by string to PostgreSQL names. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// You can also use the on your members to manually specify a PostgreSQL name. - /// If there is a discrepancy between the .NET type and database type while a composite is read or written, - /// an exception will be raised. - /// - /// - /// A PostgreSQL type name for the corresponding composite type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - /// The .NET type to be mapped - [NotNull] - INpgsqlTypeMapper MapComposite( - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null); - - /// - /// Removes an existing composite mapping. - /// - /// - /// A PostgreSQL type name for the corresponding composite type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - bool UnmapComposite( - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null); - - /// - /// Maps a CLR type to a composite type. - /// - /// - /// Maps CLR fields and properties by string to PostgreSQL names. - /// The translation strategy can be controlled by the parameter, - /// which defaults to . - /// If there is a discrepancy between the .NET type and database type while a composite is read or written, - /// an exception will be raised. - /// - /// The .NET type to be mapped. - /// - /// A PostgreSQL type name for the corresponding composite type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - [NotNull] - INpgsqlTypeMapper MapComposite( - Type clrType, - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null); - - /// - /// Removes an existing composite mapping. - /// - /// The .NET type to be unmapped. - /// - /// A PostgreSQL type name for the corresponding composite type in the database. - /// If null, the name translator given in will be used. - /// - /// - /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). - /// Defaults to - /// - bool UnmapComposite( - Type clrType, - string? pgName = null, - INpgsqlNameTranslator? nameTranslator = null); - - /// - /// Resets all mapping changes performed on this type mapper and reverts it to its original, starting state. - /// - void Reset(); - } + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET type to be mapped + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Removes an existing composite mapping. + /// + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Maps a CLR type to a composite type. + /// + /// + /// Maps CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. + /// + /// The .NET type to be mapped. + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper MapComposite( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type clrType, + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Removes an existing composite mapping. + /// + /// The .NET type to be unmapped. + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + bool UnmapComposite( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type clrType, + string? pgName = null, + INpgsqlNameTranslator? nameTranslator = null); + + /// + /// Adds a type info resolver factory which can add or modify support for PostgreSQL types. + /// Typically used by plugins. + /// + /// The type resolver factory to be added. + void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory); + + /// + /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. + /// + /// Options to customize JSON serialization and deserialization. + /// + INpgsqlTypeMapper ConfigureJsonOptions(JsonSerializerOptions serializerOptions); + + /// + /// Sets up dynamic System.Text.Json mappings. This allows mapping arbitrary .NET types to PostgreSQL json and jsonb + /// types, as well as and its derived types. + /// + /// + /// A list of CLR types to map to PostgreSQL jsonb (no need to specify ). + /// + /// + /// A list of CLR types to map to PostgreSQL json (no need to specify ). + /// + /// + /// Due to the dynamic nature of these mappings, they are not compatible with NativeAOT or trimming. + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode( + "Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + INpgsqlTypeMapper EnableDynamicJson(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null); + + /// + /// Sets up mappings for the PostgreSQL record type as a .NET or . + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode( + "The mapping of PostgreSQL records as .NET tuples requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The mapping of PostgreSQL records as .NET tuples requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper EnableRecordsAsTuples(); + + /// + /// Sets up mappings allowing the use of unmapped enum, range and multirange types. + /// + /// The same builder instance so that multiple calls can be chained. + [RequiresUnreferencedCode( + "The use of unmapped enums, ranges or multiranges requires reflection usage which is incompatible with trimming.")] + [RequiresDynamicCode( + "The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] + INpgsqlTypeMapper EnableUnmappedTypes(); + + /// + /// Resets all mapping changes performed on this type mapper and reverts it to its original, starting state. + /// + void Reset(); } diff --git a/src/Npgsql/TypeMapping/NpgsqlTypeMapping.cs b/src/Npgsql/TypeMapping/NpgsqlTypeMapping.cs deleted file mode 100644 index 1fca1616bf..0000000000 --- a/src/Npgsql/TypeMapping/NpgsqlTypeMapping.cs +++ /dev/null @@ -1,151 +0,0 @@ -using System; -using System.Data; -using System.Diagnostics.CodeAnalysis; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping -{ - /// - /// Builds instances of for addition into . - /// - public class NpgsqlTypeMappingBuilder - { - /// - /// The name of the PostgreSQL type name, as it appears in the pg_type catalog. - /// - /// - /// This can a a partial name (without the schema), or a fully-qualified name - /// (schema.typename) - the latter can be used if you have two types with the same - /// name in different schemas. - /// - [DisallowNull] - public string? PgTypeName { get; set; } - - /// - /// The that corresponds to this type. Setting an - /// 's property - /// to this value will make Npgsql write its value to PostgreSQL with this mapping. - /// - public NpgsqlDbType? NpgsqlDbType { get; set; } - - /// - /// A set of s that correspond to this type. Setting an - /// 's property - /// to one of these values will make Npgsql write its value to PostgreSQL with this mapping. - /// - public DbType[]? DbTypes { get; set; } - - /// - /// A set of CLR types that correspond to this type. Setting an - /// 's property - /// to one of these types will make Npgsql write its value to PostgreSQL with this mapping. - /// - public Type[]? ClrTypes { get; set; } - - /// - /// Determines what is returned from when this mapping - /// is used. - /// - public DbType? InferredDbType { get; set; } - - /// - /// A factory for a type handler that will be used to read and write values for PostgreSQL type. - /// - [DisallowNull] - public NpgsqlTypeHandlerFactory? TypeHandlerFactory { get; set; } - - /// - /// Builds an that can be added to an . - /// - /// - public NpgsqlTypeMapping Build() - { - if (string.IsNullOrWhiteSpace(PgTypeName)) - throw new ArgumentException($"{nameof(PgTypeName)} must contain the name of a PostgreSQL data type", nameof(PgTypeName)); - - if (TypeHandlerFactory is null) - throw new ArgumentException($"{nameof(TypeHandlerFactory)} must refer to a type handler factory"); - - return new NpgsqlTypeMapping(PgTypeName!, NpgsqlDbType, DbTypes, ClrTypes, InferredDbType, TypeHandlerFactory); - } - } - - /// - /// Represents a type mapping for a PostgreSQL data type, which can be added to a type mapper, - /// managing when that data type will be read and written and how. - /// - /// - /// - public sealed class NpgsqlTypeMapping - { - internal NpgsqlTypeMapping( - string pgTypeName, - NpgsqlDbType? npgsqlDbType, DbType[]? dbTypes, Type[]? clrTypes, DbType? inferredDbType, - NpgsqlTypeHandlerFactory typeHandlerFactory) - { - PgTypeName = pgTypeName; - NpgsqlDbType = npgsqlDbType; - DbTypes = dbTypes ?? EmptyDbTypes; - ClrTypes = clrTypes ?? EmptyClrTypes; - InferredDbType = inferredDbType; - TypeHandlerFactory = typeHandlerFactory; - } - - /// - /// The name of the PostgreSQL type name, as it appears in the pg_type catalog. - /// - /// - /// This can a a partial name (without the schema), or a fully-qualified name - /// (schema.typename) - the latter can be used if you have two types with the same - /// name in different schemas. - /// - public string PgTypeName { get; } - - /// - /// The that corresponds to this type. Setting an - /// 's property - /// to this value will make Npgsql write its value to PostgreSQL with this mapping. - /// - public NpgsqlDbType? NpgsqlDbType { get; } - - /// - /// A set of s that correspond to this type. Setting an - /// 's property - /// to one of these values will make Npgsql write its value to PostgreSQL with this mapping. - /// - public DbType[] DbTypes { get; } - - /// - /// A set of CLR types that correspond to this type. Setting an - /// 's property - /// to one of these types will make Npgsql write its value to PostgreSQL with this mapping. - /// - public Type[] ClrTypes { get; } - - /// - /// Determines what is returned from when this mapping - /// is used. - /// - public DbType? InferredDbType { get; } - - /// - /// A factory for a type handler that will be used to read and write values for PostgreSQL type. - /// - public NpgsqlTypeHandlerFactory TypeHandlerFactory { get; } - - /// - /// The default CLR type that handlers produced by this factory will read and write. - /// Used by the EF Core provider (and possibly others in the future). - /// - internal Type DefaultClrType => TypeHandlerFactory.DefaultValueType; - - /// - /// Returns a string that represents the current object. - /// - public override string ToString() => $"{PgTypeName} => {TypeHandlerFactory.GetType().Name}"; - - static readonly DbType[] EmptyDbTypes = new DbType[0]; - static readonly Type[] EmptyClrTypes = new Type[0]; - } -} diff --git a/src/Npgsql/TypeMapping/TypeMapperBase.cs b/src/Npgsql/TypeMapping/TypeMapperBase.cs deleted file mode 100644 index 5c13c9c4dd..0000000000 --- a/src/Npgsql/TypeMapping/TypeMapperBase.cs +++ /dev/null @@ -1,133 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Reflection; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandlers.CompositeHandlers; -using Npgsql.TypeHandling; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping -{ - abstract class TypeMapperBase : INpgsqlTypeMapper - { - internal Dictionary Mappings { get; } = new Dictionary(); - - public INpgsqlNameTranslator DefaultNameTranslator { get; } - - protected TypeMapperBase(INpgsqlNameTranslator defaultNameTranslator) - { - if (defaultNameTranslator == null) - throw new ArgumentNullException(nameof(defaultNameTranslator)); - - DefaultNameTranslator = defaultNameTranslator; - } - - #region Mapping management - - public virtual INpgsqlTypeMapper AddMapping(NpgsqlTypeMapping mapping) - { - if (Mappings.ContainsKey(mapping.PgTypeName)) - RemoveMapping(mapping.PgTypeName); - Mappings[mapping.PgTypeName] = mapping; - return this; - } - - public virtual bool RemoveMapping(string pgTypeName) => Mappings.Remove(pgTypeName); - - IEnumerable INpgsqlTypeMapper.Mappings => Mappings.Values; - - public abstract void Reset(); - - #endregion Mapping management - - #region Enum mapping - - public INpgsqlTypeMapper MapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum - { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - if (nameTranslator == null) - nameTranslator = DefaultNameTranslator; - if (pgName == null) - pgName = GetPgName(typeof(TEnum), nameTranslator); - - return AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = pgName, - ClrTypes = new[] { typeof(TEnum) }, - TypeHandlerFactory = new EnumTypeHandlerFactory(nameTranslator) - }.Build()); - } - - public bool UnmapEnum(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - where TEnum : struct, Enum - { - if (pgName != null && pgName.Trim() == "") - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - if (nameTranslator == null) - nameTranslator = DefaultNameTranslator; - if (pgName == null) - pgName = GetPgName(typeof(TEnum), nameTranslator); - - return RemoveMapping(pgName); - } - - #endregion Enum mapping - - #region Composite mapping - - public INpgsqlTypeMapper MapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => MapComposite(pgName, nameTranslator, typeof(T), t => new CompositeTypeHandlerFactory(t)); - - public INpgsqlTypeMapper MapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => MapComposite(pgName, nameTranslator, clrType, t => (NpgsqlTypeHandlerFactory) - Activator.CreateInstance(typeof(CompositeTypeHandlerFactory<>).MakeGenericType(clrType), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, new object[] { t }, null)!); - - INpgsqlTypeMapper MapComposite(string? pgName, INpgsqlNameTranslator? nameTranslator, Type type, Func factory) - { - if (pgName != null && string.IsNullOrWhiteSpace(pgName)) - throw new ArgumentException("pgName can't be empty.", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(type, nameTranslator); - - return AddMapping( - new NpgsqlTypeMappingBuilder - { - PgTypeName = pgName, - ClrTypes = new[] { type }, - TypeHandlerFactory = factory(nameTranslator), - } - .Build()); - } - - public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => UnmapComposite(typeof(T), pgName, nameTranslator); - - public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - { - if (pgName != null && string.IsNullOrWhiteSpace(pgName)) - throw new ArgumentException("pgName can't be empty.", nameof(pgName)); - - nameTranslator ??= DefaultNameTranslator; - pgName ??= GetPgName(clrType, nameTranslator); - - return RemoveMapping(pgName); - } - - #endregion Composite mapping - - #region Misc - - // TODO: why does ReSharper think `GetCustomAttribute` is non-nullable? - // ReSharper disable once ConstantConditionalAccessQualifier ConstantNullCoalescingCondition - static string GetPgName(Type clrType, INpgsqlNameTranslator nameTranslator) - => clrType.GetCustomAttribute()?.PgName - ?? nameTranslator.TranslateTypeName(clrType.Name); - - #endregion Misc - } -} diff --git a/src/Npgsql/TypeMapping/TypeMappingAttribute.cs b/src/Npgsql/TypeMapping/TypeMappingAttribute.cs deleted file mode 100644 index 1ceba444b9..0000000000 --- a/src/Npgsql/TypeMapping/TypeMappingAttribute.cs +++ /dev/null @@ -1,132 +0,0 @@ -using System; -using System.Data; -using System.Linq; -using System.Text; -using JetBrains.Annotations; -using NpgsqlTypes; - -namespace Npgsql.TypeMapping -{ - [AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] - [MeansImplicitUse] - class TypeMappingAttribute : Attribute - { - /// - /// Maps an Npgsql type handler to a PostgreSQL type. - /// - /// A PostgreSQL type name as it appears in the pg_type table. - /// - /// A member of which represents this PostgreSQL type. - /// An with set to - /// this value will be sent with the type handler mapped by this attribute. - /// - /// - /// All members of which represent this PostgreSQL type. - /// An with set to - /// one of these values will be sent with the type handler mapped by this attribute. - /// - /// - /// Any .NET type which corresponds to this PostgreSQL type. - /// An with set to - /// one of these values will be sent with the type handler mapped by this attribute. - /// - /// - /// The "primary" which best corresponds to this PostgreSQL type. - /// When or - /// set, will be set to this value. - /// - internal TypeMappingAttribute(string pgName, NpgsqlDbType? npgsqlDbType, DbType[]? dbTypes, Type[]? clrTypes, DbType? inferredDbType) - { - if (string.IsNullOrWhiteSpace(pgName)) - throw new ArgumentException("pgName can't be empty", nameof(pgName)); - - PgName = pgName; - NpgsqlDbType = npgsqlDbType; - DbTypes = dbTypes ?? new DbType[0]; - ClrTypes = clrTypes ?? new Type[0]; - InferredDbType = inferredDbType; - } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType[] dbTypes, Type[]? clrTypes, DbType inferredDbType) - : this(pgName, (NpgsqlDbType?)npgsqlDbType, dbTypes, clrTypes, inferredDbType) - { } - - //internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType[] dbTypes=null, Type type=null) - // : this(pgName, npgsqlDbType, dbTypes, type == null ? null : new[] { type }) {} - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType) - : this(pgName, npgsqlDbType, new DbType[0], new Type[0], null) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType inferredDbType) - : this(pgName, npgsqlDbType, new DbType[0], new Type[0], inferredDbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType[] dbTypes, Type clrType, DbType inferredDbType) - : this(pgName, npgsqlDbType, dbTypes, new[] { clrType }, inferredDbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType[] dbTypes) - : this(pgName, npgsqlDbType, dbTypes, new Type[0], null) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType dbType, Type[] clrTypes) - : this(pgName, npgsqlDbType, new[] { dbType }, clrTypes, dbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, DbType dbType, Type? clrType = null) - : this(pgName, npgsqlDbType, new[] { dbType }, clrType == null ? null : new[] { clrType }, dbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, Type[] clrTypes, DbType inferredDbType) - : this(pgName, npgsqlDbType, new DbType[0], clrTypes, inferredDbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, Type[] clrTypes) - : this(pgName, npgsqlDbType, new DbType[0], clrTypes, null) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, Type clrType, DbType inferredDbType) - : this(pgName, npgsqlDbType, new DbType[0], new[] { clrType }, inferredDbType) - { } - - internal TypeMappingAttribute(string pgName, NpgsqlDbType npgsqlDbType, Type clrType) - : this(pgName, npgsqlDbType, new DbType[0], new[] { clrType }, null) - { } - - /// - /// Read-only parameter - /// - internal TypeMappingAttribute(string pgName) - : this(pgName, null, null, null, null) - { } - - internal string PgName { get; } - internal NpgsqlDbType? NpgsqlDbType { get; } - internal DbType[] DbTypes { get; } - internal Type[] ClrTypes { get; } - internal DbType? InferredDbType { get; } - - /// - /// Returns a string that represents the current object. - /// - /// - public override string ToString() - { - var sb = new StringBuilder(); - sb.AppendFormat("[{0} NpgsqlDbType={1}", PgName, NpgsqlDbType); - if (DbTypes.Length > 0) - { - sb.Append(" DbTypes="); - sb.Append(string.Join(",", DbTypes.Select(t => t.ToString()))); - } - if (ClrTypes.Length > 0) - { - sb.Append(" ClrTypes="); - sb.Append(string.Join(",", ClrTypes.Select(t => t.Name))); - } - sb.AppendFormat("]"); - return sb.ToString(); - } - } -} diff --git a/src/Npgsql/TypeMapping/UserTypeMapper.cs b/src/Npgsql/TypeMapping/UserTypeMapper.cs new file mode 100644 index 0000000000..35fabb90fe --- /dev/null +++ b/src/Npgsql/TypeMapping/UserTypeMapper.cs @@ -0,0 +1,278 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Npgsql.Internal; +using Npgsql.Internal.Composites; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.NameTranslation; +using Npgsql.PostgresTypes; +using NpgsqlTypes; + +namespace Npgsql.TypeMapping; + +/// +/// The base class for user type mappings. +/// +public abstract class UserTypeMapping +{ + /// + /// The name of the PostgreSQL type that this mapping is for. + /// + public string PgTypeName { get; } + /// + /// The CLR type that this mapping is for. + /// + public Type ClrType { get; } + + internal UserTypeMapping(string pgTypeName, Type type) + => (PgTypeName, ClrType) = (pgTypeName, type); + + internal abstract void AddMapping(TypeInfoMappingCollection mappings); + internal abstract void AddArrayMapping(TypeInfoMappingCollection mappings); +} + +sealed class UserTypeMapper : PgTypeInfoResolverFactory +{ + readonly List _mappings; + public IList Items => _mappings; + + public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = NpgsqlSnakeCaseNameTranslator.Instance; + + UserTypeMapper(IEnumerable mappings) => _mappings = new List(mappings); + public UserTypeMapper() => _mappings = new(); + + public UserTypeMapper Clone() => new(_mappings) { DefaultNameTranslator = DefaultNameTranslator }; + + public UserTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + { + Unmap(typeof(TEnum), out var resolvedName, pgName, nameTranslator); + Items.Add(new EnumMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + public bool UnmapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + where TEnum : struct, Enum + => Unmap(typeof(TEnum), out _, pgName, nameTranslator ?? DefaultNameTranslator); + + [UnconditionalSuppressMessage("Trimming", "IL2111", Justification = "MapEnum TEnum has less DAM annotations than clrType.")] + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + public UserTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (!clrType.IsEnum || !clrType.IsValueType) + throw new ArgumentException("Type must be a concrete Enum", nameof(clrType)); + + var openMethod = typeof(UserTypeMapper).GetMethod(nameof(MapEnum), new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; + var method = openMethod.MakeGenericMethod(clrType); + method.Invoke(this, new object?[] { pgName, nameTranslator }); + return this; + } + + public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)]Type clrType,string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (!clrType.IsEnum || !clrType.IsValueType) + throw new ArgumentException("Type must be a concrete Enum", nameof(clrType)); + + return Unmap(clrType, out _, pgName, nameTranslator ?? DefaultNameTranslator); + } + + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public UserTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : class + { + Unmap(typeof(T), out var resolvedName, pgName, nameTranslator); + Items.Add(new CompositeMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public UserTypeMapper MapStructComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : struct + { + Unmap(typeof(T), out var resolvedName, pgName, nameTranslator); + Items.Add(new StructCompositeMapping(resolvedName, nameTranslator ?? DefaultNameTranslator)); + return this; + } + + [UnconditionalSuppressMessage("Trimming", "IL2111", Justification = "MapStructComposite and MapComposite have identical DAM annotations to clrType.")] + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public UserTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (clrType.IsConstructedGenericType && clrType.GetGenericTypeDefinition() == typeof(Nullable<>)) + throw new ArgumentException("Cannot map nullable.", nameof(clrType)); + + var openMethod = typeof(UserTypeMapper).GetMethod( + clrType.IsValueType ? nameof(MapStructComposite) : nameof(MapComposite), + new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; + + var method = openMethod.MakeGenericMethod(clrType); + + method.Invoke(this, new object?[] { pgName, nameTranslator }); + + return this; + } + + public bool UnmapComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : class + => UnmapComposite(typeof(T), pgName, nameTranslator); + + public bool UnmapStructComposite(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where T : struct + => UnmapComposite(typeof(T), pgName, nameTranslator); + + public bool UnmapComposite(Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => Unmap(clrType, out _, pgName, nameTranslator); + + bool Unmap(Type type, out string resolvedName, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + { + if (pgName != null && pgName.Trim() == "") + throw new ArgumentException("pgName can't be empty", nameof(pgName)); + + nameTranslator ??= DefaultNameTranslator; + resolvedName = pgName ??= GetPgName(type, nameTranslator); + + UserTypeMapping? toRemove = null; + foreach (var item in _mappings) + if (item.PgTypeName == pgName) + toRemove = item; + + return toRemove is not null && _mappings.Remove(toRemove); + } + + static string GetPgName(Type type, INpgsqlNameTranslator nameTranslator) + => type.GetCustomAttribute()?.PgName + ?? nameTranslator.TranslateTypeName(type.Name); + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(new(_mappings)); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(new(_mappings)); + + class Resolver : IPgTypeInfoResolver + { + protected readonly List _userTypeMappings; + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public Resolver(List userTypeMappings) => _userTypeMappings = userTypeMappings; + + PgTypeInfo? IPgTypeInfoResolver.GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + foreach (var userTypeMapping in _userTypeMappings) + userTypeMapping.AddMapping(mappings); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public ArrayResolver(List userTypeMappings) : base(userTypeMappings) { } + + PgTypeInfo? IPgTypeInfoResolver.GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + foreach (var userTypeMapping in _userTypeMappings) + userTypeMapping.AddArrayMapping(mappings); + + return mappings; + } + } + + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + sealed class CompositeMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T> : UserTypeMapping where T : class + { + readonly INpgsqlNameTranslator _nameTranslator; + + public CompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(T)) + => _nameTranslator = nameTranslator; + + internal override void AddMapping(TypeInfoMappingCollection mappings) + { + mappings.AddType(PgTypeName, (options, mapping, _) => + { + var pgType = mapping.GetPgType(options); + if (pgType is not PostgresCompositeType compositeType) + throw new InvalidOperationException("Composite mapping must be to a composite type"); + + return mapping.CreateInfo(options, new CompositeConverter( + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + }, isDefault: true); + } + + internal override void AddArrayMapping(TypeInfoMappingCollection mappings) => mappings.AddArrayType(PgTypeName); + } + + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + sealed class StructCompositeMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T> : UserTypeMapping where T : struct + { + readonly INpgsqlNameTranslator _nameTranslator; + + public StructCompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(T)) + => _nameTranslator = nameTranslator; + + internal override void AddMapping(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(PgTypeName, (options, mapping, dataTypeNameMatch) => + { + var pgType = mapping.GetPgType(options); + if (pgType is not PostgresCompositeType compositeType) + throw new InvalidOperationException("Composite mapping must be to a composite type"); + + return mapping.CreateInfo(options, new CompositeConverter( + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + }, isDefault: true); + } + + internal override void AddArrayMapping(TypeInfoMappingCollection mappings) => mappings.AddStructArrayType(PgTypeName); + } + + internal abstract class EnumMapping : UserTypeMapping + { + internal INpgsqlNameTranslator NameTranslator { get; } + + public EnumMapping(string pgTypeName, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)]Type enumClrType, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, enumClrType) + => NameTranslator = nameTranslator; + } + + sealed class EnumMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> : EnumMapping + where TEnum : struct, Enum + { + readonly Dictionary _enumToLabel = new(); + readonly Dictionary _labelToEnum = new(); + + public EnumMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) + : base(pgTypeName, typeof(TEnum), nameTranslator) + { + foreach (var field in typeof(TEnum).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var attribute = (PgNameAttribute?)field.GetCustomAttribute(typeof(PgNameAttribute), false); + var enumName = attribute is null + ? nameTranslator.TranslateMemberName(field.Name) + : attribute.PgName; + var enumValue = (TEnum)field.GetValue(null)!; + + _enumToLabel[enumValue] = enumName; + _labelToEnum[enumName] = enumValue; + } + } + + internal override void AddMapping(TypeInfoMappingCollection mappings) + => mappings.AddStructType(PgTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new EnumConverter(_enumToLabel, _labelToEnum, options.TextEncoding), preferredFormat: DataFormat.Text), isDefault: true); + + internal override void AddArrayMapping(TypeInfoMappingCollection mappings) => mappings.AddStructArrayType(PgTypeName); + } +} + diff --git a/src/Npgsql/UnpooledDataSource.cs b/src/Npgsql/UnpooledDataSource.cs new file mode 100644 index 0000000000..549a45f9b8 --- /dev/null +++ b/src/Npgsql/UnpooledDataSource.cs @@ -0,0 +1,50 @@ +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Internal; +using Npgsql.Util; + +namespace Npgsql; + +sealed class UnpooledDataSource : NpgsqlDataSource +{ + public UnpooledDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) + : base(settings, dataSourceConfig) + { + } + + volatile int _numConnectors; + + internal override (int Total, int Idle, int Busy) Statistics => (_numConnectors, 0, _numConnectors); + + internal override bool OwnsConnectors => true; + + internal override async ValueTask Get( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + CheckDisposed(); + + var connector = new NpgsqlConnector(this, conn); + await connector.Open(timeout, async, cancellationToken).ConfigureAwait(false); + Interlocked.Increment(ref _numConnectors); + return connector; + } + + internal override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) + { + connector = null; + return false; + } + + internal override ValueTask OpenNewConnector( + NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + => new((NpgsqlConnector?)null); + + internal override void Return(NpgsqlConnector connector) + { + Interlocked.Decrement(ref _numConnectors); + connector.Close(); + } + + internal override void Clear() {} +} diff --git a/src/Npgsql/Util/CodeAnnotations.cs b/src/Npgsql/Util/CodeAnnotations.cs deleted file mode 100644 index 63c38a4d87..0000000000 --- a/src/Npgsql/Util/CodeAnnotations.cs +++ /dev/null @@ -1,237 +0,0 @@ -using System; - -#pragma warning disable 1591 -// ReSharper disable UnusedMember.Global -// ReSharper disable MemberCanBePrivate.Global -// ReSharper disable UnusedAutoPropertyAccessor.Global -// ReSharper disable IntroduceOptionalParameters.Global -// ReSharper disable MemberCanBeProtected.Global -// ReSharper disable InconsistentNaming -// ReSharper disable CheckNamespace - -namespace JetBrains.Annotations -{ - /// - /// Indicates that the value of the marked element could be null sometimes, - /// so the check for null is necessary before its usage. - /// - /// - /// [CanBeNull] public object Test() { return null; } - /// public void UseTest() { - /// var p = Test(); - /// var s = p.ToString(); // Warning: Possible 'System.NullReferenceException' - /// } - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field | AttributeTargets.Event)] - sealed class CanBeNullAttribute : Attribute { - // ReSharper disable once EmptyConstructor - public CanBeNullAttribute() {} - } - - /// - /// Indicates that the value of the marked element could never be null. - /// - /// - /// [NotNull] public object Foo() { - /// return null; // Warning: Possible 'null' assignment - /// } - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field | AttributeTargets.Event)] - sealed class NotNullAttribute : Attribute { } - - /// - /// Can be appplied to symbols of types derived from IEnumerable as well as to symbols of Task - /// and Lazy classes to indicate that the value of a collection item, of the Task.Result property - /// or of the Lazy.Value property can never be null. - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field)] - sealed class ItemNotNullAttribute : Attribute { } - - /// - /// Can be appplied to symbols of types derived from IEnumerable as well as to symbols of Task - /// and Lazy classes to indicate that the value of a collection item, of the Task.Result property - /// or of the Lazy.Value property can be null. - /// - [AttributeUsage( - AttributeTargets.Method | AttributeTargets.Parameter | AttributeTargets.Property | - AttributeTargets.Delegate | AttributeTargets.Field)] - sealed class ItemCanBeNullAttribute : Attribute { } - - /// - /// Indicates that the marked symbol is used implicitly (e.g. via reflection, in external library), - /// so this symbol will not be marked as unused (as well as by other usage inspections). - /// - [AttributeUsage(AttributeTargets.All)] - sealed class UsedImplicitlyAttribute : Attribute - { - public UsedImplicitlyAttribute() - : this(ImplicitUseKindFlags.Default, ImplicitUseTargetFlags.Default) - { } - - public UsedImplicitlyAttribute(ImplicitUseKindFlags useKindFlags) - : this(useKindFlags, ImplicitUseTargetFlags.Default) - { } - - public UsedImplicitlyAttribute(ImplicitUseTargetFlags targetFlags) - : this(ImplicitUseKindFlags.Default, targetFlags) - { } - - public UsedImplicitlyAttribute(ImplicitUseKindFlags useKindFlags, ImplicitUseTargetFlags targetFlags) - { - UseKindFlags = useKindFlags; - TargetFlags = targetFlags; - } - - public ImplicitUseKindFlags UseKindFlags { get; private set; } - public ImplicitUseTargetFlags TargetFlags { get; private set; } - } - - /// - /// Should be used on attributes and causes ReSharper to not mark symbols marked with such attributes - /// as unused (as well as by other usage inspections) - /// - [AttributeUsage(AttributeTargets.Class | AttributeTargets.GenericParameter)] - sealed class MeansImplicitUseAttribute : Attribute - { - public MeansImplicitUseAttribute() - : this(ImplicitUseKindFlags.Default, ImplicitUseTargetFlags.Default) - { } - - public MeansImplicitUseAttribute(ImplicitUseKindFlags useKindFlags) - : this(useKindFlags, ImplicitUseTargetFlags.Default) - { } - - public MeansImplicitUseAttribute(ImplicitUseTargetFlags targetFlags) - : this(ImplicitUseKindFlags.Default, targetFlags) - { } - - public MeansImplicitUseAttribute(ImplicitUseKindFlags useKindFlags, ImplicitUseTargetFlags targetFlags) - { - UseKindFlags = useKindFlags; - TargetFlags = targetFlags; - } - - [UsedImplicitly] - public ImplicitUseKindFlags UseKindFlags { get; private set; } - [UsedImplicitly] - public ImplicitUseTargetFlags TargetFlags { get; private set; } - } - - [Flags] - internal enum ImplicitUseKindFlags - { - Default = Access | Assign | InstantiatedWithFixedConstructorSignature, - /// Only entity marked with attribute considered used. - Access = 1, - /// Indicates implicit assignment to a member. - Assign = 2, - /// - /// Indicates implicit instantiation of a type with fixed constructor signature. - /// That means any unused constructor parameters won't be reported as such. - /// - InstantiatedWithFixedConstructorSignature = 4, - /// Indicates implicit instantiation of a type. - InstantiatedNoFixedConstructorSignature = 8, - } - - /// - /// Specify what is considered used implicitly when marked - /// with or . - /// - [Flags] - internal enum ImplicitUseTargetFlags - { - Default = Itself, - Itself = 1, - /// Members of entity marked with attribute are considered used. - Members = 2, - /// Entity marked with attribute and all its members considered used. - WithMembers = Itself | Members - } - - /// - /// Describes dependency between method input and output. - /// - /// - ///

Function Definition Table syntax:

- /// - /// FDT ::= FDTRow [;FDTRow]* - /// FDTRow ::= Input => Output | Output <= Input - /// Input ::= ParameterName: Value [, Input]* - /// Output ::= [ParameterName: Value]* {halt|stop|void|nothing|Value} - /// Value ::= true | false | null | notnull | canbenull - /// - /// If method has single input parameter, it's name could be omitted.
- /// Using halt (or void/nothing, which is the same) - /// for method output means that the methos doesn't return normally.
- /// canbenull annotation is only applicable for output parameters.
- /// You can use multiple [ContractAnnotation] for each FDT row, - /// or use single attribute with rows separated by semicolon.
- ///
- /// - /// - /// [ContractAnnotation("=> halt")] - /// public void TerminationMethod() - /// - /// - /// [ContractAnnotation("halt <= condition: false")] - /// public void Assert(bool condition, string text) // regular assertion method - /// - /// - /// [ContractAnnotation("s:null => true")] - /// public bool IsNullOrEmpty(string s) // string.IsNullOrEmpty() - /// - /// - /// // A method that returns null if the parameter is null, - /// // and not null if the parameter is not null - /// [ContractAnnotation("null => null; notnull => notnull")] - /// public object Transform(object data) - /// - /// - /// [ContractAnnotation("s:null=>false; =>true,result:notnull; =>false, result:null")] - /// public bool TryParse(string s, out Person result) - /// - /// - [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] - sealed class ContractAnnotationAttribute : Attribute - { - public ContractAnnotationAttribute([NotNull] string contract) - : this(contract, false) - { } - - public ContractAnnotationAttribute([NotNull] string contract, bool forceFullStates) - { - Contract = contract; - ForceFullStates = forceFullStates; - } - - public string Contract { get; private set; } - public bool ForceFullStates { get; private set; } - } - - /// - /// Indicates that the function argument should be string literal and match one - /// of the parameters of the caller function. For example, ReSharper annotates - /// the parameter of . - /// - /// - /// public void Foo(string param) { - /// if (param == null) - /// throw new ArgumentNullException("par"); // Warning: Cannot resolve symbol - /// } - /// - [AttributeUsage(AttributeTargets.Parameter)] - sealed class InvokerParameterNameAttribute : Attribute { } - - /// - /// Indicates that IEnumerable, passed as parameter, is not enumerated. - /// - [AttributeUsage(AttributeTargets.Parameter)] - sealed class NoEnumerationAttribute : Attribute { } -} diff --git a/src/Npgsql/Util/ManualResetValueTaskSource.cs b/src/Npgsql/Util/ManualResetValueTaskSource.cs index e4656f9675..55e45aa225 100644 --- a/src/Npgsql/Util/ManualResetValueTaskSource.cs +++ b/src/Npgsql/Util/ManualResetValueTaskSource.cs @@ -1,22 +1,21 @@ using System; using System.Threading.Tasks.Sources; -namespace Npgsql.Util +namespace Npgsql.Util; + +sealed class ManualResetValueTaskSource : IValueTaskSource, IValueTaskSource { - sealed class ManualResetValueTaskSource : IValueTaskSource, IValueTaskSource - { - ManualResetValueTaskSourceCore _core; // mutable struct; do not make this readonly + ManualResetValueTaskSourceCore _core; // mutable struct; do not make this readonly - public bool RunContinuationsAsynchronously { get => _core.RunContinuationsAsynchronously; set => _core.RunContinuationsAsynchronously = value; } - public short Version => _core.Version; - public void Reset() => _core.Reset(); - public void SetResult(T result) => _core.SetResult(result); - public void SetException(Exception error) => _core.SetException(error); + public bool RunContinuationsAsynchronously { get => _core.RunContinuationsAsynchronously; set => _core.RunContinuationsAsynchronously = value; } + public short Version => _core.Version; + public void Reset() => _core.Reset(); + public void SetResult(T result) => _core.SetResult(result); + public void SetException(Exception error) => _core.SetException(error); - public T GetResult(short token) => _core.GetResult(token); - void IValueTaskSource.GetResult(short token) => _core.GetResult(token); - public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token); - public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) - => _core.OnCompleted(continuation, state, token, flags); - } -} + public T GetResult(short token) => _core.GetResult(token); + void IValueTaskSource.GetResult(short token) => _core.GetResult(token); + public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token); + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + => _core.OnCompleted(continuation, state, token, flags); +} \ No newline at end of file diff --git a/src/Npgsql/Util/NpgsqlTimeout.cs b/src/Npgsql/Util/NpgsqlTimeout.cs new file mode 100644 index 0000000000..79c44d6c4b --- /dev/null +++ b/src/Npgsql/Util/NpgsqlTimeout.cs @@ -0,0 +1,52 @@ +using System; +using System.Threading; +using Npgsql.Internal; + +namespace Npgsql.Util; + +/// +/// Represents a timeout that will expire at some point. +/// +public readonly struct NpgsqlTimeout +{ + readonly DateTime _expiration; + + internal static readonly NpgsqlTimeout Infinite = new(TimeSpan.Zero); + + internal NpgsqlTimeout(TimeSpan expiration) + => _expiration = expiration > TimeSpan.Zero + ? DateTime.UtcNow + expiration + : expiration == TimeSpan.Zero + ? DateTime.MaxValue + : DateTime.MinValue; + + internal void Check() + { + if (HasExpired) + ThrowHelper.ThrowNpgsqlExceptionWithInnerTimeoutException("The operation has timed out"); + } + + internal void CheckAndApply(NpgsqlConnector connector) + { + if (!IsSet) + return; + + var timeLeft = CheckAndGetTimeLeft(); + // Set the remaining timeout on the read and write buffers + connector.ReadBuffer.Timeout = connector.WriteBuffer.Timeout = timeLeft; + } + + internal bool IsSet => _expiration != DateTime.MaxValue; + + internal bool HasExpired => DateTime.UtcNow >= _expiration; + + internal TimeSpan CheckAndGetTimeLeft() + { + if (!IsSet) + return Timeout.InfiniteTimeSpan; + var timeLeft = _expiration - DateTime.UtcNow; + if (timeLeft <= TimeSpan.Zero) + Check(); + return timeLeft; + } +} diff --git a/src/Npgsql/Util/PGUtil.cs b/src/Npgsql/Util/PGUtil.cs deleted file mode 100644 index a1fd07ea07..0000000000 --- a/src/Npgsql/Util/PGUtil.cs +++ /dev/null @@ -1,204 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql.Util -{ - static class Statics - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static T Expect(IBackendMessage msg, NpgsqlConnector connector) - { - if (msg is T asT) - return asT; - - throw connector.Break( - new NpgsqlException($"Received backend message {msg.Code} while expecting {typeof(T).Name}. " + - "Please file a bug.")); - } - - internal static DeferDisposable Defer(Action action) => new DeferDisposable(action); - internal static DeferDisposable Defer(Action action, T arg) => new DeferDisposable(action, arg); - internal static DeferDisposable Defer(Action action, T1 arg1, T2 arg2) => new DeferDisposable(action, arg1, arg2); - // internal static AsyncDeferDisposable DeferAsync(Func func) => new AsyncDeferDisposable(func); - internal static AsyncDeferDisposable DeferAsync(Func func) => new AsyncDeferDisposable(func); - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - public DeferDisposable(Action action) => _action = action; - public void Dispose() => _action(); - } - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - readonly T _arg; - public DeferDisposable(Action action, T arg) - { - _action = action; - _arg = arg; - } - public void Dispose() => _action(_arg); - } - - internal readonly struct DeferDisposable : IDisposable - { - readonly Action _action; - readonly T1 _arg1; - readonly T2 _arg2; - public DeferDisposable(Action action, T1 arg1, T2 arg2) - { - _action = action; - _arg1 = arg1; - _arg2 = arg2; - } - public void Dispose() => _action(_arg1, _arg2); - } - - internal readonly struct AsyncDeferDisposable : IAsyncDisposable - { - readonly Func _func; - public AsyncDeferDisposable(Func func) => _func = func; - public async ValueTask DisposeAsync() => await _func(); - } - } - - // ReSharper disable once InconsistentNaming - static class PGUtil - { - internal static readonly UTF8Encoding UTF8Encoding = new UTF8Encoding(false, true); - internal static readonly UTF8Encoding RelaxedUTF8Encoding = new UTF8Encoding(false, false); - - internal const int BitsInInt = sizeof(int) * 8; - - internal static void ValidateBackendMessageCode(BackendMessageCode code) - { - switch (code) - { - case BackendMessageCode.AuthenticationRequest: - case BackendMessageCode.BackendKeyData: - case BackendMessageCode.BindComplete: - case BackendMessageCode.CloseComplete: - case BackendMessageCode.CommandComplete: - case BackendMessageCode.CopyData: - case BackendMessageCode.CopyDone: - case BackendMessageCode.CopyBothResponse: - case BackendMessageCode.CopyInResponse: - case BackendMessageCode.CopyOutResponse: - case BackendMessageCode.DataRow: - case BackendMessageCode.EmptyQueryResponse: - case BackendMessageCode.ErrorResponse: - case BackendMessageCode.FunctionCall: - case BackendMessageCode.FunctionCallResponse: - case BackendMessageCode.NoData: - case BackendMessageCode.NoticeResponse: - case BackendMessageCode.NotificationResponse: - case BackendMessageCode.ParameterDescription: - case BackendMessageCode.ParameterStatus: - case BackendMessageCode.ParseComplete: - case BackendMessageCode.PasswordPacket: - case BackendMessageCode.PortalSuspended: - case BackendMessageCode.ReadyForQuery: - case BackendMessageCode.RowDescription: - return; - default: - throw new NpgsqlException("Unknown message code: " + code); - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static int RotateShift(int val, int shift) - => (val << shift) | (val >> (BitsInInt - shift)); - - internal static readonly Task TrueTask = Task.FromResult(true); - internal static readonly Task FalseTask = Task.FromResult(false); - - internal static StringComparer InvariantCaseIgnoringStringComparer => StringComparer.InvariantCultureIgnoreCase; - - internal static bool IsWindows => - System.Runtime.InteropServices.RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows); - } - - enum FormatCode : short - { - Text = 0, - Binary = 1 - } - - static class EnumerableExtensions - { - internal static string Join(this IEnumerable values, string separator) - { - return string.Join(separator, values); - } - } - - static class ExceptionExtensions - { - internal static Exception UnwrapAggregate(this Exception exception) - => exception is AggregateException agg ? agg.InnerException! : exception; - } - - /// - /// Represents a timeout that will expire at some point. - /// - public readonly struct NpgsqlTimeout - { - readonly DateTime _expiration; - internal DateTime Expiration => _expiration; - - internal static NpgsqlTimeout Infinite = new NpgsqlTimeout(TimeSpan.Zero); - - internal NpgsqlTimeout(TimeSpan expiration) - => _expiration = expiration == TimeSpan.Zero ? DateTime.MaxValue : DateTime.UtcNow + expiration; - - internal void Check() - { - if (HasExpired) - throw new TimeoutException(); - } - - internal void CheckAndApply(NpgsqlConnector connector) - { - if (!IsSet) - return; - - var timeLeft = TimeLeft; - if (timeLeft > TimeSpan.Zero) - { - // Set the remaining timeout on the read and write buffers - connector.ReadBuffer.Timeout = connector.WriteBuffer.Timeout = timeLeft; - - // Note that we set UserTimeout as well, otherwise the read timeout will get overwritten in ReadMessage - // Note also that we must set the read buffer's timeout directly (above), since the SSL handshake - // reads data directly from the buffer, without going through ReadMessage. - connector.UserTimeout = (int)Math.Ceiling(timeLeft.TotalMilliseconds); - } - - Check(); - } - - internal bool IsSet => _expiration != DateTime.MaxValue; - - internal bool HasExpired => DateTime.UtcNow >= Expiration; - - internal TimeSpan TimeLeft => IsSet ? Expiration - DateTime.UtcNow : Timeout.InfiniteTimeSpan; - } - - static class MethodInfos - { - internal static readonly ConstructorInfo InvalidCastExceptionCtor = - typeof(InvalidCastException).GetConstructor(new[] { typeof(string) })!; - - internal static readonly MethodInfo StringFormat = - typeof(string).GetMethod(nameof(string.Format), new[] { typeof(string), typeof(object) })!; - - internal static readonly MethodInfo ObjectGetType = - typeof(object).GetMethod(nameof(GetType), new Type[0])!; - } -} diff --git a/src/Npgsql/Util/ResettableCancellationTokenSource.cs b/src/Npgsql/Util/ResettableCancellationTokenSource.cs index 027c3353be..874d7a40f8 100644 --- a/src/Npgsql/Util/ResettableCancellationTokenSource.cs +++ b/src/Npgsql/Util/ResettableCancellationTokenSource.cs @@ -3,192 +3,228 @@ using System.Threading; using static System.Threading.Timeout; -namespace Npgsql.Util +namespace Npgsql.Util; + +/// +/// A wrapper around to simplify reset management. +/// +/// +/// Since there's no way to reset a once it was cancelled, +/// we need to make sure that an existing cancellation token source hasn't been cancelled, +/// every time we start it (see https://github.com/dotnet/runtime/issues/4694). +/// +sealed class ResettableCancellationTokenSource : IDisposable { - /// - /// A wrapper around to simplify reset management. - /// - /// - /// Since there's no way to reset a once it was cancelled, - /// we need to make sure that an existing cancellation token source hasn't been cancelled, - /// every time we start it (see https://github.com/dotnet/runtime/issues/4694). - /// - class ResettableCancellationTokenSource : IDisposable - { - bool isDisposed; + bool isDisposed; - public TimeSpan Timeout { get; set; } + public TimeSpan Timeout { get; set; } - volatile CancellationTokenSource _cts = new CancellationTokenSource(); - CancellationTokenRegistration _registration; + CancellationTokenSource _cts = new(); + CancellationTokenRegistration? _registration; - /// - /// Used, so we wouldn't concurently use the cts for the cancellation, while it's being disposed - /// - readonly object lockObject = new object(); + /// + /// Used, so we wouldn't concurently use the cts for the cancellation, while it's being disposed + /// + readonly object lockObject = new(); #if DEBUG - bool _isRunning; + bool _isRunning; #endif - public ResettableCancellationTokenSource() => Timeout = InfiniteTimeSpan; + public ResettableCancellationTokenSource() => Timeout = InfiniteTimeSpan; - public ResettableCancellationTokenSource(TimeSpan timeout) => Timeout = timeout; + public ResettableCancellationTokenSource(TimeSpan timeout) => Timeout = timeout; - /// - /// Set the timeout on the wrapped - /// and make sure that it hasn't been cancelled yet - /// - /// An optional cancellation token that will be linked with the - /// contained - /// The from the wrapped - public CancellationToken Start(CancellationToken cancellationToken = default) + /// + /// Set the timeout on the wrapped + /// and make sure that it hasn't been cancelled yet + /// + /// + /// An optional token to cancel the asynchronous operation. The default value is . + /// + /// The from the wrapped + public CancellationToken Start(CancellationToken cancellationToken = default) + { +#if DEBUG + Debug.Assert(!_isRunning); +#endif + lock (lockObject) { + // if there was an attempt to cancel while the connector was breaking + // we do nothing and return the default token + // as we're going to fail while reading or writing anyway + if (isDisposed) + { #if DEBUG - Debug.Assert(!_isRunning); + _isRunning = true; #endif + return CancellationToken.None; + } + _cts.CancelAfter(Timeout); if (_cts.IsCancellationRequested) { - lock (lockObject) - { - _cts.Dispose(); - _cts = new CancellationTokenSource(Timeout); - } + _cts.Dispose(); + _cts = new CancellationTokenSource(Timeout); } - if (cancellationToken.CanBeCanceled) - _registration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _cts); + } + if (cancellationToken.CanBeCanceled) + _registration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _cts); #if DEBUG - _isRunning = true; + _isRunning = true; #endif - return _cts.Token; + return _cts.Token; + } + + /// + /// Restart the timeout on the wrapped without reinitializing it, + /// even if is already set to + /// + public void RestartTimeoutWithoutReset() + { + lock (lockObject) + { + // if there was an attempt to cancel while the connector was breaking + // we do nothing and return the default token + // as we're going to fail while reading or writing anyway + if (!isDisposed) + _cts.CancelAfter(Timeout); } + } - /// - /// Restart the timeout on the wrapped without reinitializing it, - /// even if is already set to - /// - public void RestartTimeoutWithoutReset() => _cts.CancelAfter(Timeout); - - /// - /// Reset the wrapper to contain a unstarted and uncancelled - /// in order make sure the next call to will not invalidate - /// the cancellation token. - /// - /// An optional cancellation token that will be linked with the - /// contained - /// The from the wrapped - public CancellationToken Reset(CancellationToken cancellationToken = default) + /// + /// Reset the wrapper to contain a unstarted and uncancelled + /// in order make sure the next call to will not invalidate + /// the cancellation token. + /// + /// The from the wrapped + public CancellationToken Reset() + { + _registration?.Dispose(); + _registration = null; + lock (lockObject) { - _registration.Dispose(); - _cts.CancelAfter(InfiniteTimeSpan); - if (_cts.IsCancellationRequested) + // if there was an attempt to cancel while the connector was breaking + // we do nothing and return + // as we're going to fail while reading or writing anyway + if (isDisposed) { - lock (lockObject) - { - _cts.Dispose(); - _cts = new CancellationTokenSource(); - } - } - if (cancellationToken.CanBeCanceled) - _registration = cancellationToken.Register(cts => ((CancellationTokenSource)cts!).Cancel(), _cts); #if DEBUG - _isRunning = false; + _isRunning = false; #endif - return _cts.Token; - } + return CancellationToken.None; + } - /// - /// Reset the wrapper to contain a unstarted and uncancelled - /// in order make sure the next call to will not invalidate - /// the cancellation token. - /// - public void ResetCts() - { + _cts.CancelAfter(InfiniteTimeSpan); if (_cts.IsCancellationRequested) { _cts.Dispose(); _cts = new CancellationTokenSource(); } } - - /// - /// Set the timeout on the wrapped - /// to - /// - /// - /// can still arrive at a state - /// where it's value is if the - /// passed to gets a cancellation request. - /// If this is the case it will be resolved upon the next call to - /// or . Calling multiple times or without calling - /// first will do no any harm (besides eating a tiny amount of CPU cycles). - /// - public void Stop() - { - _cts.CancelAfter(InfiniteTimeSpan); - _registration.Dispose(); #if DEBUG - _isRunning = false; + _isRunning = false; #endif + return _cts.Token; + } + + /// + /// Reset the wrapper to contain a unstarted and uncancelled + /// in order make sure the next call to will not invalidate + /// the cancellation token. + /// + public void ResetCts() + { + if (_cts.IsCancellationRequested) + { + _cts.Dispose(); + _cts = new CancellationTokenSource(); } + } - /// - /// Cancel the wrapped - /// - public void Cancel() + /// + /// Set the timeout on the wrapped + /// to + /// + /// + /// can still arrive at a state + /// where it's value is if the + /// passed to gets a cancellation request. + /// If this is the case it will be resolved upon the next call to + /// or . Calling multiple times or without calling + /// first will do no any harm (besides eating a tiny amount of CPU cycles). + /// + public void Stop() + { + _registration?.Dispose(); + _registration = null; + lock (lockObject) { - lock (lockObject) - { - // if there was an attempt to cancel while the connector was breaking - // we do nothing - if (isDisposed) - return; - _cts.Cancel(); - } + // if there was an attempt to cancel while the connector was breaking + // we do nothing + if (!isDisposed) + _cts.CancelAfter(InfiniteTimeSpan); } +#if DEBUG + _isRunning = false; +#endif + } - /// - /// Cancel the wrapped after delay - /// - public void CancelAfter(int delay) + /// + /// Cancel the wrapped + /// + public void Cancel() + { + lock (lockObject) { - lock (lockObject) - { - // if there was an attempt to cancel while the connector was breaking - // we do nothing - if (isDisposed) - return; - _cts.CancelAfter(delay); - } + // if there was an attempt to cancel while the connector was breaking + // we do nothing + if (isDisposed) + return; + _cts.Cancel(); } + } - /// - /// The from the wrapped - /// . - /// - /// - /// The token is only valid after calling - /// and before calling the next time. - /// Otherwise you may end up with a token that has already been - /// cancelled or belongs to a cancellation token source that has - /// been disposed. - /// - public CancellationToken Token => _cts.Token; - - public bool IsCancellationRequested => _cts.IsCancellationRequested; - - public void Dispose() + /// + /// Cancel the wrapped after delay + /// + public void CancelAfter(int delay) + { + lock (lockObject) { - Debug.Assert(!isDisposed); + // if there was an attempt to cancel while the connector was breaking + // we do nothing + if (isDisposed) + return; + _cts.CancelAfter(delay); + } + } - lock (lockObject) - { - _registration.Dispose(); - _cts.Dispose(); + /// + /// The from the wrapped + /// . + /// + /// + /// The token is only valid after calling + /// and before calling the next time. + /// Otherwise you may end up with a token that has already been + /// cancelled or belongs to a cancellation token source that has + /// been disposed. + /// + public CancellationToken Token => _cts.Token; - isDisposed = true; - } + public bool IsCancellationRequested => _cts.IsCancellationRequested; + + public void Dispose() + { + Debug.Assert(!isDisposed); + + lock (lockObject) + { + _registration?.Dispose(); + _cts.Dispose(); + + isDisposed = true; } } } diff --git a/src/Npgsql/Util/Statics.cs b/src/Npgsql/Util/Statics.cs new file mode 100644 index 0000000000..2b1101171b --- /dev/null +++ b/src/Npgsql/Util/Statics.cs @@ -0,0 +1,96 @@ +using Npgsql.Internal; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace Npgsql.Util; + +static class Statics +{ + internal static readonly bool EnableAssertions; +#if DEBUG + internal static bool LegacyTimestampBehavior; + internal static bool DisableDateTimeInfinityConversions; +#else + internal static readonly bool LegacyTimestampBehavior; + internal static readonly bool DisableDateTimeInfinityConversions; +#endif + + static Statics() + { + EnableAssertions = AppContext.TryGetSwitch("Npgsql.EnableAssertions", out var enabled) && enabled; + LegacyTimestampBehavior = AppContext.TryGetSwitch("Npgsql.EnableLegacyTimestampBehavior", out enabled) && enabled; + DisableDateTimeInfinityConversions = AppContext.TryGetSwitch("Npgsql.DisableDateTimeInfinityConversions", out enabled) && enabled; + } + + internal static T Expect(IBackendMessage msg, NpgsqlConnector connector) + { + if (msg.GetType() != typeof(T)) + ThrowIfMsgWrongType(msg, connector); + + return (T)msg; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static T ExpectAny(IBackendMessage msg, NpgsqlConnector connector) + { + if (msg is T t) + return t; + + ThrowIfMsgWrongType(msg, connector); + return default; + } + + [DoesNotReturn] + static void ThrowIfMsgWrongType(IBackendMessage msg, NpgsqlConnector connector) + => throw connector.Break( + new NpgsqlException($"Received backend message {msg.Code} while expecting {typeof(T).Name}. Please file a bug.")); + + [Conditional("DEBUG")] + internal static void ValidateBackendMessageCode(BackendMessageCode code) + { + switch (code) + { + case BackendMessageCode.AuthenticationRequest: + case BackendMessageCode.BackendKeyData: + case BackendMessageCode.BindComplete: + case BackendMessageCode.CloseComplete: + case BackendMessageCode.CommandComplete: + case BackendMessageCode.CopyData: + case BackendMessageCode.CopyDone: + case BackendMessageCode.CopyBothResponse: + case BackendMessageCode.CopyInResponse: + case BackendMessageCode.CopyOutResponse: + case BackendMessageCode.DataRow: + case BackendMessageCode.EmptyQueryResponse: + case BackendMessageCode.ErrorResponse: + case BackendMessageCode.FunctionCall: + case BackendMessageCode.FunctionCallResponse: + case BackendMessageCode.NoData: + case BackendMessageCode.NoticeResponse: + case BackendMessageCode.NotificationResponse: + case BackendMessageCode.ParameterDescription: + case BackendMessageCode.ParameterStatus: + case BackendMessageCode.ParseComplete: + case BackendMessageCode.PasswordPacket: + case BackendMessageCode.PortalSuspended: + case BackendMessageCode.ReadyForQuery: + case BackendMessageCode.RowDescription: + return; + default: + ThrowUnknownMessageCode(code); + return; + } + + static void ThrowUnknownMessageCode(BackendMessageCode code) + => ThrowHelper.ThrowNpgsqlException($"Unknown message code: {code}"); + } +} + +static class EnumerableExtensions +{ + internal static string Join(this IEnumerable values, string separator) + => string.Join(separator, values); +} diff --git a/src/Npgsql/Util/StrongBox.cs b/src/Npgsql/Util/StrongBox.cs new file mode 100644 index 0000000000..d72c3140e0 --- /dev/null +++ b/src/Npgsql/Util/StrongBox.cs @@ -0,0 +1,41 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Util; + +abstract class StrongBox +{ + private protected StrongBox() { } + public abstract bool HasValue { get; } + public abstract object? Value { get; set; } + public abstract void Clear(); +} + +sealed class StrongBox : StrongBox +{ + bool _hasValue; + + [MaybeNull] T _typedValue; + [MaybeNull] + public T TypedValue { + get => _typedValue; + set + { + _hasValue = true; + _typedValue = value; + } + } + + public override bool HasValue => _hasValue; + + public override object? Value + { + get => TypedValue; + set => TypedValue = (T)value!; + } + + public override void Clear() + { + _hasValue = false; + TypedValue = default!; + } +} diff --git a/src/Npgsql/Util/SubReadStream.cs b/src/Npgsql/Util/SubReadStream.cs new file mode 100644 index 0000000000..6aaee9651a --- /dev/null +++ b/src/Npgsql/Util/SubReadStream.cs @@ -0,0 +1,227 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Util; + +// Adapted from https://github.com/dotnet/runtime/blob/83adfae6a6273d8fb4c69554aa3b1cc7cbf01c71/src/libraries/System.IO.Compression/src/System/IO/Compression/ZipCustomStreams.cs#L221 +sealed class SubReadStream : Stream +{ + readonly long _startInSuperStream; + long _positionInSuperStream; + readonly long _endInSuperStream; + readonly Stream _superStream; + readonly bool _canSeek; + bool _isDisposed; + + public SubReadStream(Stream superStream, long maxLength) + { + _startInSuperStream = -1; + _positionInSuperStream = 0; + _endInSuperStream = maxLength; + _superStream = superStream; + _canSeek = false; + _isDisposed = false; + } + + public SubReadStream(Stream superStream, long startPosition, long maxLength) + { + _startInSuperStream = startPosition; + _positionInSuperStream = startPosition; + _endInSuperStream = startPosition + maxLength; + _superStream = superStream; + _canSeek = superStream.CanSeek; + _isDisposed = false; + } + + public override long Length + { + get + { + ThrowIfDisposed(); + + if (!_canSeek) + throw new NotSupportedException(); + + return _endInSuperStream - _startInSuperStream; + } + } + + public override long Position + { + get + { + ThrowIfDisposed(); + + if (!_canSeek) + throw new NotSupportedException(); + + return _positionInSuperStream - _startInSuperStream; + } + set + { + ThrowIfDisposed(); + + throw new NotSupportedException(); + } + } + + public override bool CanRead => _superStream.CanRead && !_isDisposed; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + void ThrowIfDisposed() + { + if (_isDisposed) + throw new ObjectDisposedException(GetType().ToString()); + } + + void ThrowIfCantRead() + { + if (!CanRead) + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + // parameter validation sent to _superStream.Read + var origCount = count; + + ThrowIfDisposed(); + ThrowIfCantRead(); + + if (_canSeek && _superStream.Position != _positionInSuperStream) + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + if (_positionInSuperStream > _endInSuperStream - count) + count = (int)(_endInSuperStream - _positionInSuperStream); + + Debug.Assert(count >= 0); + Debug.Assert(count <= origCount); + + var ret = _superStream.Read(buffer, offset, count); + + _positionInSuperStream += ret; + return ret; + } + +#if !NETSTANDARD2_0 + public override int Read(Span destination) +#else + int Read(Span destination) +#endif + { + // parameter validation sent to _superStream.Read + var origCount = destination.Length; + var count = destination.Length; + + ThrowIfDisposed(); + ThrowIfCantRead(); + + if (_canSeek && _superStream.Position != _positionInSuperStream) + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + if (_positionInSuperStream + count > _endInSuperStream) + count = (int)(_endInSuperStream - _positionInSuperStream); + + Debug.Assert(count >= 0); + Debug.Assert(count <= origCount); + + var ret = _superStream.Read(destination.Slice(0, count)); + + _positionInSuperStream += ret; + return ret; + } + + public override int ReadByte() + { + Span b = stackalloc byte[1]; + return Read(b) == 1 ? b[0] : -1; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBufferArguments(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + +#if !NETSTANDARD2_0 + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#else + ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) +#endif + { + ThrowIfDisposed(); + ThrowIfCantRead(); + if (_canSeek && _superStream.Position != _positionInSuperStream) + { + _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin); + } + + if (_positionInSuperStream > _endInSuperStream - buffer.Length) + { + buffer = buffer.Slice(0, (int)(_endInSuperStream - _positionInSuperStream)); + } + + return Core(buffer, cancellationToken); + + async ValueTask Core(Memory buffer, CancellationToken cancellationToken) + { + var ret = await _superStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _positionInSuperStream += ret; + return ret; + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + public override void Flush() + { + ThrowIfDisposed(); + throw new NotSupportedException(); + } + + // Close the stream for reading. Note that this does NOT close the superStream (since + // the substream is just 'a chunk' of the super-stream + protected override void Dispose(bool disposing) + { + if (disposing && !_isDisposed) + { + _isDisposed = true; + } + base.Dispose(disposing); + } + +#if NETSTANDARD + void ValidateBufferArguments(byte[]? buffer, int offset, int count) + { + if (buffer is null) + ThrowHelper.ThrowArgumentNullException(nameof(buffer)); + + if (offset < 0) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(offset), "Offset is less than 0"); + + if ((uint)count > buffer.Length - offset) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Count larger than buffer minus offset"); + + } +#endif +} diff --git a/src/Npgsql/Util/TaskSchedulerAwaitable.cs b/src/Npgsql/Util/TaskSchedulerAwaitable.cs new file mode 100644 index 0000000000..be16d8fa55 --- /dev/null +++ b/src/Npgsql/Util/TaskSchedulerAwaitable.cs @@ -0,0 +1,38 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Util; + +readonly struct TaskSchedulerAwaitable : ICriticalNotifyCompletion +{ + readonly TaskScheduler _scheduler; + public TaskSchedulerAwaitable(TaskScheduler scheduler) => _scheduler = scheduler; + + public void GetResult() {} + public bool IsCompleted => false; + + public void OnCompleted(Action continuation) + { + var task = Task.Factory.StartNew(continuation, CancellationToken.None, + TaskCreationOptions.DenyChildAttach, + scheduler: _scheduler); + + // Exceptions should never happen as the continuation should be the async statemachine. + // It normally does its own error handling through the returned task unless it's an async void returning method. + // In which case we should absolutely let it bubble up to TaskScheduler.UnobservedTaskException. + OnFaulted(task); + + [Conditional("DEBUG")] + static void OnFaulted(Task task) + { + task.ContinueWith(t => Debug.Fail("Task scheduler task threw an unobserved exception"), TaskContinuationOptions.OnlyOnFaulted); + } + } + + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + + public TaskSchedulerAwaitable GetAwaiter() => this; +} diff --git a/src/Npgsql/Util/VersionExtensions.cs b/src/Npgsql/Util/VersionExtensions.cs index c42585bcda..4501dd78d2 100644 --- a/src/Npgsql/Util/VersionExtensions.cs +++ b/src/Npgsql/Util/VersionExtensions.cs @@ -1,17 +1,14 @@ using System; -namespace Npgsql.Util +namespace Npgsql.Util; + +static class VersionExtensions { - static class VersionExtensions - { - /// - /// Allocation free helper function to find if version is greater than expected - /// - public static bool IsGreaterOrEqual(this Version version, int major, int minor, int build) - => version.Major != major - ? version.Major > major - : version.Minor != minor - ? version.Minor > minor - : version.Build >= build; - } -} + /// + /// Allocation free helper function to find if version is greater than expected + /// + public static bool IsGreaterOrEqual(this Version version, int major, int minor = 0) + => version.Major != major + ? version.Major > major + : version.Minor >= minor; +} \ No newline at end of file diff --git a/src/Npgsql/VolatileResourceManager.cs b/src/Npgsql/VolatileResourceManager.cs index 53107c2ad2..239b62fe8e 100644 --- a/src/Npgsql/VolatileResourceManager.cs +++ b/src/Npgsql/VolatileResourceManager.cs @@ -1,187 +1,191 @@ using System; -using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Transactions; -using Npgsql.Logging; - -namespace Npgsql +using Microsoft.Extensions.Logging; +using Npgsql.Internal; + +namespace Npgsql; + +/// +/// +/// +/// +/// Note that a connection may be closed before its TransactionScope completes. In this case we close the NpgsqlConnection +/// as usual but the connector in a special list in the pool; it will be closed only when the scope completes. +/// +sealed class VolatileResourceManager : ISinglePhaseNotification { - /// - /// - /// - /// - /// Note that a connection may be closed before its TransactionScope completes. In this case we close the NpgsqlConnection - /// as usual but the connector in a special list in the pool; it will be closed only when the scope completes. - /// - class VolatileResourceManager : ISinglePhaseNotification + NpgsqlConnector _connector; + Transaction _transaction; + readonly string _txId; + NpgsqlTransaction _localTx = null!; + string? _preparedTxName; + bool IsPrepared => _preparedTxName != null; + bool _isDisposed; + + readonly ILogger _transactionLogger; + + const int MaximumRollbackAttempts = 20; + + internal VolatileResourceManager(NpgsqlConnection connection, Transaction transaction) { - NpgsqlConnector _connector; - Transaction _transaction; - readonly string _txId; - NpgsqlTransaction _localTx; - string? _preparedTxName; - bool IsPrepared => _preparedTxName != null; - bool _isDisposed; + _connector = connection.Connector!; + _transaction = transaction; + // _tx gets disposed by System.Transactions at some point, but we want to be able to log its local ID + _txId = transaction.TransactionInformation.LocalIdentifier; + _transactionLogger = _connector.LoggingConfiguration.TransactionLogger; + } - static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(VolatileResourceManager)); + internal void Init() + => _localTx = _connector.Connection!.BeginTransaction(ConvertIsolationLevel(_transaction.IsolationLevel)); - const int MaximumRollbackAttempts = 20; + public void SinglePhaseCommit(SinglePhaseEnlistment singlePhaseEnlistment) + { + CheckDisposed(); + LogMessages.CommittingSinglePhaseTransaction(_transactionLogger, _txId, _connector.Id); - internal VolatileResourceManager(NpgsqlConnection connection, Transaction transaction) + try + { + _localTx.Commit(); + singlePhaseEnlistment.Committed(); + } + catch (PostgresException e) { - _connector = connection.Connector!; - _transaction = transaction; - // _tx gets disposed by System.Transactions at some point, but we want to be able to log its local ID - _txId = transaction.TransactionInformation.LocalIdentifier; - _localTx = connection.BeginTransaction(ConvertIsolationLevel(_transaction.IsolationLevel)); + singlePhaseEnlistment.Aborted(e); } + catch (Exception e) + { + singlePhaseEnlistment.InDoubt(e); + } + finally + { + Dispose(); + } + } + + public void Prepare(PreparingEnlistment preparingEnlistment) + { + CheckDisposed(); + LogMessages.PreparingTwoPhaseTransaction(_transactionLogger, _txId, _connector.Id); + + // The PostgreSQL prepared transaction name is the distributed GUID + our connection's process ID, for uniqueness + _preparedTxName = $"{_transaction.TransactionInformation.DistributedIdentifier}/{_connector.BackendProcessId}"; - public void SinglePhaseCommit(SinglePhaseEnlistment singlePhaseEnlistment) + try { - CheckDisposed(); - Log.Debug($"Single Phase Commit (localid={_txId})", _connector.Id); + using (_connector.StartUserAction()) + _connector.ExecuteInternalCommand($"PREPARE TRANSACTION '{_preparedTxName}'"); + + // The MSDTC, which manages escalated distributed transactions, performs the 2nd phase + // asynchronously - this means that TransactionScope.Dispose() will return before all + // resource managers have actually commit. + // If the same connection tries to enlist to a new TransactionScope immediately after + // disposing an old TransactionScope, its EnlistedTransaction must have been cleared + // (or we'll throw a double enlistment exception). This must be done here at the 1st phase + // (which is sync). + if (_connector.Connection != null) + _connector.Connection.EnlistedTransaction = null; - try - { - _localTx.Commit(); - singlePhaseEnlistment.Committed(); - } - catch (PostgresException e) - { - singlePhaseEnlistment.Aborted(e); - } - catch (Exception e) - { - singlePhaseEnlistment.InDoubt(e); - } - finally - { - Dispose(); - } + preparingEnlistment.Prepared(); } - - public void Prepare(PreparingEnlistment preparingEnlistment) + catch (Exception e) { - CheckDisposed(); - Log.Debug($"Two-phase transaction prepare (localid={_txId})", _connector.Id); + Dispose(); + preparingEnlistment.ForceRollback(e); + } + } - // The PostgreSQL prepared transaction name is the distributed GUID + our connection's process ID, for uniqueness - _preparedTxName = $"{_transaction.TransactionInformation.DistributedIdentifier}/{_connector.BackendProcessId}"; + [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Changing Enlist to be false does not affect potentially trimmed out functionality.")] + [UnconditionalSuppressMessage("Aot", "IL3050", Justification = "Changing Enlist to be false does not cause dynamic codegen.")] + public void Commit(Enlistment enlistment) + { + CheckDisposed(); + LogMessages.CommittingTwoPhaseTransaction(_transactionLogger, _txId, _connector.Id); - try + try + { + if (_connector.Connection == null) { + // The connection has been closed before the TransactionScope was disposed. + // The connector is unbound from its connection and is sitting in the pool's + // pending enlisted connector list. Since there's no risk of the connector being + // used by anyone we can executed the 2nd phase on it directly (see below). using (_connector.StartUserAction()) - _connector.ExecuteInternalCommand($"PREPARE TRANSACTION '{_preparedTxName}'"); - - // The MSDTC, which manages escalated distributed transactions, performs the 2nd phase - // asynchronously - this means that TransactionScope.Dispose() will return before all - // resource managers have actually commit. - // If the same connection tries to enlist to a new TransactionScope immediately after - // disposing an old TransactionScope, its EnlistedTransaction must have been cleared - // (or we'll throw a double enlistment exception). This must be done here at the 1st phase - // (which is sync). - if (_connector.Connection != null) - _connector.Connection.EnlistedTransaction = null; - - preparingEnlistment.Prepared(); + _connector.ExecuteInternalCommand($"COMMIT PREPARED '{_preparedTxName}'"); } - catch (Exception e) + else { - Dispose(); - preparingEnlistment.ForceRollback(e); + // The connection is still open and potentially will be reused by by the user. + // The MSDTC, which manages escalated distributed transactions, performs the 2nd phase + // asynchronously - this means that TransactionScope.Dispose() will return before all + // resource managers have actually commit. This can cause a concurrent connection use scenario + // if the user continues to use their connection after disposing the scope, and the MSDTC + // requests a commit at that exact time. + // To avoid this, we open a new connection for performing the 2nd phase. + var settings = _connector.Connection.Settings.Clone(); + // Set Enlist to false because we might be in TransactionScope and we can't prepare transaction while being in an open transaction + // see #5246 + settings.Enlist = false; + using var conn2 = _connector.Connection.CloneWith(settings.ConnectionString); + conn2.Open(); + + var connector = conn2.Connector!; + using (connector.StartUserAction()) + connector.ExecuteInternalCommand($"COMMIT PREPARED '{_preparedTxName}'"); } } - - public void Commit(Enlistment enlistment) + catch (Exception e) { - CheckDisposed(); - Log.Debug($"Two-phase transaction commit (localid={_txId})", _connector.Id); - - try - { - if (_connector.Connection == null) - { - // The connection has been closed before the TransactionScope was disposed. - // The connector is unbound from its connection and is sitting in the pool's - // pending enlisted connector list. Since there's no risk of the connector being - // used by anyone we can executed the 2nd phase on it directly (see below). - using (_connector.StartUserAction()) - _connector.ExecuteInternalCommand($"COMMIT PREPARED '{_preparedTxName}'"); - } - else - { - // The connection is still open and potentially will be reused by by the user. - // The MSDTC, which manages escalated distributed transactions, performs the 2nd phase - // asynchronously - this means that TransactionScope.Dispose() will return before all - // resource managers have actually commit. This can cause a concurrent connection use scenario - // if the user continues to use their connection after disposing the scope, and the MSDTC - // requests a commit at that exact time. - // To avoid this, we open a new connection for performing the 2nd phase. - using var conn2 = (NpgsqlConnection)((ICloneable)_connector.Connection).Clone(); - conn2.Open(); - - var connector = conn2.Connector!; - using (connector.StartUserAction()) - connector.ExecuteInternalCommand($"COMMIT PREPARED '{_preparedTxName}'"); - } - } - catch (Exception e) - { - Log.Error("Exception during two-phase transaction commit (localid={TransactionId})", e, _connector.Id); - } - finally - { - Dispose(); - enlistment.Done(); - } + LogMessages.TwoPhaseTransactionCommitFailed(_transactionLogger, _txId, _connector.Id, e); } - - public void Rollback(Enlistment enlistment) + finally { - CheckDisposed(); - - try - { - if (IsPrepared) - RollbackTwoPhase(); - else - RollbackLocal(); - } - catch (Exception e) - { - Log.Error($"Exception during transaction rollback (localid={_txId})", e, _connector.Id); - } - finally - { - Dispose(); - enlistment.Done(); - } + Dispose(); + enlistment.Done(); } + } - public void InDoubt(Enlistment enlistment) - { - Log.Warn($"Two-phase transaction in doubt (localid={_txId})", _connector.Id); + public void Rollback(Enlistment enlistment) + { + CheckDisposed(); - // TODO: Is this the correct behavior? - try - { + try + { + if (IsPrepared) RollbackTwoPhase(); - } - catch (Exception e) - { - Log.Error($"Exception during transaction rollback (localid={_txId})", e, _connector.Id); - } - finally - { - Dispose(); - enlistment.Done(); - } + else + RollbackLocal(); + } + finally + { + Dispose(); + enlistment.Done(); } + } + + public void InDoubt(Enlistment enlistment) + { + LogMessages.TwoPhaseTransactionInDoubt(_transactionLogger, _txId, _connector.Id); - void RollbackLocal() + // TODO: Is this the correct behavior? + try { - Log.Debug($"Single-phase transaction rollback (localid={_txId})", _connector.Id); + RollbackTwoPhase(); + } + finally + { + Dispose(); + enlistment.Done(); + } + } + void RollbackLocal() + { + LogMessages.RollingBackSinglePhaseTransaction(_transactionLogger, _txId, _connector.Id); + + try + { var attempt = 0; while (true) { @@ -197,21 +201,29 @@ void RollbackLocal() // This really shouldn't be necessary, but just in case if (attempt++ == MaximumRollbackAttempts) - throw new Exception($"Could not roll back after {MaximumRollbackAttempts} attempts, aborting. Transaction is in an unknown state."); + throw new Exception( + $"Could not roll back after {MaximumRollbackAttempts} attempts, aborting. Transaction is in an unknown state."); - Log.Warn($"Connection in use while trying to rollback, will cancel and retry (localid={_txId}", _connector.Id); + LogMessages.ConnectionInUseWhenRollingBack(_transactionLogger, _txId, _connector.Id); _connector.PerformPostgresCancellation(); // Cancellations are asynchronous, give it some time Thread.Sleep(500); } } } - - void RollbackTwoPhase() + catch { - // This only occurs if we've started a two-phase commit but one of the commits has failed. - Log.Debug($"Two-phase transaction rollback (localid={_txId})", _connector.Id); + LogMessages.SinglePhaseTransactionRollbackFailed(_transactionLogger, _txId, _connector.Id); + } + } + + void RollbackTwoPhase() + { + // This only occurs if we've started a two-phase commit but one of the commits has failed. + LogMessages.RollingBackTwoPhaseTransaction(_transactionLogger, _txId, _connector.Id); + try + { if (_connector.Connection == null) { // The connection has been closed before the TransactionScope was disposed. @@ -238,63 +250,61 @@ void RollbackTwoPhase() connector.ExecuteInternalCommand($"ROLLBACK PREPARED '{_preparedTxName}'"); } } - - #region Dispose/Cleanup - -#pragma warning disable CS8625 - void Dispose() + catch (Exception e) { - if (_isDisposed) - return; + LogMessages.TwoPhaseTransactionRollbackFailed(_transactionLogger, _txId, _connector.Id, e); + } + } - Log.Trace($"Cleaning up resource manager (localid={_txId}", _connector.Id); - if (_localTx != null) - { - _localTx.Dispose(); - _localTx = null; - } + #region Dispose/Cleanup - if (_connector.Connection != null) - _connector.Connection.EnlistedTransaction = null; - else - { - // We're here for connections which were closed before their TransactionScope completes. - // These need to be closed now. - if (_connector.Settings.Pooling) - { - var found = PoolManager.TryGetValue(_connector.ConnectionString, out var pool); - Debug.Assert(found); - pool!.TryRemovePendingEnlistedConnector(_connector, _transaction); - pool.Return(_connector); - } - else - _connector.Close(); - } +#pragma warning disable CS8625 + void Dispose() + { + if (_isDisposed) + return; - _connector = null!; - _transaction = null!; - _isDisposed = true; + LogMessages.CleaningUpResourceManager(_transactionLogger, _txId, _connector.Id); + if (_localTx != null) + { + _localTx.Dispose(); + _localTx = null; } -#pragma warning restore CS8625 - void CheckDisposed() + if (_connector.Connection != null) + _connector.Connection.EnlistedTransaction = null; + else { - if (_isDisposed) - throw new ObjectDisposedException(nameof(VolatileResourceManager)); + // We're here for connections which were closed before their TransactionScope completes. + // These need to be closed now. + // We should return the connector to the pool only if we've successfully removed it from the pending list + if (_connector.TryRemovePendingEnlistedConnector(_transaction)) + _connector.Return(); } - #endregion + _connector = null!; + _transaction = null!; + _isDisposed = true; + } +#pragma warning restore CS8625 - static System.Data.IsolationLevel ConvertIsolationLevel(IsolationLevel isolationLevel) - => isolationLevel switch - { - IsolationLevel.Chaos => System.Data.IsolationLevel.Chaos, - IsolationLevel.ReadCommitted => System.Data.IsolationLevel.ReadCommitted, - IsolationLevel.ReadUncommitted => System.Data.IsolationLevel.ReadUncommitted, - IsolationLevel.RepeatableRead => System.Data.IsolationLevel.RepeatableRead, - IsolationLevel.Serializable => System.Data.IsolationLevel.Serializable, - IsolationLevel.Snapshot => System.Data.IsolationLevel.Snapshot, - _ => System.Data.IsolationLevel.Unspecified - }; + void CheckDisposed() + { + if (_isDisposed) + throw new ObjectDisposedException(nameof(VolatileResourceManager)); } + + #endregion + + static System.Data.IsolationLevel ConvertIsolationLevel(IsolationLevel isolationLevel) + => isolationLevel switch + { + IsolationLevel.Chaos => System.Data.IsolationLevel.Chaos, + IsolationLevel.ReadCommitted => System.Data.IsolationLevel.ReadCommitted, + IsolationLevel.ReadUncommitted => System.Data.IsolationLevel.ReadUncommitted, + IsolationLevel.RepeatableRead => System.Data.IsolationLevel.RepeatableRead, + IsolationLevel.Serializable => System.Data.IsolationLevel.Serializable, + IsolationLevel.Snapshot => System.Data.IsolationLevel.Snapshot, + _ => System.Data.IsolationLevel.Unspecified + }; } diff --git a/src/Shared/CodeAnalysis.cs b/src/Shared/CodeAnalysis.cs new file mode 100644 index 0000000000..8e8e3b3d9e --- /dev/null +++ b/src/Shared/CodeAnalysis.cs @@ -0,0 +1,257 @@ +using System; +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable 1591 + +namespace System.Diagnostics.CodeAnalysis +{ +#if !NET7_0_OR_GREATER + /// + /// Indicates that the specified method requires the ability to generate new code at runtime, + /// for example through . + /// + /// + /// This allows tools to understand which methods are unsafe to call when compiling ahead of time. + /// + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] + sealed class RequiresDynamicCodeAttribute : Attribute + { + /// + /// Initializes a new instance of the class + /// with the specified message. + /// + /// + /// A message that contains information about the usage of dynamic code. + /// + public RequiresDynamicCodeAttribute(string message) + { + Message = message; + } + + /// + /// Gets a message that contains information about the usage of dynamic code. + /// + public string Message { get; } + + /// + /// Gets or sets an optional URL that contains more information about the method, + /// why it requires dynamic code, and what options a consumer has to deal with it. + /// + public string? Url { get; set; } + } + + [AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)] + sealed class SetsRequiredMembersAttribute : Attribute + { + } + [AttributeUsageAttribute(AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] + sealed class UnscopedRefAttribute : Attribute + { + /// + /// Initializes a new instance of the class. + /// + public UnscopedRefAttribute() { } + } +#endif +#if NETSTANDARD2_0 + [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] + sealed class AllowNullAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] + sealed class DisallowNullAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Method)] + sealed class DoesNotReturnAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Parameter)] + sealed class DoesNotReturnIfAttribute : Attribute + { + public DoesNotReturnIfAttribute(bool parameterValue) => ParameterValue = parameterValue; + public bool ParameterValue { get; } + } + + [AttributeUsageAttribute(AttributeTargets.Assembly | AttributeTargets.Class | AttributeTargets.Constructor | AttributeTargets.Event | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Struct, AllowMultiple = false)] + sealed class ExcludeFromCodeCoverageAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] + sealed class MaybeNullAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Parameter)] + sealed class MaybeNullWhenAttribute : Attribute + { + public MaybeNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + public bool ReturnValue { get; } + } + + [AttributeUsageAttribute(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] + sealed class NotNullAttribute : Attribute + { + } + + [AttributeUsageAttribute(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, AllowMultiple = true)] + sealed class NotNullIfNotNullAttribute : Attribute + { + public NotNullIfNotNullAttribute(string parameterName) => ParameterName = parameterName; + public string ParameterName { get; } + } + + [AttributeUsageAttribute(AttributeTargets.Parameter)] + sealed class NotNullWhenAttribute : Attribute + { + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + public bool ReturnValue { get; } + } +#endif + +#if !NET5_0_OR_GREATER + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = true, Inherited = false)] + sealed class MemberNotNullAttribute : Attribute + { + public MemberNotNullAttribute(string member) => Members = new string[] + { + member + }; + + public MemberNotNullAttribute(params string[] members) => Members = members; + + public string[] Members { get; } + } + + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = true, Inherited = false)] + sealed class MemberNotNullWhenAttribute : Attribute + { + public MemberNotNullWhenAttribute(bool returnValue, string member) + { + ReturnValue = returnValue; + Members = new string[1] { member }; + } + + public MemberNotNullWhenAttribute(bool returnValue, params string[] members) + { + ReturnValue = returnValue; + Members = members; + } + + public bool ReturnValue { get; } + + public string[] Members { get; } + } + + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] + sealed class RequiresUnreferencedCodeAttribute : Attribute + { + public RequiresUnreferencedCodeAttribute(string message) + { + Message = message; + } + + public string Message { get; } + + public string? Url { get; set; } + } + + [AttributeUsage( + AttributeTargets.Field | AttributeTargets.ReturnValue | AttributeTargets.GenericParameter | + AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.Method | + AttributeTargets.Class | AttributeTargets.Interface | AttributeTargets.Struct, + Inherited = false)] + sealed class DynamicallyAccessedMembersAttribute : Attribute + { + public DynamicallyAccessedMembersAttribute(DynamicallyAccessedMemberTypes memberTypes) + { + MemberTypes = memberTypes; + } + + public DynamicallyAccessedMemberTypes MemberTypes { get; } + } + + [Flags] + enum DynamicallyAccessedMemberTypes + { + None = 0, + PublicParameterlessConstructor = 0x0001, + PublicConstructors = 0x0002 | PublicParameterlessConstructor, + NonPublicConstructors = 0x0004, + PublicMethods = 0x0008, + NonPublicMethods = 0x0010, + PublicFields = 0x0020, + NonPublicFields = 0x0040, + PublicNestedTypes = 0x0080, + NonPublicNestedTypes = 0x0100, + PublicProperties = 0x0200, + NonPublicProperties = 0x0400, + PublicEvents = 0x0800, + NonPublicEvents = 0x1000, + Interfaces = 0x2000, + All = ~None + } + + [AttributeUsage(AttributeTargets.All, Inherited = false, AllowMultiple = true)] + sealed class UnconditionalSuppressMessageAttribute : Attribute + { + public UnconditionalSuppressMessageAttribute(string category, string checkId) + { + Category = category; + CheckId = checkId; + } + + public string Category { get; } + public string CheckId { get; } + public string? Scope { get; set; } + public string? Target { get; set; } + public string? MessageId { get; set; } + public string? Justification { get; set; } + } +#endif +} + +namespace System.Runtime.CompilerServices +{ +#if !NET5_0_OR_GREATER + static class IsExternalInit {} +#endif +#if !NET7_0_OR_GREATER + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] + sealed class RequiredMemberAttribute : Attribute + { } + + [AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] + sealed class CompilerFeatureRequiredAttribute : Attribute + { + public CompilerFeatureRequiredAttribute(string featureName) + { + FeatureName = featureName; + } + + /// + /// The name of the compiler feature. + /// + public string FeatureName { get; } + + /// + /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . + /// + public bool IsOptional { get; init; } + + /// + /// The used for the ref structs C# feature. + /// + public const string RefStructs = nameof(RefStructs); + + /// + /// The used for the required members C# feature. + /// + public const string RequiredMembers = nameof(RequiredMembers); + } +#endif +} diff --git a/test/Directory.Build.props b/test/Directory.Build.props index 3e8eda0be7..59f7665837 100644 --- a/test/Directory.Build.props +++ b/test/Directory.Build.props @@ -1,10 +1,22 @@  - - net5.0 + net8.0;netcoreapp3.1 + net8.0 false + + + $(NoWarn);CA2252 + + + + + true diff --git a/test/MStatDumper/MStatDumper.csproj b/test/MStatDumper/MStatDumper.csproj new file mode 100644 index 0000000000..3cab4d57fd --- /dev/null +++ b/test/MStatDumper/MStatDumper.csproj @@ -0,0 +1,15 @@ + + + + Exe + + net8.0 + enable + disable + + + + + + + diff --git a/test/MStatDumper/Program.cs b/test/MStatDumper/Program.cs new file mode 100644 index 0000000000..9a9fe89dfb --- /dev/null +++ b/test/MStatDumper/Program.cs @@ -0,0 +1,368 @@ +using Mono.Cecil; +using Mono.Cecil.Rocks; + +namespace MStatDumper +{ + internal class Program + { + static void Main(string[] args) + { + if (args.Length == 0) + { + throw new Exception("Must provide the path to mstat file. It's in {project}/obj/Release/{TFM}/{os}/native/{project}.mstat"); + } + + var markDownStyleOutput = args.Length > 1 && args[1] == "md"; + + var asm = AssemblyDefinition.ReadAssembly(args[0]); + var globalType = (TypeDefinition)asm.MainModule.LookupToken(0x02000001); + + var versionMajor = asm.Name.Version.Major; + + var types = globalType.Methods.First(x => x.Name == "Types"); + var typeStats = GetTypes(versionMajor, types).ToList(); + var typeSize = typeStats.Sum(x => x.Size); + var typesByModules = typeStats.GroupBy(x => x.Type.Scope).Select(x => new { x.Key.Name, Sum = x.Sum(x => x.Size) }).ToList(); + if (markDownStyleOutput) + { + Console.WriteLine("
"); + Console.WriteLine($"Types Total Size {typeSize:n0}"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size |"); + Console.WriteLine("| --- | --- |"); + foreach (var m in typesByModules.OrderByDescending(x => x.Sum)) + { + var name = m.Name + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Sum:n0} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + } + else + { + Console.WriteLine($"// ********** Types Total Size {typeSize:n0}"); + foreach (var m in typesByModules.OrderByDescending(x => x.Sum)) + { + Console.WriteLine($"{m.Name,-70} {m.Sum,9:n0}"); + } + Console.WriteLine("// **********"); + } + + Console.WriteLine(); + + var methods = globalType.Methods.First(x => x.Name == "Methods"); + var methodStats = GetMethods(versionMajor, methods).ToList(); + var methodSize = methodStats.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize); + var methodsByModules = methodStats.GroupBy(x => x.Method.DeclaringType.Scope).Select(x => new { x.Key.Name, Sum = x.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize) }).ToList(); + if (markDownStyleOutput) + { + Console.WriteLine("
"); + Console.WriteLine($"Methods Total Size {methodSize:n0}"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size |"); + Console.WriteLine("| --- | --- |"); + foreach (var m in methodsByModules.OrderByDescending(x => x.Sum)) + { + var name = m.Name + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Sum:n0} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + } + else + { + Console.WriteLine($"// ********** Methods Total Size {methodSize:n0}"); + foreach (var m in methodsByModules.OrderByDescending(x => x.Sum)) + { + Console.WriteLine($"{m.Name,-70} {m.Sum,9:n0}"); + } + Console.WriteLine("// **********"); + } + + Console.WriteLine(); + + string FindNamespace(TypeReference type) + { + var current = type; + while (true) + { + if (!string.IsNullOrEmpty(current.Namespace)) + { + return current.Namespace; + } + + if (current.DeclaringType == null) + { + return current.Name; + } + + current = current.DeclaringType; + } + } + + var methodsByNamespace = methodStats.Select(x => new TypeStats { Type = x.Method.DeclaringType, Size = x.Size + x.GcInfoSize + x.EhInfoSize }).Concat(typeStats).GroupBy(x => FindNamespace(x.Type)).Select(x => new { x.Key, Sum = x.Sum(x => x.Size) }).ToList(); + if (markDownStyleOutput) + { + Console.WriteLine("
"); + Console.WriteLine("Size By Namespace"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size |"); + Console.WriteLine("| --- | --- |"); + foreach (var m in methodsByNamespace.OrderByDescending(x => x.Sum)) + { + var name = m.Key + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Sum:n0} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + } + else + { + Console.WriteLine("// ********** Size By Namespace"); + foreach (var m in methodsByNamespace.OrderByDescending(x => x.Sum)) + { + Console.WriteLine($"{m.Key,-70} {m.Sum,9:n0}"); + } + Console.WriteLine("// **********"); + } + + Console.WriteLine(); + + var blobs = globalType.Methods.First(x => x.Name == "Blobs"); + var blobStats = GetBlobs(blobs).ToList(); + var blobSize = blobStats.Sum(x => x.Size); + if (markDownStyleOutput) + { + Console.WriteLine("
"); + Console.WriteLine($"Blobs Total Size {blobSize:n0}"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size |"); + Console.WriteLine("| --- | --- |"); + foreach (var m in blobStats.OrderByDescending(x => x.Size)) + { + var name = m.Name + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Size:n0} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + } + else + { + Console.WriteLine($"// ********** Blobs Total Size {blobSize:n0}"); + foreach (var m in blobStats.OrderByDescending(x => x.Size)) + { + Console.WriteLine($"{m.Name,-70} {m.Size,9:n0}"); + } + Console.WriteLine("// **********"); + } + + if (markDownStyleOutput) + { + var methodsByClass = methodStats + .Where(x => x.Method.DeclaringType.Scope.Name == "Npgsql") + .GroupBy(x => GetClassName(x.Method)) + .OrderByDescending(x => x.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize)) + .Take(100) + .ToList(); + + static string GetClassName(MethodReference methodReference) + { + var type = methodReference.DeclaringType.DeclaringType ?? methodReference.DeclaringType; + return type.Namespace + "." + type.Name; + } + + Console.WriteLine("
"); + Console.WriteLine("Top 100 Npgsql Classes By Methods Size"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size | Total Instantiations |"); + Console.WriteLine("| --- | --- | --- |"); + foreach (var m in methodsByClass + .Select(x => new { Name = x.Key, Sum = x.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize), Count = x.Count() }) + .OrderByDescending(x => x.Sum)) + { + var name = m.Name + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Sum:n0} | {m.Count} |"); + } + + Console.WriteLine(); + Console.WriteLine("
"); + + foreach (var g in methodsByClass + .OrderByDescending(x => x.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize))) + { + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine($"\"{g.Key}\" Methods ({g.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize):n0} bytes)"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size | Instantiations |"); + Console.WriteLine("| --- | --- | --- |"); + foreach (var m in g + .GroupBy(x => GetMethodName(x.Method)) + .Select(x => new { Name = x.Key, Size = x.Sum(x => x.Size + x.GcInfoSize + x.EhInfoSize), Count = x.Count()}) + .OrderByDescending(x => x.Size)) + { + var methodName = m.Name + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {methodName} | {m.Size:n0} | {m.Count} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("
"); + + static string GetMethodName(MethodReference methodReference) + { + if (methodReference.DeclaringType.DeclaringType is null) + { + return methodReference.Name; + } + + return methodReference.DeclaringType.Name; + } + } + + Console.WriteLine(); + Console.WriteLine("
"); + + var filteredTypeStats = GetTypes(versionMajor, types) + .Where(x => x.Type.Scope.Name == "Npgsql") + .GroupBy(x => x.Type.Name) + .OrderByDescending(x => x.Sum(x => x.Size)) + .Take(100) + .ToList(); + Console.WriteLine("
"); + Console.WriteLine($"Top 100 Npgsql Types By Size"); + Console.WriteLine(); + Console.WriteLine("
"); + Console.WriteLine(); + Console.WriteLine("| Name | Size | Instantiations |"); + Console.WriteLine("| --- | --- | --- |"); + foreach (var m in filteredTypeStats) + { + var name = m.Key + .Replace("`", "\\`") + .Replace("<", "<") + .Replace(">", ">") + .Replace("|", "\\|"); + Console.WriteLine($"| {name} | {m.Sum(x => x.Size):n0} | {m.Count()} |"); + } + Console.WriteLine(); + Console.WriteLine("
"); + + Console.WriteLine(); + } + } + + public static IEnumerable GetTypes(int formatVersion, MethodDefinition types) + { + var entrySize = formatVersion == 1 ? 2 : 3; + + types.Body.SimplifyMacros(); + var il = types.Body.Instructions; + for (var i = 0; i + entrySize < il.Count; i += entrySize) + { + var type = (TypeReference)il[i + 0].Operand; + var size = (int)il[i + 1].Operand; + yield return new TypeStats + { + Type = type, + Size = size + }; + } + } + + public static IEnumerable GetMethods(int formatVersion, MethodDefinition methods) + { + var entrySize = formatVersion == 1 ? 4 : 5; + + methods.Body.SimplifyMacros(); + var il = methods.Body.Instructions; + for (var i = 0; i + entrySize < il.Count; i += entrySize) + { + var method = (MethodReference)il[i + 0].Operand; + var size = (int)il[i + 1].Operand; + var gcInfoSize = (int)il[i + 2].Operand; + var ehInfoSize = (int)il[i + 3].Operand; + yield return new MethodStats + { + Method = method, + Size = size, + GcInfoSize = gcInfoSize, + EhInfoSize = ehInfoSize + }; + } + } + + public static IEnumerable GetBlobs(MethodDefinition blobs) + { + blobs.Body.SimplifyMacros(); + var il = blobs.Body.Instructions; + for (var i = 0; i + 2 < il.Count; i += 2) + { + var name = (string)il[i + 0].Operand; + var size = (int)il[i + 1].Operand; + yield return new BlobStats + { + Name = name, + Size = size + }; + } + } + } + + public class TypeStats + { + public string MethodName { get; set; } + public TypeReference Type { get; set; } + public int Size { get; set; } + } + + public class MethodStats + { + public MethodReference Method { get; set; } + public int Size { get; set; } + public int GcInfoSize { get; set; } + public int EhInfoSize { get; set; } + } + + public class BlobStats + { + public string Name { get; set; } + public int Size { get; set; } + } +} diff --git a/test/Npgsql.Benchmarks/BenchmarkEnvironment.cs b/test/Npgsql.Benchmarks/BenchmarkEnvironment.cs index d05267ae85..4704cc90e3 100644 --- a/test/Npgsql.Benchmarks/BenchmarkEnvironment.cs +++ b/test/Npgsql.Benchmarks/BenchmarkEnvironment.cs @@ -1,24 +1,23 @@ using System; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +static class BenchmarkEnvironment { - static class BenchmarkEnvironment - { - internal static string ConnectionString => Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; + internal static string ConnectionString => Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; - /// - /// Unless the NPGSQL_TEST_DB environment variable is defined, this is used as the connection string for the - /// test database. - /// - const string DefaultConnectionString = "Server=localhost;User ID=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests"; + /// + /// Unless the NPGSQL_TEST_DB environment variable is defined, this is used as the connection string for the + /// test database. + /// + const string DefaultConnectionString = "Server=localhost;User ID=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests"; - internal static NpgsqlConnection GetConnection() => new NpgsqlConnection(ConnectionString); + internal static NpgsqlConnection GetConnection() => new(ConnectionString); - internal static NpgsqlConnection OpenConnection() - { - var conn = GetConnection(); - conn.Open(); - return conn; - } - } -} + internal static NpgsqlConnection OpenConnection() + { + var conn = GetConnection(); + conn.Open(); + return conn; + } +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs b/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs index ca82f27d1d..c75febe708 100644 --- a/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs +++ b/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs @@ -6,58 +6,57 @@ // ReSharper disable UnusedMember.Global -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +[SuppressMessage("ReSharper", "AssignNullToNotNullAttribute")] +[Config(typeof(Config))] +public class CommandExecuteBenchmarks { - [SuppressMessage("ReSharper", "AssignNullToNotNullAttribute")] - [Config(typeof(Config))] - public class CommandExecuteBenchmarks - { - readonly NpgsqlCommand _executeNonQueryCmd; - readonly NpgsqlCommand _executeNonQueryWithParamCmd; - readonly NpgsqlCommand _executeNonQueryPreparedCmd; - readonly NpgsqlCommand _executeScalarCmd; - readonly NpgsqlCommand _executeReaderCmd; + readonly NpgsqlCommand _executeNonQueryCmd; + readonly NpgsqlCommand _executeNonQueryWithParamCmd; + readonly NpgsqlCommand _executeNonQueryPreparedCmd; + readonly NpgsqlCommand _executeScalarCmd; + readonly NpgsqlCommand _executeReaderCmd; - public CommandExecuteBenchmarks() - { - var conn = BenchmarkEnvironment.OpenConnection(); - _executeNonQueryCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); - _executeNonQueryWithParamCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); - _executeNonQueryWithParamCmd.Parameters.AddWithValue("not_used", DBNull.Value); - _executeNonQueryPreparedCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); - _executeNonQueryPreparedCmd.Prepare(); - _executeScalarCmd = new NpgsqlCommand("SELECT 1", conn); - _executeReaderCmd = new NpgsqlCommand("SELECT 1", conn); - } + public CommandExecuteBenchmarks() + { + var conn = BenchmarkEnvironment.OpenConnection(); + _executeNonQueryCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); + _executeNonQueryWithParamCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); + _executeNonQueryWithParamCmd.Parameters.AddWithValue("not_used", DBNull.Value); + _executeNonQueryPreparedCmd = new NpgsqlCommand("SET lock_timeout = 1000", conn); + _executeNonQueryPreparedCmd.Prepare(); + _executeScalarCmd = new NpgsqlCommand("SELECT 1", conn); + _executeReaderCmd = new NpgsqlCommand("SELECT 1", conn); + } - [Benchmark] - public int ExecuteNonQuery() => _executeNonQueryCmd.ExecuteNonQuery(); + [Benchmark] + public int ExecuteNonQuery() => _executeNonQueryCmd.ExecuteNonQuery(); - [Benchmark] - public int ExecuteNonQueryWithParam() => _executeNonQueryWithParamCmd.ExecuteNonQuery(); + [Benchmark] + public int ExecuteNonQueryWithParam() => _executeNonQueryWithParamCmd.ExecuteNonQuery(); - [Benchmark] - public int ExecuteNonQueryPrepared() => _executeNonQueryPreparedCmd.ExecuteNonQuery(); + [Benchmark] + public int ExecuteNonQueryPrepared() => _executeNonQueryPreparedCmd.ExecuteNonQuery(); - [Benchmark] - public object ExecuteScalar() => _executeScalarCmd.ExecuteScalar()!; + [Benchmark] + public object ExecuteScalar() => _executeScalarCmd.ExecuteScalar()!; - [Benchmark] - public object ExecuteReader() + [Benchmark] + public object ExecuteReader() + { + using (var reader = _executeReaderCmd.ExecuteReader()) { - using (var reader = _executeReaderCmd.ExecuteReader()) - { - reader.Read(); - return reader.GetValue(0); - } + reader.Read(); + return reader.GetValue(0); } + } - class Config : ManualConfig + class Config : ManualConfig + { + public Config() { - public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + AddColumn(StatisticColumn.OperationsPerSecond); } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/Commit.cs b/test/Npgsql.Benchmarks/Commit.cs index 19d10a6896..96e04ade96 100644 --- a/test/Npgsql.Benchmarks/Commit.cs +++ b/test/Npgsql.Benchmarks/Commit.cs @@ -4,34 +4,33 @@ // ReSharper disable AssignNullToNotNullAttribute.Global -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +[Config(typeof(Config))] +public class Commit { - [Config(typeof(Config))] - public class Commit - { - readonly NpgsqlConnection _conn; - readonly NpgsqlCommand _cmd; + readonly NpgsqlConnection _conn; + readonly NpgsqlCommand _cmd; - public Commit() - { - _conn = BenchmarkEnvironment.OpenConnection(); - _cmd = new NpgsqlCommand("SELECT 1", _conn); - } + public Commit() + { + _conn = BenchmarkEnvironment.OpenConnection(); + _cmd = new NpgsqlCommand("SELECT 1", _conn); + } - [Benchmark] - public void Basic() - { - var tx = _conn.BeginTransaction(); - _cmd.ExecuteNonQuery(); - tx.Commit(); - } + [Benchmark] + public void Basic() + { + var tx = _conn.BeginTransaction(); + _cmd.ExecuteNonQuery(); + tx.Commit(); + } - class Config : ManualConfig + class Config : ManualConfig + { + public Config() { - public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + AddColumn(StatisticColumn.OperationsPerSecond); } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs b/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs index 46cbfa4bad..e63bbba7c6 100644 --- a/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs +++ b/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs @@ -5,26 +5,25 @@ // ReSharper disable UnusedMember.Global -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +[Config(typeof(Config))] +public class ConnectionCreationBenchmarks { - [Config(typeof(Config))] - public class ConnectionCreationBenchmarks - { - const string NpgsqlConnectionString = "Host=foo;Database=bar;Username=user;Password=password"; - const string SqlClientConnectionString = @"Data Source=(localdb)\mssqllocaldb"; + const string NpgsqlConnectionString = "Host=foo;Database=bar;Username=user;Password=password"; + const string SqlClientConnectionString = @"Data Source=(localdb)\mssqllocaldb"; - [Benchmark] - public NpgsqlConnection Npgsql() => new NpgsqlConnection(NpgsqlConnectionString); + [Benchmark] + public NpgsqlConnection Npgsql() => new(NpgsqlConnectionString); - [Benchmark] - public SqlConnection SqlClient() => new SqlConnection(SqlClientConnectionString); + [Benchmark] + public SqlConnection SqlClient() => new(SqlClientConnectionString); - class Config : ManualConfig + class Config : ManualConfig + { + public Config() { - public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + AddColumn(StatisticColumn.OperationsPerSecond); } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs b/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs index cf42c364ea..d733ff9c11 100644 --- a/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs +++ b/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs @@ -6,171 +6,170 @@ // ReSharper disable UnusedMember.Global // ReSharper disable MemberCanBePrivate.Global -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +[Config(typeof(Config))] +public class ConnectionOpenCloseBenchmarks { - [Config(typeof(Config))] - public class ConnectionOpenCloseBenchmarks - { - const string SqlClientConnectionString = @"Data Source=(localdb)\mssqllocaldb"; + const string SqlClientConnectionString = @"Data Source=(localdb)\mssqllocaldb"; #pragma warning disable CS8618 - NpgsqlCommand _noOpenCloseCmd; + NpgsqlCommand _noOpenCloseCmd; - readonly string _openCloseConnString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(OpenClose) }.ToString(); - readonly NpgsqlCommand _openCloseCmd = new NpgsqlCommand("SET lock_timeout = 1000"); - readonly SqlCommand _sqlOpenCloseCmd = new SqlCommand("SET LOCK_TIMEOUT 1000"); + readonly string _openCloseConnString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(OpenClose) }.ToString(); + readonly NpgsqlCommand _openCloseCmd = new("SET lock_timeout = 1000"); + readonly SqlCommand _sqlOpenCloseCmd = new("SET LOCK_TIMEOUT 1000"); - NpgsqlConnection _openCloseSameConn; - NpgsqlCommand _openCloseSameCmd; + NpgsqlConnection _openCloseSameConn; + NpgsqlCommand _openCloseSameCmd; - SqlConnection _sqlOpenCloseSameConn; - SqlCommand _sqlOpenCloseSameCmd; + SqlConnection _sqlOpenCloseSameConn; + SqlCommand _sqlOpenCloseSameCmd; - NpgsqlConnection _connWithPrepared; - NpgsqlCommand _withPreparedCmd; + NpgsqlConnection _connWithPrepared; + NpgsqlCommand _withPreparedCmd; - NpgsqlConnection _noResetConn; - NpgsqlCommand _noResetCmd; + NpgsqlConnection _noResetConn; + NpgsqlCommand _noResetCmd; - NpgsqlConnection _nonPooledConnection; - NpgsqlCommand _nonPooledCmd; + NpgsqlConnection _nonPooledConnection; + NpgsqlCommand _nonPooledCmd; #pragma warning restore CS8618 - // ReSharper disable once UnusedAutoPropertyAccessor.Global - [Params(0, 1, 5, 10)] - public int StatementsToSend { get; set; } - - [GlobalSetup] - public void GlobalSetup() - { - var csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(NoOpenClose)}; - var noOpenCloseConn = new NpgsqlConnection(csb.ToString()); - noOpenCloseConn.Open(); - _noOpenCloseCmd = new NpgsqlCommand("SET lock_timeout = 1000", noOpenCloseConn); - - csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(OpenCloseSameConnection) }; - _openCloseSameConn = new NpgsqlConnection(csb.ToString()); - _openCloseSameCmd = new NpgsqlCommand("SET lock_timeout = 1000", _openCloseSameConn); - - _sqlOpenCloseSameConn = new SqlConnection(SqlClientConnectionString); - _sqlOpenCloseSameCmd = new SqlCommand("SET LOCK_TIMEOUT 1000", _sqlOpenCloseSameConn); - - csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(WithPrepared) }; - _connWithPrepared = new NpgsqlConnection(csb.ToString()); - _connWithPrepared.Open(); - using (var somePreparedCmd = new NpgsqlCommand("SELECT 1", _connWithPrepared)) - somePreparedCmd.Prepare(); - _connWithPrepared.Close(); - _withPreparedCmd = new NpgsqlCommand("SET lock_timeout = 1000", _connWithPrepared); - - csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) - { - ApplicationName = nameof(NoResetOnClose), - NoResetOnClose = true - }; - _noResetConn = new NpgsqlConnection(csb.ToString()); - _noResetCmd = new NpgsqlCommand("SET lock_timeout = 1000", _noResetConn); - csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { - ApplicationName = nameof(NonPooled), - Pooling = false - }; - _nonPooledConnection = new NpgsqlConnection(csb.ToString()); - _nonPooledCmd = new NpgsqlCommand("SET lock_timeout = 1000", _nonPooledConnection); - } - - [GlobalCleanup] - public void GlobalCleanup() - { - _noOpenCloseCmd.Connection?.Close(); - NpgsqlConnection.ClearAllPools(); - SqlConnection.ClearAllPools(); - } + // ReSharper disable once UnusedAutoPropertyAccessor.Global + [Params(0, 1, 5, 10)] + public int StatementsToSend { get; set; } - [Benchmark] - public void NoOpenClose() + [GlobalSetup] + public void GlobalSetup() + { + var csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(NoOpenClose)}; + var noOpenCloseConn = new NpgsqlConnection(csb.ToString()); + noOpenCloseConn.Open(); + _noOpenCloseCmd = new NpgsqlCommand("SET lock_timeout = 1000", noOpenCloseConn); + + csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(OpenCloseSameConnection) }; + _openCloseSameConn = new NpgsqlConnection(csb.ToString()); + _openCloseSameCmd = new NpgsqlCommand("SET lock_timeout = 1000", _openCloseSameConn); + + _sqlOpenCloseSameConn = new SqlConnection(SqlClientConnectionString); + _sqlOpenCloseSameCmd = new SqlCommand("SET LOCK_TIMEOUT 1000", _sqlOpenCloseSameConn); + + csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { ApplicationName = nameof(WithPrepared) }; + _connWithPrepared = new NpgsqlConnection(csb.ToString()); + _connWithPrepared.Open(); + using (var somePreparedCmd = new NpgsqlCommand("SELECT 1", _connWithPrepared)) + somePreparedCmd.Prepare(); + _connWithPrepared.Close(); + _withPreparedCmd = new NpgsqlCommand("SET lock_timeout = 1000", _connWithPrepared); + + csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { - for (var i = 0; i < StatementsToSend; i++) - _noOpenCloseCmd.ExecuteNonQuery(); - } + ApplicationName = nameof(NoResetOnClose), + NoResetOnClose = true + }; + _noResetConn = new NpgsqlConnection(csb.ToString()); + _noResetCmd = new NpgsqlCommand("SET lock_timeout = 1000", _noResetConn); + csb = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { + ApplicationName = nameof(NonPooled), + Pooling = false + }; + _nonPooledConnection = new NpgsqlConnection(csb.ToString()); + _nonPooledCmd = new NpgsqlCommand("SET lock_timeout = 1000", _nonPooledConnection); + } - [Benchmark] - public void OpenClose() - { - using (var conn = new NpgsqlConnection(_openCloseConnString)) - { - conn.Open(); - _openCloseCmd.Connection = conn; - for (var i = 0; i < StatementsToSend; i++) - _openCloseCmd.ExecuteNonQuery(); - } - } + [GlobalCleanup] + public void GlobalCleanup() + { + _noOpenCloseCmd.Connection?.Close(); + NpgsqlConnection.ClearAllPools(); + SqlConnection.ClearAllPools(); + } - [Benchmark(Baseline = true)] - public void SqlClientOpenClose() - { - using (var conn = new SqlConnection(SqlClientConnectionString)) - { - conn.Open(); - _sqlOpenCloseCmd.Connection = conn; - for (var i = 0; i < StatementsToSend; i++) - _sqlOpenCloseCmd.ExecuteNonQuery(); - } - } + [Benchmark] + public void NoOpenClose() + { + for (var i = 0; i < StatementsToSend; i++) + _noOpenCloseCmd.ExecuteNonQuery(); + } - [Benchmark] - public void OpenCloseSameConnection() + [Benchmark] + public void OpenClose() + { + using (var conn = new NpgsqlConnection(_openCloseConnString)) { - _openCloseSameConn.Open(); + conn.Open(); + _openCloseCmd.Connection = conn; for (var i = 0; i < StatementsToSend; i++) - _openCloseSameCmd.ExecuteNonQuery(); - _openCloseSameConn.Close(); + _openCloseCmd.ExecuteNonQuery(); } + } - [Benchmark] - public void SqlClientOpenCloseSameConnection() + [Benchmark(Baseline = true)] + public void SqlClientOpenClose() + { + using (var conn = new SqlConnection(SqlClientConnectionString)) { - _sqlOpenCloseSameConn.Open(); + conn.Open(); + _sqlOpenCloseCmd.Connection = conn; for (var i = 0; i < StatementsToSend; i++) - _sqlOpenCloseSameCmd.ExecuteNonQuery(); - _sqlOpenCloseSameConn.Close(); + _sqlOpenCloseCmd.ExecuteNonQuery(); } + } - /// - /// Having prepared statements alters the connection reset when closing. - /// - [Benchmark] - public void WithPrepared() - { - _connWithPrepared.Open(); - for (var i = 0; i < StatementsToSend; i++) - _withPreparedCmd.ExecuteNonQuery(); - _connWithPrepared.Close(); - } + [Benchmark] + public void OpenCloseSameConnection() + { + _openCloseSameConn.Open(); + for (var i = 0; i < StatementsToSend; i++) + _openCloseSameCmd.ExecuteNonQuery(); + _openCloseSameConn.Close(); + } - [Benchmark] - public void NoResetOnClose() - { - _noResetConn.Open(); - for (var i = 0; i < StatementsToSend; i++) - _noResetCmd.ExecuteNonQuery(); - _noResetConn.Close(); - } + [Benchmark] + public void SqlClientOpenCloseSameConnection() + { + _sqlOpenCloseSameConn.Open(); + for (var i = 0; i < StatementsToSend; i++) + _sqlOpenCloseSameCmd.ExecuteNonQuery(); + _sqlOpenCloseSameConn.Close(); + } - [Benchmark] - public void NonPooled() - { - _nonPooledConnection.Open(); - for (var i = 0; i < StatementsToSend; i++) - _nonPooledCmd.ExecuteNonQuery(); - _nonPooledConnection.Close(); - } + /// + /// Having prepared statements alters the connection reset when closing. + /// + [Benchmark] + public void WithPrepared() + { + _connWithPrepared.Open(); + for (var i = 0; i < StatementsToSend; i++) + _withPreparedCmd.ExecuteNonQuery(); + _connWithPrepared.Close(); + } - class Config : ManualConfig + [Benchmark] + public void NoResetOnClose() + { + _noResetConn.Open(); + for (var i = 0; i < StatementsToSend; i++) + _noResetCmd.ExecuteNonQuery(); + _noResetConn.Close(); + } + + [Benchmark] + public void NonPooled() + { + _nonPooledConnection.Open(); + for (var i = 0; i < StatementsToSend; i++) + _nonPooledCmd.ExecuteNonQuery(); + _nonPooledConnection.Close(); + } + + class Config : ManualConfig + { + public Config() { - public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + AddColumn(StatisticColumn.OperationsPerSecond); } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/CopyExport.cs b/test/Npgsql.Benchmarks/CopyExport.cs index f97d61cc0f..e4ea9c0698 100644 --- a/test/Npgsql.Benchmarks/CopyExport.cs +++ b/test/Npgsql.Benchmarks/CopyExport.cs @@ -1,40 +1,39 @@ using BenchmarkDotNet.Attributes; using NpgsqlTypes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class CopyExport { - public class CopyExport + NpgsqlConnection _conn = default!; + const int Rows = 1000; + + [GlobalSetup] + public void Setup() { - NpgsqlConnection _conn = default!; - const int Rows = 1000; + _conn = BenchmarkEnvironment.OpenConnection(); + using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (i1 INT, i2 INT, i3 INT, i4 INT, i5 INT, i6 INT, i7 INT, i8 INT, i9 INT, i10 INT)", _conn)) + cmd.ExecuteNonQuery(); - [GlobalSetup] - public void Setup() - { - _conn = BenchmarkEnvironment.OpenConnection(); - using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (i1 INT, i2 INT, i3 INT, i4 INT, i5 INT, i6 INT, i7 INT, i8 INT, i9 INT, i10 INT)", _conn)) + using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)", _conn)) + for (var i = 0; i < Rows; i++) cmd.ExecuteNonQuery(); + } - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)", _conn)) - for (var i = 0; i < Rows; i++) - cmd.ExecuteNonQuery(); - } - - [GlobalCleanup] - public void Cleanup() => _conn.Dispose(); + [GlobalCleanup] + public void Cleanup() => _conn.Dispose(); - [Benchmark] - public int Export() + [Benchmark] + public int Export() + { + var sum = 0; + unchecked { - var sum = 0; - unchecked - { - using (var exporter = _conn.BeginBinaryExport("COPY data TO STDOUT (FORMAT BINARY)")) - while (exporter.StartRow() != -1) - for (var col = 0; col < 10; col++) - sum += exporter.Read(NpgsqlDbType.Integer); - } - return sum; + using (var exporter = _conn.BeginBinaryExport("COPY data TO STDOUT (FORMAT BINARY)")) + while (exporter.StartRow() != -1) + for (var col = 0; col < 10; col++) + sum += exporter.Read(NpgsqlDbType.Integer); } + return sum; } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/CopyImport.cs b/test/Npgsql.Benchmarks/CopyImport.cs index afb4002687..486d257d6c 100644 --- a/test/Npgsql.Benchmarks/CopyImport.cs +++ b/test/Npgsql.Benchmarks/CopyImport.cs @@ -1,43 +1,42 @@ using BenchmarkDotNet.Attributes; using NpgsqlTypes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class CopyImport { - public class CopyImport - { - NpgsqlConnection _conn = default!; - NpgsqlCommand _truncateCmd = default!; - const int Rows = 1000; + NpgsqlConnection _conn = default!; + NpgsqlCommand _truncateCmd = default!; + const int Rows = 1000; - [GlobalSetup] - public void Setup() - { - _conn = BenchmarkEnvironment.OpenConnection(); - using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (i1 INT, i2 INT, i3 INT, i4 INT, i5 INT, i6 INT, i7 INT, i8 INT, i9 INT, i10 INT)", _conn)) - cmd.ExecuteNonQuery(); + [GlobalSetup] + public void Setup() + { + _conn = BenchmarkEnvironment.OpenConnection(); + using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (i1 INT, i2 INT, i3 INT, i4 INT, i5 INT, i6 INT, i7 INT, i8 INT, i9 INT, i10 INT)", _conn)) + cmd.ExecuteNonQuery(); - _truncateCmd = new NpgsqlCommand("TRUNCATE data", _conn); - _truncateCmd.Prepare(); - } + _truncateCmd = new NpgsqlCommand("TRUNCATE data", _conn); + _truncateCmd.Prepare(); + } - [GlobalCleanup] - public void Cleanup() => _conn.Dispose(); + [GlobalCleanup] + public void Cleanup() => _conn.Dispose(); - [IterationCleanup] - public void IterationCleanup() => _truncateCmd.ExecuteNonQuery(); + [IterationCleanup] + public void IterationCleanup() => _truncateCmd.ExecuteNonQuery(); - [Benchmark] - public void Import() + [Benchmark] + public void Import() + { + using (var importer = _conn.BeginBinaryImport("COPY data FROM STDIN (FORMAT BINARY)")) { - using (var importer = _conn.BeginBinaryImport("COPY data FROM STDIN (FORMAT BINARY)")) + for (var row = 0; row < Rows; row++) { - for (var row = 0; row < Rows; row++) - { - importer.StartRow(); - for (var col = 0; col < 10; col++) - importer.Write(col, NpgsqlDbType.Integer); - } + importer.StartRow(); + for (var col = 0; col < 10; col++) + importer.Write(col, NpgsqlDbType.Integer); } } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/GetFieldValue.cs b/test/Npgsql.Benchmarks/GetFieldValue.cs index b7b19fa0b9..0065f4546c 100644 --- a/test/Npgsql.Benchmarks/GetFieldValue.cs +++ b/test/Npgsql.Benchmarks/GetFieldValue.cs @@ -2,38 +2,37 @@ using BenchmarkDotNet.Columns; using BenchmarkDotNet.Configs; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +[Config(typeof(Config))] +public class GetFieldValue { - [Config(typeof(Config))] - public class GetFieldValue + readonly NpgsqlConnection _conn; + readonly NpgsqlCommand _cmd; + readonly NpgsqlDataReader _reader; + + public GetFieldValue() + { + _conn = BenchmarkEnvironment.OpenConnection(); + _cmd = new NpgsqlCommand("SELECT 0, 'str'", _conn); + _reader = _cmd.ExecuteReader(); + _reader.Read(); + } + + [Benchmark] + public void NullableField() => _reader.GetFieldValue(0); + + [Benchmark] + public void ValueTypeField() => _reader.GetFieldValue(0); + + [Benchmark] + public void ReferenceTypeField() => _reader.GetFieldValue(1); + + [Benchmark] + public void ObjectField() => _reader.GetFieldValue(1); + + class Config : ManualConfig { - readonly NpgsqlConnection _conn; - readonly NpgsqlCommand _cmd; - readonly NpgsqlDataReader _reader; - - public GetFieldValue() - { - _conn = BenchmarkEnvironment.OpenConnection(); - _cmd = new NpgsqlCommand("SELECT 0, 'str'", _conn); - _reader = _cmd.ExecuteReader(); - _reader.Read(); - } - - [Benchmark] - public void NullableField() => _reader.GetFieldValue(0); - - [Benchmark] - public void ValueTypeField() => _reader.GetFieldValue(0); - - [Benchmark] - public void ReferenceTypeField() => _reader.GetFieldValue(1); - - [Benchmark] - public void ObjectField() => _reader.GetFieldValue(1); - - class Config : ManualConfig - { - public Config() => AddColumn(StatisticColumn.OperationsPerSecond); - } + public Config() => AddColumn(StatisticColumn.OperationsPerSecond); } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/Insert.cs b/test/Npgsql.Benchmarks/Insert.cs index db69b95fe0..2de57776d5 100644 --- a/test/Npgsql.Benchmarks/Insert.cs +++ b/test/Npgsql.Benchmarks/Insert.cs @@ -2,88 +2,87 @@ using BenchmarkDotNet.Attributes; using NpgsqlTypes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class Insert { - public class Insert - { - NpgsqlConnection _conn = default!; - NpgsqlCommand _truncateCmd = default!; + NpgsqlConnection _conn = default!; + NpgsqlCommand _truncateCmd = default!; - [Params(1, 100, 1000, 10000)] - public int BatchSize { get; set; } + [Params(1, 100, 1000, 10000)] + public int BatchSize { get; set; } - [GlobalSetup] - public void GlobalSetup() + [GlobalSetup] + public void GlobalSetup() + { + var connString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) - { - Pooling = false - }.ToString(); - _conn = new NpgsqlConnection(connString); - _conn.Open(); + Pooling = false + }.ToString(); + _conn = new NpgsqlConnection(connString); + _conn.Open(); - using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (int1 INT4, text1 TEXT, int2 INT4, text2 TEXT)", _conn)) - cmd.ExecuteNonQuery(); + using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (int1 INT4, text1 TEXT, int2 INT4, text2 TEXT)", _conn)) + cmd.ExecuteNonQuery(); - _truncateCmd = new NpgsqlCommand("TRUNCATE data", _conn); - } + _truncateCmd = new NpgsqlCommand("TRUNCATE data", _conn); + } - [GlobalCleanup] - public void GlobalCleanup() => _conn.Close(); + [GlobalCleanup] + public void GlobalCleanup() => _conn.Close(); - [Benchmark(Baseline = true)] - public void Unbatched() - { - var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p0, @p1, @p2, @p3)", _conn); - cmd.Parameters.AddWithValue("p0", NpgsqlDbType.Integer, 8); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Text, "foo"); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer, 9); - cmd.Parameters.AddWithValue("p3", NpgsqlDbType.Text, "bar"); - cmd.Prepare(); + [Benchmark(Baseline = true)] + public void Unbatched() + { + var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p0, @p1, @p2, @p3)", _conn); + cmd.Parameters.AddWithValue("p0", NpgsqlDbType.Integer, 8); + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Text, "foo"); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer, 9); + cmd.Parameters.AddWithValue("p3", NpgsqlDbType.Text, "bar"); + cmd.Prepare(); - for (var i = 0; i < BatchSize; i++) - cmd.ExecuteNonQuery(); - _truncateCmd.ExecuteNonQuery(); - } + for (var i = 0; i < BatchSize; i++) + cmd.ExecuteNonQuery(); + _truncateCmd.ExecuteNonQuery(); + } - [Benchmark] - public void Batched() + [Benchmark] + public void Batched() + { + var cmd = new NpgsqlCommand { Connection = _conn }; + var sb = new StringBuilder(); + for (var i = 0; i < BatchSize; i++) { - var cmd = new NpgsqlCommand { Connection = _conn }; - var sb = new StringBuilder(); - for (var i = 0; i < BatchSize; i++) - { - var p1 = (i * 4).ToString(); - var p2 = (i * 4 + 1).ToString(); - var p3 = (i * 4 + 2).ToString(); - var p4 = (i * 4 + 3).ToString(); - sb.Append("INSERT INTO data VALUES (@").Append(p1).Append(", @").Append(p2).Append(", @").Append(p3).Append(", @").Append(p4).Append(");"); - cmd.Parameters.AddWithValue(p1, NpgsqlDbType.Integer, 8); - cmd.Parameters.AddWithValue(p2, NpgsqlDbType.Text, "foo"); - cmd.Parameters.AddWithValue(p3, NpgsqlDbType.Integer, 9); - cmd.Parameters.AddWithValue(p4, NpgsqlDbType.Text, "bar"); - } - cmd.CommandText = sb.ToString(); - cmd.Prepare(); - cmd.ExecuteNonQuery(); - _truncateCmd.ExecuteNonQuery(); + var p1 = (i * 4).ToString(); + var p2 = (i * 4 + 1).ToString(); + var p3 = (i * 4 + 2).ToString(); + var p4 = (i * 4 + 3).ToString(); + sb.Append("INSERT INTO data VALUES (@").Append(p1).Append(", @").Append(p2).Append(", @").Append(p3).Append(", @").Append(p4).Append(");"); + cmd.Parameters.AddWithValue(p1, NpgsqlDbType.Integer, 8); + cmd.Parameters.AddWithValue(p2, NpgsqlDbType.Text, "foo"); + cmd.Parameters.AddWithValue(p3, NpgsqlDbType.Integer, 9); + cmd.Parameters.AddWithValue(p4, NpgsqlDbType.Text, "bar"); } + cmd.CommandText = sb.ToString(); + cmd.Prepare(); + cmd.ExecuteNonQuery(); + _truncateCmd.ExecuteNonQuery(); + } - [Benchmark] - public void Copy() + [Benchmark] + public void Copy() + { + using (var s = _conn.BeginBinaryImport("COPY data (int1, text1, int2, text2) FROM STDIN BINARY")) { - using (var s = _conn.BeginBinaryImport("COPY data (int1, text1, int2, text2) FROM STDIN BINARY")) + for (var i = 0; i < BatchSize; i++) { - for (var i = 0; i < BatchSize; i++) - { - s.StartRow(); - s.Write(8); - s.Write("foo"); - s.Write(9); - s.Write("bar"); - } + s.StartRow(); + s.Write(8); + s.Write("foo"); + s.Write(9); + s.Write("bar"); } - _truncateCmd.ExecuteNonQuery(); } + _truncateCmd.ExecuteNonQuery(); } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj b/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj index 7b3fe28f8a..922d4cbdce 100644 --- a/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj +++ b/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj @@ -4,6 +4,7 @@ portable Npgsql.Benchmarks Exe + $(NoWarn);NPG9001 @@ -13,6 +14,8 @@ + + diff --git a/test/Npgsql.Benchmarks/Prepare.cs b/test/Npgsql.Benchmarks/Prepare.cs index 16cd0666a9..6b8d9b06bc 100644 --- a/test/Npgsql.Benchmarks/Prepare.cs +++ b/test/Npgsql.Benchmarks/Prepare.cs @@ -1,5 +1,4 @@ -using System.Diagnostics.CodeAnalysis; -using System.Linq; +using System.Linq; using System.Reflection; using System.Text; using BenchmarkDotNet.Attributes; @@ -8,117 +7,116 @@ // ReSharper disable MemberCanBePrivate.Global // ReSharper disable AssignNullToNotNullAttribute.Global -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class Prepare { - public class Prepare - { - NpgsqlConnection _conn = default!, _autoPreparingConn = default!; - static readonly string[] Queries; - string _query = default!; - NpgsqlCommand _preparedCmd = default!; + NpgsqlConnection _conn = default!, _autoPreparingConn = default!; + static readonly string[] Queries; + string _query = default!; + NpgsqlCommand _preparedCmd = default!; - /// - /// The more tables are joined, the more complex the query is to plan, and therefore the more - /// impact statement preparation should have. - /// - [Params(0, 1, 2, 5, 10)] - public int TablesToJoin { get; set; } + /// + /// The more tables are joined, the more complex the query is to plan, and therefore the more + /// impact statement preparation should have. + /// + [Params(0, 1, 2, 5, 10)] + public int TablesToJoin { get; set; } - [GlobalSetup] - public void GlobalSetup() + [GlobalSetup] + public void GlobalSetup() + { + _conn = BenchmarkEnvironment.OpenConnection(); + _autoPreparingConn = new NpgsqlConnection(new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) { - _conn = BenchmarkEnvironment.OpenConnection(); - _autoPreparingConn = new NpgsqlConnection(new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) - { - MaxAutoPrepare = 10 - }.ToString()); - _autoPreparingConn.Open(); + MaxAutoPrepare = 10 + }.ToString()); + _autoPreparingConn.Open(); - foreach (var conn in new[] { _conn, _autoPreparingConn }) + foreach (var conn in new[] { _conn, _autoPreparingConn }) + { + using (var cmd = new NpgsqlCommand { Connection = conn }) { - using (var cmd = new NpgsqlCommand { Connection = conn }) + for (var i = 0; i < 100; i++) { - for (var i = 0; i < 100; i++) - { - cmd.CommandText = $@" + cmd.CommandText = $@" CREATE TEMP TABLE table{i} (id INT PRIMARY KEY, data INT); INSERT INTO table{i} (id, data) VALUES (1, {i}); "; - cmd.ExecuteNonQuery(); - } + cmd.ExecuteNonQuery(); } } - _query = Queries[TablesToJoin]; - _preparedCmd = new NpgsqlCommand(_query, _conn); - _preparedCmd.Prepare(); } + _query = Queries[TablesToJoin]; + _preparedCmd = new NpgsqlCommand(_query, _conn); + _preparedCmd.Prepare(); + } - [GlobalCleanup] - public void GlobalCleanup() - { - _conn.Dispose(); - } + [GlobalCleanup] + public void GlobalCleanup() + { + _conn.Dispose(); + } - public Prepare() + public Prepare() + { + // Create tables and data + using (var conn = BenchmarkEnvironment.OpenConnection()) + using (var cmd = new NpgsqlCommand {Connection = conn}) { - // Create tables and data - using (var conn = BenchmarkEnvironment.OpenConnection()) - using (var cmd = new NpgsqlCommand {Connection = conn}) + for (var i = 0; i < TablesToJoinValues.Max(); i++) { - for (var i = 0; i < TablesToJoinValues.Max(); i++) - { - cmd.CommandText = $@" + cmd.CommandText = $@" DROP TABLE IF EXISTS table{i}; CREATE TABLE table{i} (id INT PRIMARY KEY, data INT); INSERT INTO table{i} (id, data) VALUES (1, {i}); "; - cmd.ExecuteNonQuery(); - } + cmd.ExecuteNonQuery(); } } + } - [Benchmark(Baseline = true)] - public object Unprepared() - { - using (var cmd = new NpgsqlCommand(_query, _conn)) - return cmd.ExecuteScalar()!; - } - - [Benchmark] - public object AutoPrepared() - { - using (var cmd = new NpgsqlCommand(_query, _autoPreparingConn)) - return cmd.ExecuteScalar()!; - } + [Benchmark(Baseline = true)] + public object Unprepared() + { + using (var cmd = new NpgsqlCommand(_query, _conn)) + return cmd.ExecuteScalar()!; + } - [Benchmark] - public object Prepared() => _preparedCmd.ExecuteScalar()!; + [Benchmark] + public object AutoPrepared() + { + using (var cmd = new NpgsqlCommand(_query, _autoPreparingConn)) + return cmd.ExecuteScalar()!; + } - static Prepare() - { - Queries = new string[TablesToJoinValues.Max() + 1]; - Queries[0] = "SELECT 1"; + [Benchmark] + public object Prepared() => _preparedCmd.ExecuteScalar()!; - foreach (var tablesToJoin in TablesToJoinValues.Where(i => i != 0)) - Queries[tablesToJoin] = GenerateQuery(tablesToJoin); - } + static Prepare() + { + Queries = new string[TablesToJoinValues.Max() + 1]; + Queries[0] = "SELECT 1"; - static string GenerateQuery(int tablesToJoin) - { - var sb = new StringBuilder(); - sb.AppendLine("SELECT "); - sb.AppendLine(string.Join("+", Enumerable.Range(0, tablesToJoin).Select(i => $"table{i}.data"))); - sb.AppendLine("FROM table0"); - for (var i = 1; i < tablesToJoin; i++) - sb.AppendLine($"JOIN table{i} ON table{i}.id = table{i - 1}.id"); - return sb.ToString(); - } + foreach (var tablesToJoin in TablesToJoinValues.Where(i => i != 0)) + Queries[tablesToJoin] = GenerateQuery(tablesToJoin); + } - static readonly int[] TablesToJoinValues = typeof(Prepare) - .GetProperty(nameof(TablesToJoin))! - .GetCustomAttribute()! - .Values - .Cast() - .ToArray(); + static string GenerateQuery(int tablesToJoin) + { + var sb = new StringBuilder(); + sb.AppendLine("SELECT "); + sb.AppendLine(string.Join("+", Enumerable.Range(0, tablesToJoin).Select(i => $"table{i}.data"))); + sb.AppendLine("FROM table0"); + for (var i = 1; i < tablesToJoin; i++) + sb.AppendLine($"JOIN table{i} ON table{i}.id = table{i - 1}.id"); + return sb.ToString(); } -} + + static readonly int[] TablesToJoinValues = typeof(Prepare) + .GetProperty(nameof(TablesToJoin))! + .GetCustomAttribute()! + .Values + .Cast() + .ToArray(); +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/Program.cs b/test/Npgsql.Benchmarks/Program.cs index 67a573c319..9a334f63b8 100644 --- a/test/Npgsql.Benchmarks/Program.cs +++ b/test/Npgsql.Benchmarks/Program.cs @@ -1,10 +1,9 @@ using BenchmarkDotNet.Running; using System.Reflection; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +class Program { - class Program - { - static void Main(string[] args) => new BenchmarkSwitcher(typeof(Program).GetTypeInfo().Assembly).Run(args); - } -} + static void Main(string[] args) => new BenchmarkSwitcher(typeof(Program).GetTypeInfo().Assembly).Run(args); +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ReadArray.cs b/test/Npgsql.Benchmarks/ReadArray.cs index 8d57e322a7..e1e5b2d8de 100644 --- a/test/Npgsql.Benchmarks/ReadArray.cs +++ b/test/Npgsql.Benchmarks/ReadArray.cs @@ -1,89 +1,82 @@ using BenchmarkDotNet.Attributes; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class ReadArrays { - public class ReadArrays - { - [Params(true, false)] - public bool AllNulls; + [Params(true, false)] + public bool AllNulls; - [Params(1, 10, 1000, 100000)] - public int NumElements; + [Params(1, 10, 1000, 100000)] + public int NumElements; - NpgsqlConnection _intConn = default!; - NpgsqlCommand _intCmd = default!; - NpgsqlDataReader _intReader = default!; + NpgsqlConnection _intConn = default!; + NpgsqlCommand _intCmd = default!; + NpgsqlDataReader _intReader = default!; - NpgsqlConnection _nullableIntConn = default!; - NpgsqlCommand _nullableIntCmd = default!; - NpgsqlDataReader _nullableIntReader = default!; + NpgsqlConnection _nullableIntConn = default!; + NpgsqlCommand _nullableIntCmd = default!; + NpgsqlDataReader _nullableIntReader = default!; - NpgsqlConnection _stringConn = default!; - NpgsqlCommand _stringCmd = default!; - NpgsqlDataReader _stringReader = default!; + NpgsqlConnection _stringConn = default!; + NpgsqlCommand _stringCmd = default!; + NpgsqlDataReader _stringReader = default!; - [GlobalSetup] - public void Setup() - { - var intArray = new int[NumElements]; - for (var i = 0; i < NumElements; i++) - intArray[i] = 666; - _intConn = BenchmarkEnvironment.OpenConnection(); - _intCmd = new NpgsqlCommand("SELECT @p1", _intConn); - _intCmd.Parameters.AddWithValue("p1", intArray); - _intReader = _intCmd.ExecuteReader(); - _intReader.Read(); + [GlobalSetup] + public void Setup() + { + var intArray = new int[NumElements]; + for (var i = 0; i < NumElements; i++) + intArray[i] = 666; + _intConn = BenchmarkEnvironment.OpenConnection(); + _intCmd = new NpgsqlCommand("SELECT @p1", _intConn); + _intCmd.Parameters.AddWithValue("p1", intArray); + _intReader = _intCmd.ExecuteReader(); + _intReader.Read(); - var nullableIntArray = new int?[NumElements]; - for (var i = 0; i < NumElements; i++) - nullableIntArray[i] = AllNulls ? (int?)null : 666; - _nullableIntConn = BenchmarkEnvironment.OpenConnection(); - _nullableIntCmd = new NpgsqlCommand("SELECT @p1", _nullableIntConn); - _nullableIntCmd.Parameters.AddWithValue("p1", nullableIntArray); - _nullableIntReader = _nullableIntCmd.ExecuteReader(); - _nullableIntReader.Read(); + var nullableIntArray = new int?[NumElements]; + for (var i = 0; i < NumElements; i++) + nullableIntArray[i] = AllNulls ? (int?)null : 666; + _nullableIntConn = BenchmarkEnvironment.OpenConnection(); + _nullableIntCmd = new NpgsqlCommand("SELECT @p1", _nullableIntConn); + _nullableIntCmd.Parameters.AddWithValue("p1", nullableIntArray); + _nullableIntReader = _nullableIntCmd.ExecuteReader(); + _nullableIntReader.Read(); - var stringArray = new string?[NumElements]; - for (var i = 0; i < NumElements; i++) - stringArray[i] = AllNulls ? null : "666"; - _stringConn = BenchmarkEnvironment.OpenConnection(); - _stringCmd = new NpgsqlCommand("SELECT @p1", _stringConn); - _stringCmd.Parameters.AddWithValue("p1", stringArray); - _stringReader = _stringCmd.ExecuteReader(); - _stringReader.Read(); - } + var stringArray = new string?[NumElements]; + for (var i = 0; i < NumElements; i++) + stringArray[i] = AllNulls ? null : "666"; + _stringConn = BenchmarkEnvironment.OpenConnection(); + _stringCmd = new NpgsqlCommand("SELECT @p1", _stringConn); + _stringCmd.Parameters.AddWithValue("p1", stringArray); + _stringReader = _stringCmd.ExecuteReader(); + _stringReader.Read(); + } - protected void Cleanup() - { - _intReader.Dispose(); - _nullableIntReader.Dispose(); - _stringReader.Dispose(); + protected void Cleanup() + { + _intReader.Dispose(); + _nullableIntReader.Dispose(); + _stringReader.Dispose(); - _intCmd.Dispose(); - _nullableIntCmd.Dispose(); - _stringCmd.Dispose(); + _intCmd.Dispose(); + _nullableIntCmd.Dispose(); + _stringCmd.Dispose(); - _intConn.Dispose(); - _nullableIntConn.Dispose(); - _stringConn.Dispose(); - } + _intConn.Dispose(); + _nullableIntConn.Dispose(); + _stringConn.Dispose(); + } - [Benchmark] - public int ReadIntArray() - => _intReader.GetFieldValue(0).Length; + [Benchmark] + public int ReadIntArray() + => _intReader.GetFieldValue(0).Length; - [Benchmark] - public int ReadNullableIntArray() - => _nullableIntReader.GetFieldValue(0).Length; + [Benchmark] + public int ReadNullableIntArray() + => _nullableIntReader.GetFieldValue(0).Length; - [Benchmark] - public int ReadStringArray() - => _stringReader.GetFieldValue(0).Length; - } -} + [Benchmark] + public int ReadStringArray() + => _stringReader.GetFieldValue(0).Length; +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ReadColumns.cs b/test/Npgsql.Benchmarks/ReadColumns.cs index 3ec6229ad4..aa10d25f1a 100644 --- a/test/Npgsql.Benchmarks/ReadColumns.cs +++ b/test/Npgsql.Benchmarks/ReadColumns.cs @@ -3,71 +3,70 @@ using System.Text; using BenchmarkDotNet.Attributes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class ReadColumns { - public class ReadColumns - { - NpgsqlConnection _conn = default!; - NpgsqlCommand _cmd = default!; + NpgsqlConnection _conn = default!; + NpgsqlCommand _cmd = default!; - [Params(1, 10, 100, 1000)] - public int NumColumns { get; set; } = 100; + [Params(1, 10, 100, 1000)] + public int NumColumns { get; set; } = 100; - static readonly string[] Queries; + static readonly string[] Queries; - [GlobalSetup] - public void GlobalSetup() - { - _conn = BenchmarkEnvironment.OpenConnection(); - _cmd = new NpgsqlCommand(Queries[NumColumns], _conn); - } + [GlobalSetup] + public void GlobalSetup() + { + _conn = BenchmarkEnvironment.OpenConnection(); + _cmd = new NpgsqlCommand(Queries[NumColumns], _conn); + } - [GlobalCleanup] - public void Cleanup() - { - _cmd.Dispose(); - _conn.Dispose(); - } + [GlobalCleanup] + public void Cleanup() + { + _cmd.Dispose(); + _conn.Dispose(); + } - [Benchmark] - public int IntColumn() + [Benchmark] + public int IntColumn() + { + unchecked { - unchecked + var x = 0; + using (var reader = _cmd.ExecuteReader()) { - var x = 0; - using (var reader = _cmd.ExecuteReader()) - { - reader.Read(); - for (var i = 0; i < NumColumns; i++) - x += reader.GetInt32(i); - } - return x; + reader.Read(); + for (var i = 0; i < NumColumns; i++) + x += reader.GetInt32(i); } + return x; } + } - static ReadColumns() - { - Queries = new string[NumColumnsValues.Max() + 1]; - Queries[0] = "SELECT 1 WHERE 1=0"; - - foreach (var numColumns in NumColumnsValues.Where(i => i != 0)) - Queries[numColumns] = GenerateQuery(numColumns); - } + static ReadColumns() + { + Queries = new string[NumColumnsValues.Max() + 1]; + Queries[0] = "SELECT 1 WHERE 1=0"; - static string GenerateQuery(int numColumns) - { - var sb = new StringBuilder() - .Append("SELECT ") - .Append(string.Join(",", Enumerable.Range(0, numColumns))) - .Append(";"); - return sb.ToString(); - } + foreach (var numColumns in NumColumnsValues.Where(i => i != 0)) + Queries[numColumns] = GenerateQuery(numColumns); + } - static readonly int[] NumColumnsValues = typeof(ReadColumns) - .GetProperty(nameof(NumColumns))! - .GetCustomAttribute()! - .Values - .Cast() - .ToArray(); + static string GenerateQuery(int numColumns) + { + var sb = new StringBuilder() + .Append("SELECT ") + .Append(string.Join(",", Enumerable.Range(0, numColumns))) + .Append(";"); + return sb.ToString(); } -} + + static readonly int[] NumColumnsValues = typeof(ReadColumns) + .GetProperty(nameof(NumColumns))! + .GetCustomAttribute()! + .Values + .Cast() + .ToArray(); +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ReadRows.cs b/test/Npgsql.Benchmarks/ReadRows.cs index 2a60570c5d..7ec8d9ed09 100644 --- a/test/Npgsql.Benchmarks/ReadRows.cs +++ b/test/Npgsql.Benchmarks/ReadRows.cs @@ -1,27 +1,26 @@ using BenchmarkDotNet.Attributes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class ReadRows { - public class ReadRows - { - [Params(1, 10, 100, 1000)] - public int NumRows { get; set; } + [Params(1, 10, 100, 1000)] + public int NumRows { get; set; } - NpgsqlCommand Command { get; set; } = default!; + NpgsqlCommand Command { get; set; } = default!; - [GlobalSetup] - public void Setup() - { - var conn = BenchmarkEnvironment.OpenConnection(); - Command = new NpgsqlCommand($"SELECT generate_series(1, {NumRows})", conn); - Command.Prepare(); - } + [GlobalSetup] + public void Setup() + { + var conn = BenchmarkEnvironment.OpenConnection(); + Command = new NpgsqlCommand($"SELECT generate_series(1, {NumRows})", conn); + Command.Prepare(); + } - [Benchmark] - public void Read() - { - using (var reader = Command.ExecuteReader()) - while (reader.Read()) { } - } + [Benchmark] + public void Read() + { + using (var reader = Command.ExecuteReader()) + while (reader.Read()) { } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/ResolveHandler.cs b/test/Npgsql.Benchmarks/ResolveHandler.cs new file mode 100644 index 0000000000..86e5d20fbb --- /dev/null +++ b/test/Npgsql.Benchmarks/ResolveHandler.cs @@ -0,0 +1,42 @@ +using BenchmarkDotNet.Attributes; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Benchmarks; + +[MemoryDiagnoser] +public class ResolveHandler +{ + NpgsqlDataSource? _dataSource; + PgSerializerOptions _serializerOptions = null!; + + [Params(0, 1, 2)] + public int NumPlugins { get; set; } + + [GlobalSetup] + public void Setup() + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(); + if (NumPlugins > 0) + dataSourceBuilder.UseNodaTime(); + if (NumPlugins > 1) + dataSourceBuilder.UseNetTopologySuite(); + _dataSource = dataSourceBuilder.Build(); + _serializerOptions = _dataSource.SerializerOptions; + } + + [GlobalCleanup] + public void Cleanup() => _dataSource?.Dispose(); + + [Benchmark] + public PgTypeInfo? ResolveDefault() + => _serializerOptions.GetDefaultTypeInfo(new Oid(23)); // int4 + + [Benchmark] + public PgTypeInfo? ResolveType() + => _serializerOptions.GetTypeInfo(typeof(int)); + + [Benchmark] + public PgTypeInfo? ResolveBoth() + => _serializerOptions.GetTypeInfo(typeof(int), new Oid(23)); // int4 +} diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs b/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs index 5b83145054..52418a7240 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Composite.cs @@ -1,12 +1,4 @@ -using System.Collections.Generic; -using BenchmarkDotNet.Attributes; -using Npgsql.NameTranslation; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandlers.CompositeHandlers; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using Npgsql.Util; + /* Disabling for now: unmapped composite support is probably going away, and there's a good chance this * class can be simplified to a certain extent diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs index a31567b974..42f5f3936a 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs @@ -1,67 +1,66 @@ using System.Collections.Generic; using BenchmarkDotNet.Attributes; -using Npgsql.TypeHandlers.NumericHandlers; +using Npgsql.Internal.Converters; -namespace Npgsql.Benchmarks.TypeHandlers -{ - [Config(typeof(Config))] - public class Int16 : TypeHandlerBenchmarks - { - public Int16() : base(new Int16Handler(GetPostgresType("smallint"))) { } - } +namespace Npgsql.Benchmarks.TypeHandlers; - [Config(typeof(Config))] - public class Int32 : TypeHandlerBenchmarks - { - public Int32() : base(new Int32Handler(GetPostgresType("integer"))) { } - } +[Config(typeof(Config))] +public class Int16 : TypeHandlerBenchmarks +{ + public Int16() : base(new Int2Converter()) { } +} - [Config(typeof(Config))] - public class Int64 : TypeHandlerBenchmarks - { - public Int64() : base(new Int64Handler(GetPostgresType("bigint"))) { } - } +[Config(typeof(Config))] +public class Int32 : TypeHandlerBenchmarks +{ + public Int32() : base(new Int4Converter()) { } +} - [Config(typeof(Config))] - public class Single : TypeHandlerBenchmarks - { - public Single() : base(new SingleHandler(GetPostgresType("real"))) { } - } +[Config(typeof(Config))] +public class Int64 : TypeHandlerBenchmarks +{ + public Int64() : base(new Int8Converter()) { } +} - [Config(typeof(Config))] - public class Double : TypeHandlerBenchmarks - { - public Double() : base(new DoubleHandler(GetPostgresType("double precision"))) { } - } +[Config(typeof(Config))] +public class Single : TypeHandlerBenchmarks +{ + public Single() : base(new RealConverter()) { } +} - [Config(typeof(Config))] - public class Numeric : TypeHandlerBenchmarks - { - public Numeric() : base(new NumericHandler(GetPostgresType("numeric"))) { } +[Config(typeof(Config))] +public class Double : TypeHandlerBenchmarks +{ + public Double() : base(new DoubleConverter()) { } +} - protected override IEnumerable ValuesOverride() => new[] - { - 0.0000000000000000000000000001M, - 0.000000000000000000000001M, - 0.00000000000000000001M, - 0.0000000000000001M, - 0.000000000001M, - 0.00000001M, - 0.0001M, - 1M, - 10000M, - 100000000M, - 1000000000000M, - 10000000000000000M, - 100000000000000000000M, - 1000000000000000000000000M, - 10000000000000000000000000000M, - }; - } +[Config(typeof(Config))] +public class Numeric : TypeHandlerBenchmarks +{ + public Numeric() : base(new DecimalNumericConverter()) { } - [Config(typeof(Config))] - public class Money : TypeHandlerBenchmarks + protected override IEnumerable ValuesOverride() => new[] { - public Money() : base(new MoneyHandler(GetPostgresType("money"))) { } - } + 0.0000000000000000000000000001M, + 0.000000000000000000000001M, + 0.00000000000000000001M, + 0.0000000000000001M, + 0.000000000001M, + 0.00000001M, + 0.0001M, + 1M, + 10000M, + 100000000M, + 1000000000000M, + 10000000000000000M, + 100000000000000000000M, + 1000000000000000000000000M, + 10000000000000000000000000000M, + }; +} + +[Config(typeof(Config))] +public class Money : TypeHandlerBenchmarks +{ + public Money() : base(new MoneyConverter()) { } } diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs index 47d444bc45..80d5f6ce0c 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs @@ -1,19 +1,18 @@ using BenchmarkDotNet.Attributes; -using Npgsql.TypeHandlers; using System.Collections.Generic; using System.Text; +using Npgsql.Internal.Converters; -namespace Npgsql.Benchmarks.TypeHandlers +namespace Npgsql.Benchmarks.TypeHandlers; + +[Config(typeof(Config))] +public class Text : TypeHandlerBenchmarks { - [Config(typeof(Config))] - public class Text : TypeHandlerBenchmarks - { - public Text() : base(new TextHandler(GetPostgresType("text"), Encoding.UTF8)) { } + public Text() : base(new StringTextConverter(Encoding.UTF8)) { } - protected override IEnumerable ValuesOverride() - { - for (var i = 1; i <= 10000; i *= 10) - yield return new string('x', i); - } + protected override IEnumerable ValuesOverride() + { + for (var i = 1; i <= 10000; i *= 10) + yield return new string('x', i); } } diff --git a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs index f540cb36b3..994839c219 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs @@ -2,103 +2,103 @@ using BenchmarkDotNet.Columns; using BenchmarkDotNet.Configs; using BenchmarkDotNet.Diagnosers; -using Npgsql.TypeHandling; using System; using System.Collections.Generic; using System.IO; -using System.Text; -using Npgsql.PostgresTypes; -using Npgsql.Util; +using System.Threading; +using Npgsql.Internal; #nullable disable -namespace Npgsql.Benchmarks.TypeHandlers +namespace Npgsql.Benchmarks.TypeHandlers; + +public abstract class TypeHandlerBenchmarks { - public abstract class TypeHandlerBenchmarks + protected class Config : ManualConfig { - protected class Config : ManualConfig + public Config() { - public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - AddDiagnoser(MemoryDiagnoser.Default); - } + AddColumn(StatisticColumn.OperationsPerSecond); + AddDiagnoser(MemoryDiagnoser.Default); } + } - class EndlessStream : Stream - { - public override bool CanRead => true; - public override bool CanSeek => true; - public override bool CanWrite => true; - public override long Length => long.MaxValue; - public override long Position { get => 0L; set { } } - public override void Flush() { } - public override int Read(byte[] buffer, int offset, int count) => count; - public override long Seek(long offset, SeekOrigin origin) => 0L; - public override void SetLength(long value) { } - public override void Write(byte[] buffer, int offset, int count) { } - } + class EndlessStream : Stream + { + public override bool CanRead => true; + public override bool CanSeek => true; + public override bool CanWrite => true; + public override long Length => long.MaxValue; + public override long Position { get => 0L; set { } } + public override void Flush() { } + public override int Read(byte[] buffer, int offset, int count) => count; + public override long Seek(long offset, SeekOrigin origin) => 0L; + public override void SetLength(long value) { } + public override void Write(byte[] buffer, int offset, int count) { } + } - readonly EndlessStream _stream; - readonly NpgsqlTypeHandler _handler; - readonly NpgsqlReadBuffer _readBuffer; - readonly NpgsqlWriteBuffer _writeBuffer; - T _value; - int _elementSize; + readonly PgConverter _converter; + readonly PgReader _reader; + readonly PgWriter _writer; + readonly NpgsqlWriteBuffer _writeBuffer; + readonly NpgsqlReadBuffer _readBuffer; + readonly BufferRequirements _binaryRequirements; - protected TypeHandlerBenchmarks(NpgsqlTypeHandler handler) - { - _stream = new EndlessStream(); - _handler = handler ?? throw new ArgumentNullException(nameof(handler)); - _readBuffer = new NpgsqlReadBuffer(null, _stream, null, NpgsqlReadBuffer.MinimumSize, Encoding.UTF8, PGUtil.RelaxedUTF8Encoding); - _writeBuffer = new NpgsqlWriteBuffer(null, _stream, null, NpgsqlWriteBuffer.MinimumSize, Encoding.UTF8); - } + T _value; + Size _elementSize; - protected static PostgresType GetPostgresType(string pgType) - { - using (var conn = BenchmarkEnvironment.OpenConnection()) - using (var cmd = new NpgsqlCommand($"SELECT NULL::{pgType}", conn)) - using (var reader = cmd.ExecuteReader()) - return reader.GetPostgresType(0); - } + protected TypeHandlerBenchmarks(PgConverter handler) + { + var stream = new EndlessStream(); + _converter = handler ?? throw new ArgumentNullException(nameof(handler)); + _readBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.MinimumSize, NpgsqlWriteBuffer.UTF8Encoding, NpgsqlWriteBuffer.RelaxedUTF8Encoding); + _writeBuffer = new NpgsqlWriteBuffer(null, stream, null, NpgsqlWriteBuffer.MinimumSize, NpgsqlWriteBuffer.UTF8Encoding); + _reader = new PgReader(_readBuffer); + _writer = new PgWriter(new NpgsqlBufferWriter(_writeBuffer)); + _writer.Init(new PostgresMinimalDatabaseInfo()); + _converter.CanConvert(DataFormat.Binary, out _binaryRequirements); + } - public IEnumerable Values() => ValuesOverride(); + public IEnumerable Values() => ValuesOverride(); - protected virtual IEnumerable ValuesOverride() => new[] { default(T) }; + protected virtual IEnumerable ValuesOverride() => new[] { default(T) }; - [ParamsSource(nameof(Values))] - public T Value + [ParamsSource(nameof(Values))] + public T Value + { + get => _value; + set { - get => _value; - set - { - NpgsqlLengthCache cache = null; + _value = value; + object state = null; + var size = _elementSize = _converter.GetSizeAsObject(new(DataFormat.Binary, _binaryRequirements.Write), value, ref state); + var current = new ValueMetadata { Format = DataFormat.Binary, BufferRequirement = _binaryRequirements.Write, Size = size, WriteState = state }; + _writer.BeginWrite(async: false, current, CancellationToken.None).GetAwaiter().GetResult(); + _converter.WriteAsObject(_writer, value); + Buffer.BlockCopy(_writeBuffer.Buffer, 0, _readBuffer.Buffer, 0, size.Value); - _value = value; - _elementSize = _handler.ValidateAndGetLength(value, ref cache, null); - - cache.Rewind(); - - _handler.WriteWithLengthInternal(_value, _writeBuffer, cache, null, false); - Buffer.BlockCopy(_writeBuffer.Buffer, 0, _readBuffer.Buffer, 0, _elementSize); - - _readBuffer.FilledBytes = _elementSize; - _writeBuffer.WritePosition = 0; - } + _writer.Commit(size.Value); + _readBuffer.FilledBytes = size.Value; + _writeBuffer.WritePosition = 0; } + } - [Benchmark] - public T Read() - { - _readBuffer.ReadPosition = sizeof(int); - return _handler.Read(_readBuffer, _elementSize); - } + [Benchmark] + public T Read() + { + _readBuffer.ReadPosition = sizeof(int); + _reader.StartRead(_binaryRequirements.Read); + var value = ((PgConverter)_converter).Read(_reader); + _reader.EndRead(); + return value; + } - [Benchmark] - public void Write() - { - _writeBuffer.WritePosition = 0; - _handler.WriteWithLengthInternal(_value, _writeBuffer, null, null, false); - } + [Benchmark] + public void Write() + { + _writeBuffer.WritePosition = 0; + var current = new ValueMetadata { Format = DataFormat.Binary, BufferRequirement = _binaryRequirements.Write, Size = _elementSize, WriteState = null }; + _writer.BeginWrite(async: false, current, CancellationToken.None).GetAwaiter().GetResult(); + ((PgConverter)_converter).Write(_writer, _value); } } diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs index f2a882ee4e..7c229a3b57 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs @@ -1,12 +1,11 @@ using System; using BenchmarkDotNet.Attributes; -using Npgsql.TypeHandlers; +using Npgsql.Internal.Converters; -namespace Npgsql.Benchmarks.TypeHandlers +namespace Npgsql.Benchmarks.TypeHandlers; + +[Config(typeof(Config))] +public class Uuid : TypeHandlerBenchmarks { - [Config(typeof(Config))] - public class Uuid : TypeHandlerBenchmarks - { - public Uuid() : base(new UuidHandler(GetPostgresType("uuid"))) { } - } + public Uuid() : base(new GuidUuidConverter()) { } } diff --git a/test/Npgsql.Benchmarks/UnixDomainSocket.cs b/test/Npgsql.Benchmarks/UnixDomainSocket.cs index 706cedc7ce..89c42a9a49 100644 --- a/test/Npgsql.Benchmarks/UnixDomainSocket.cs +++ b/test/Npgsql.Benchmarks/UnixDomainSocket.cs @@ -3,41 +3,40 @@ using System.Linq; using BenchmarkDotNet.Attributes; -namespace Npgsql.Benchmarks +namespace Npgsql.Benchmarks; + +public class UnixDomainSocket { - public class UnixDomainSocket - { - readonly NpgsqlConnection _tcpipConn; - readonly NpgsqlCommand _tcpipCmd; - readonly NpgsqlConnection _unixConn; - readonly NpgsqlCommand _unixCmd; + readonly NpgsqlConnection _tcpipConn; + readonly NpgsqlCommand _tcpipCmd; + readonly NpgsqlConnection _unixConn; + readonly NpgsqlCommand _unixCmd; - public UnixDomainSocket() - { - _tcpipConn = BenchmarkEnvironment.OpenConnection(); - _tcpipCmd = new NpgsqlCommand("SELECT @p", _tcpipConn); - _tcpipCmd.Parameters.AddWithValue("p", new string('x', 10000)); + public UnixDomainSocket() + { + _tcpipConn = BenchmarkEnvironment.OpenConnection(); + _tcpipCmd = new NpgsqlCommand("SELECT @p", _tcpipConn); + _tcpipCmd.Parameters.AddWithValue("p", new string('x', 10000)); - var port = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString).Port; - var candidateDirectories = new[] { "/var/run/postgresql", "/tmp" }; - var dir = candidateDirectories.FirstOrDefault(d => File.Exists(Path.Combine(d, $".s.PGSQL.{port}"))); - if (dir == null) - throw new Exception("No PostgreSQL unix domain socket was found"); + var port = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString).Port; + var candidateDirectories = new[] { "/var/run/postgresql", "/tmp" }; + var dir = candidateDirectories.FirstOrDefault(d => File.Exists(Path.Combine(d, $".s.PGSQL.{port}"))); + if (dir == null) + throw new Exception("No PostgreSQL unix domain socket was found"); - var connString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) - { - Host = dir - }.ToString(); - _unixConn = new NpgsqlConnection(connString); - _unixConn.Open(); - _unixCmd = new NpgsqlCommand("SELECT @p", _unixConn); - _unixCmd.Parameters.AddWithValue("p", new string('x', 10000)); - } + var connString = new NpgsqlConnectionStringBuilder(BenchmarkEnvironment.ConnectionString) + { + Host = dir + }.ToString(); + _unixConn = new NpgsqlConnection(connString); + _unixConn.Open(); + _unixCmd = new NpgsqlCommand("SELECT @p", _unixConn); + _unixCmd.Parameters.AddWithValue("p", new string('x', 10000)); + } - [Benchmark(Baseline = true)] - public string Tcpip() => (string)_tcpipCmd.ExecuteScalar()!; + [Benchmark(Baseline = true)] + public string Tcpip() => (string)_tcpipCmd.ExecuteScalar()!; - [Benchmark] - public string UnixDomain() => (string)_unixCmd.ExecuteScalar()!; - } -} + [Benchmark] + public string UnixDomain() => (string)_unixCmd.ExecuteScalar()!; +} \ No newline at end of file diff --git a/test/Npgsql.Benchmarks/WriteVaryingNumberOfParameters.cs b/test/Npgsql.Benchmarks/WriteVaryingNumberOfParameters.cs index 45a8f3f2d1..429861f262 100644 --- a/test/Npgsql.Benchmarks/WriteVaryingNumberOfParameters.cs +++ b/test/Npgsql.Benchmarks/WriteVaryingNumberOfParameters.cs @@ -2,48 +2,47 @@ using BenchmarkDotNet.Attributes; using NpgsqlTypes; -namespace Npgsql.Benchmarks.Types +namespace Npgsql.Benchmarks.Types; + +public class WriteVaryingNumberOfParameters { - public class WriteVaryingNumberOfParameters - { - NpgsqlConnection _conn = default!; - NpgsqlCommand _cmd = default!; + NpgsqlConnection _conn = default!; + NpgsqlCommand _cmd = default!; - [Params(10)] - public int NumParams { get; set; } + [Params(10)] + public int NumParams { get; set; } - [GlobalSetup] - public void Setup() - { - _conn = BenchmarkEnvironment.OpenConnection(); + [GlobalSetup] + public void Setup() + { + _conn = BenchmarkEnvironment.OpenConnection(); - var funcParams = string.Join(",", - Enumerable.Range(0, NumParams) + var funcParams = string.Join(",", + Enumerable.Range(0, NumParams) .Select(i => $"IN p{i} int4") - ); - using (var cmd = new NpgsqlCommand($"CREATE FUNCTION pg_temp.swallow({funcParams}) RETURNS void AS 'BEGIN END;' LANGUAGE 'plpgsql'", _conn)) - cmd.ExecuteNonQuery(); - - var cmdParams = string.Join(",", Enumerable.Range(0, NumParams).Select(i => $"@p{i}")); - _cmd = new NpgsqlCommand($"SELECT pg_temp.swallow({cmdParams})", _conn); - for (var i = 0; i < NumParams; i++) - _cmd.Parameters.Add(new NpgsqlParameter("p" + i, NpgsqlDbType.Integer)); - _cmd.Prepare(); - } - - [GlobalCleanup] - public void Cleanup() - { - _cmd.Unprepare(); - _conn.Close(); - } - - [Benchmark] - public void WriteParameters() - { - for (var i = 0; i < NumParams; i++) - _cmd.Parameters[i].Value = i; - _cmd.ExecuteNonQuery(); - } + ); + using (var cmd = new NpgsqlCommand($"CREATE FUNCTION pg_temp.swallow({funcParams}) RETURNS void AS 'BEGIN END;' LANGUAGE 'plpgsql'", _conn)) + cmd.ExecuteNonQuery(); + + var cmdParams = string.Join(",", Enumerable.Range(0, NumParams).Select(i => $"@p{i}")); + _cmd = new NpgsqlCommand($"SELECT pg_temp.swallow({cmdParams})", _conn); + for (var i = 0; i < NumParams; i++) + _cmd.Parameters.Add(new NpgsqlParameter("p" + i, NpgsqlDbType.Integer)); + _cmd.Prepare(); + } + + [GlobalCleanup] + public void Cleanup() + { + _cmd.Unprepare(); + _conn.Close(); + } + + [Benchmark] + public void WriteParameters() + { + for (var i = 0; i < NumParams; i++) + _cmd.Parameters[i].Value = i; + _cmd.ExecuteNonQuery(); } -} +} \ No newline at end of file diff --git a/test/Npgsql.DependencyInjection.Tests/DependencyInjectionTests.cs b/test/Npgsql.DependencyInjection.Tests/DependencyInjectionTests.cs new file mode 100644 index 0000000000..ebbf0e2388 --- /dev/null +++ b/test/Npgsql.DependencyInjection.Tests/DependencyInjectionTests.cs @@ -0,0 +1,184 @@ +using System; +using System.Data; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Npgsql.Tests; +using Npgsql.Tests.Support; +using NUnit.Framework; + +namespace Npgsql.DependencyInjection.Tests; + +[TestFixture(DataSourceMode.Standard)] +[TestFixture(DataSourceMode.Slim)] +public class DependencyInjectionTests(DataSourceMode mode) +{ + [Test] + public async Task NpgsqlDataSource_is_registered_properly([Values] bool async) + { + var serviceCollection = new ServiceCollection(); + RegisterDataSource(serviceCollection, TestUtil.ConnectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + var dataSource = serviceProvider.GetRequiredService(); + + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + } + + [Test] + public async Task NpgsqlMultiHostDataSource_is_registered_properly([Values] bool async) + { + var serviceCollection = new ServiceCollection(); + RegisterMultiHostDataSource(serviceCollection, TestUtil.ConnectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + var multiHostDataSource = serviceProvider.GetRequiredService(); + var dataSource = serviceProvider.GetRequiredService(); + + Assert.That(dataSource, Is.SameAs(multiHostDataSource)); + + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + } + + [Test] + public async Task NpgsqlDataSource_with_service_key_is_registered_properly([Values] bool async) + { + const string serviceKey = "key"; + var serviceCollection = new ServiceCollection(); + RegisterDataSource(serviceCollection, TestUtil.ConnectionString, serviceKey); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + var dataSource = serviceProvider.GetRequiredKeyedService(serviceKey); + Assert.Throws(() => serviceProvider.GetRequiredService()); + + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + } + + [Test] + public async Task NpgsqlMultiHostDataSource_with_service_key_is_registered_properly([Values] bool async) + { + const string serviceKey = "key"; + var serviceCollection = new ServiceCollection(); + RegisterMultiHostDataSource(serviceCollection, TestUtil.ConnectionString, serviceKey); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + var multiHostDataSource = serviceProvider.GetRequiredKeyedService(serviceKey); + var dataSource = serviceProvider.GetRequiredKeyedService(serviceKey); + Assert.Throws(() => serviceProvider.GetRequiredService()); + Assert.Throws(() => serviceProvider.GetRequiredService()); + + Assert.That(dataSource, Is.SameAs(multiHostDataSource)); + + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + } + + [Test] + public void NpgsqlDataSource_is_registered_as_singleton_by_default() + { + var serviceCollection = new ServiceCollection(); + RegisterDataSource(serviceCollection, TestUtil.ConnectionString); + + using var serviceProvider = serviceCollection.BuildServiceProvider(); + using var scope1 = serviceProvider.CreateScope(); + using var scope2 = serviceProvider.CreateScope(); + var scopeServiceProvider1 = scope1.ServiceProvider; + var scopeServiceProvider2 = scope2.ServiceProvider; + + var dataSource1 = scopeServiceProvider1.GetRequiredService(); + var dataSource2 = scopeServiceProvider2.GetRequiredService(); + + Assert.That(dataSource2, Is.SameAs(dataSource1)); + } + + [Test] + public async Task NpgsqlConnection_is_registered_properly([Values] bool async) + { + var serviceCollection = new ServiceCollection(); + RegisterDataSource(serviceCollection, TestUtil.ConnectionString); + + using var serviceProvider = serviceCollection.BuildServiceProvider(); + using var scope = serviceProvider.CreateScope(); + var scopedServiceProvider = scope.ServiceProvider; + + var connection = scopedServiceProvider.GetRequiredService(); + + Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); + + if (async) + await connection.OpenAsync(); + else + connection.Open(); + } + + [Test] + public void NpgsqlConnection_is_registered_as_transient_by_default() + { + var serviceCollection = new ServiceCollection(); + RegisterDataSource(serviceCollection, "Host=localhost;Username=test;Password=test"); + + using var serviceProvider = serviceCollection.BuildServiceProvider(); + using var scope1 = serviceProvider.CreateScope(); + var scopedServiceProvider1 = scope1.ServiceProvider; + + var connection1 = scopedServiceProvider1.GetRequiredService(); + var connection2 = scopedServiceProvider1.GetRequiredService(); + + Assert.That(connection2, Is.Not.SameAs(connection1)); + + using var scope2 = serviceProvider.CreateScope(); + var scopedServiceProvider2 = scope2.ServiceProvider; + + var connection3 = scopedServiceProvider2.GetRequiredService(); + Assert.That(connection3, Is.Not.SameAs(connection1)); + } + + [Test] + public async Task LoggerFactory_is_picked_up_from_ServiceCollection() + { + var listLoggerProvider = new ListLoggerProvider(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddLogging(b => b.AddProvider(listLoggerProvider)); + RegisterDataSource(serviceCollection, TestUtil.ConnectionString); + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var dataSource = serviceProvider.GetRequiredService(); + await using var command = dataSource.CreateCommand("SELECT 1"); + + using (listLoggerProvider.Record()) + _ = command.ExecuteNonQuery(); + + Assert.That(listLoggerProvider.Log.Any(l => l.Id == NpgsqlEventId.CommandExecutionCompleted)); + } + + IServiceCollection RegisterDataSource(ServiceCollection serviceCollection, string connectionString, object? serviceKey = null) + => mode switch + { + DataSourceMode.Standard => serviceCollection.AddNpgsqlDataSource(connectionString, serviceKey: serviceKey), + DataSourceMode.Slim => serviceCollection.AddNpgsqlSlimDataSource(connectionString, serviceKey: serviceKey), + _ => throw new NotSupportedException($"Mode {mode} not supported") + }; + + IServiceCollection RegisterMultiHostDataSource(ServiceCollection serviceCollection, string connectionString, object? serviceKey = null) + => mode switch + { + DataSourceMode.Standard => serviceCollection.AddMultiHostNpgsqlDataSource(connectionString, serviceKey: serviceKey), + DataSourceMode.Slim => serviceCollection.AddMultiHostNpgsqlSlimDataSource(connectionString, serviceKey: serviceKey), + _ => throw new NotSupportedException($"Mode {mode} not supported") + }; +} + +public enum DataSourceMode +{ + Standard, + Slim +} diff --git a/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj b/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj new file mode 100644 index 0000000000..9637e56366 --- /dev/null +++ b/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj @@ -0,0 +1,20 @@ + + + + net8.0 + + + + + + + + + + + + + + + + diff --git a/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj new file mode 100644 index 0000000000..bc680c3052 --- /dev/null +++ b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj @@ -0,0 +1,22 @@ + + + exe + true + + net8.0 + true + true + true + true + true + false + true + Size + + + + + + + + diff --git a/test/Npgsql.NativeAotTests/Program.cs b/test/Npgsql.NativeAotTests/Program.cs new file mode 100644 index 0000000000..098c978296 --- /dev/null +++ b/test/Npgsql.NativeAotTests/Program.cs @@ -0,0 +1,19 @@ +using System; +using Npgsql; + +var connectionString = Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") + ?? "Server=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0"; + +var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(connectionString); +await using var dataSource = dataSourceBuilder.Build(); + +await using var conn = dataSource.CreateConnection(); +await conn.OpenAsync(); +await using var cmd = new NpgsqlCommand("SELECT 'Hello World'", conn); +await using var reader = await cmd.ExecuteReaderAsync(); +if (!await reader.ReadAsync()) + throw new Exception("Got nothing from the database"); + +var value = reader.GetFieldValue(0); +if (value != "Hello World") + throw new Exception($"Got {value} instead of the expected 'Hello World'"); diff --git a/test/Npgsql.PluginTests/GeoJSONTests.cs b/test/Npgsql.PluginTests/GeoJSONTests.cs index 4988ae1e71..0630eebc8d 100644 --- a/test/Npgsql.PluginTests/GeoJSONTests.cs +++ b/test/Npgsql.PluginTests/GeoJSONTests.cs @@ -1,44 +1,48 @@ using System; -using System.Diagnostics; -using System.Text.RegularExpressions; +using System.Collections.Concurrent; +using System.Linq; using System.Threading.Tasks; using GeoJSON.Net; using GeoJSON.Net.Converters; using GeoJSON.Net.CoordinateReferenceSystem; using GeoJSON.Net.Geometry; using Newtonsoft.Json; -using Npgsql.GeoJSON; using Npgsql.Tests; +using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.PluginTests +namespace Npgsql.PluginTests; + +public class GeoJSONTests : TestBase { - public class GeoJSONTests : TestBase + public struct TestData { - public struct TestData - { - public GeoJSONObject Geometry; - public string CommandText; - } + public GeoJSONObject Geometry; + public string CommandText; + } - public static readonly TestData[] Tests = + public static readonly TestData[] Tests = + { + new() { - new TestData { - Geometry = new Point( + Geometry = new Point( new Position(longitude: 1d, latitude: 2d)) { BoundingBoxes = new[] { 1d, 2d, 1d, 2d } }, - CommandText = "st_makepoint(1,2)" - }, - new TestData { - Geometry = new LineString(new[] { + CommandText = "st_makepoint(1,2)" + }, + new() + { + Geometry = new LineString(new[] { new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d) }) { BoundingBoxes = new[] { 1d, 1d, 1d, 2d } }, - CommandText = "st_makeline(st_makepoint(1,1), st_makepoint(1,2))" - }, - new TestData { - Geometry = new Polygon(new[] { + CommandText = "st_makeline(st_makepoint(1,1), st_makepoint(1,2))" + }, + new() + { + Geometry = new Polygon(new[] { new LineString(new[] { new Position(longitude: 1d, latitude: 1d), new Position(longitude: 2d, latitude: 2d), @@ -47,27 +51,30 @@ public struct TestData }) }) { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, - CommandText = "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))" - }, - new TestData { - Geometry = new MultiPoint(new[] { + CommandText = "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))" + }, + new() + { + Geometry = new MultiPoint(new[] { new Point(new Position(longitude: 1d, latitude: 1d)) }) { BoundingBoxes = new[] { 1d, 1d, 1d, 1d } }, - CommandText = "st_multi(st_makepoint(1, 1))" - }, - new TestData { - Geometry = new MultiLineString(new[] { + CommandText = "st_multi(st_makepoint(1, 1))" + }, + new() + { + Geometry = new MultiLineString(new[] { new LineString(new[] { new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d) }) }) { BoundingBoxes = new[] { 1d, 1d, 1d, 2d } }, - CommandText = "st_multi(st_makeline(st_makepoint(1,1), st_makepoint(1,2)))" - }, - new TestData { - Geometry = new MultiPolygon(new[] { + CommandText = "st_multi(st_makeline(st_makepoint(1,1), st_makepoint(1,2)))" + }, + new() + { + Geometry = new MultiPolygon(new[] { new Polygon(new[] { new LineString(new[] { new Position(longitude: 1d, latitude: 1d), @@ -78,248 +85,350 @@ public struct TestData }) }) { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, - CommandText = "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)])))" - }, - new TestData { - Geometry = new GeometryCollection(new IGeometryObject[] { + CommandText = "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)])))" + }, + new() + { + Geometry = new GeometryCollection(new IGeometryObject[] { new Point(new Position(longitude: 1d, latitude: 1d)), new MultiPolygon(new[] { new Polygon(new[] { new LineString(new[] { - new Position(longitude: 1d, latitude: 1d), - new Position(longitude: 2d, latitude: 2d), - new Position(longitude: 3d, latitude: 3d), - new Position(longitude: 1d, latitude: 1d) + new Position(longitude: 1d, latitude: 1d), + new Position(longitude: 2d, latitude: 2d), + new Position(longitude: 3d, latitude: 3d), + new Position(longitude: 1d, latitude: 1d) }) }) }) }) { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, - CommandText = "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))))" - }, - }; + CommandText = "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))))" + }, + }; - [Test, TestCaseSource(nameof(Tests))] - public void Read(TestData data) - { - using (var conn = OpenConnection(option: GeoJSONOptions.BoundingBox)) - using (var cmd = new NpgsqlCommand($"SELECT {data.CommandText}, st_asgeojson({data.CommandText},options:=1)", conn)) - using (var reader = cmd.ExecuteReader()) - { - Assert.That(reader.Read()); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(data.Geometry)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(JsonConvert.DeserializeObject(reader.GetFieldValue(1), new GeometryConverter()))); - } - } + [Test, TestCaseSource(nameof(Tests))] + public async Task Read(TestData data) + { + await using var conn = await OpenConnectionAsync(GeoJSONOptions.BoundingBox); + await using var cmd = new NpgsqlCommand($"SELECT {data.CommandText}, st_asgeojson({data.CommandText},options:=1)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(data.Geometry)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(JsonConvert.DeserializeObject(reader.GetFieldValue(1), new GeometryConverter()))); + } - [Test, TestCaseSource(nameof(Tests))] - public void Write(TestData data) - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand($"SELECT st_asewkb(@p) = st_asewkb({data.CommandText})", conn)) - { - cmd.Parameters.AddWithValue("p", data.Geometry); - Assert.That(cmd.ExecuteScalar(), Is.True); - } - } + [Test, TestCaseSource(nameof(Tests))] + public async Task Write(TestData data) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT st_asewkb(@p) = st_asewkb({data.CommandText})", conn); + cmd.Parameters.AddWithValue("p", data.Geometry); + Assert.That(await cmd.ExecuteScalarAsync(), Is.True); + } - [Test] - public void IgnoreM() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_makepointm(1,1,1)", conn)) - using (var reader = cmd.ExecuteReader()) - { - Assert.That(reader.Read()); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new Point(new Position(1d, 1d)))); - } - } + [Test] + public async Task IgnoreM() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_makepointm(1,1,1)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(new Point(new Position(1d, 1d)))); + } - public static readonly TestData[] NotAllZSpecifiedTests = + public static readonly TestData[] NotAllZSpecifiedTests = + { + new() { - new TestData { - Geometry = new LineString(new[] { - new Position(1d, 1d, 0d), - new Position(2d, 2d) - }) - }, - new TestData { - Geometry = new LineString(new[] { - new Position(1d, 1d, 0d), - new Position(2d, 2d), - new Position(3d, 3d), - new Position(4d, 4d) - }) - } - }; - - [Test, TestCaseSource(nameof(NotAllZSpecifiedTests))] - public void NotAllZSpecified(TestData data) + Geometry = new LineString(new[] { + new Position(1d, 1d, 0d), + new Position(2d, 2d) + }) + }, + new() { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", data.Geometry); - Assert.That(() => cmd.ExecuteScalar(), Throws.ArgumentException); - } + Geometry = new LineString(new[] { + new Position(1d, 1d, 0d), + new Position(2d, 2d), + new Position(3d, 3d), + new Position(4d, 4d) + }) } + }; - [Test] - public void ReadUnknownCRS() - { - using (var conn = OpenConnection(option: GeoJSONOptions.ShortCRS)) - using (var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 1)", conn)) - using (var reader = cmd.ExecuteReader()) - { - Assert.That(reader.Read()); - Assert.That(() => reader.GetValue(0), Throws.InvalidOperationException); - } - } + [Test, TestCaseSource(nameof(NotAllZSpecifiedTests))] + public async Task Not_all_Z_specified(TestData data) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.AddWithValue("p", data.Geometry); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.ArgumentException); + } - [Test] - public void ReadUnspecifiedCRS() - { - using (var conn = OpenConnection(option: GeoJSONOptions.ShortCRS)) - using (var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 0)", conn)) - using (var reader = cmd.ExecuteReader()) - { - Assert.That(reader.Read()); - Assert.That(reader.GetFieldValue(0).CRS, Is.Null); - } - } + [Test] + public async Task Read_unknown_CRS() + { + await using var conn = await OpenConnectionAsync(GeoJSONOptions.ShortCRS); + await using var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 1)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(() => reader.GetValue(0), Throws.InvalidOperationException); + } - [Test] - public void ReadShortCRS() - { - using (var conn = OpenConnection(option: GeoJSONOptions.ShortCRS)) - using (var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 4326)", conn)) - { - var point = (Point)cmd.ExecuteScalar()!; - var crs = point.CRS as NamedCRS; - - Assert.That(crs, Is.Not.Null); - Assert.That(crs!.Properties["name"], Is.EqualTo("EPSG:4326")); - } - } + [Test] + public async Task Read_unspecified_CRS() + { + await using var conn = await OpenConnectionAsync(GeoJSONOptions.ShortCRS); + await using var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 0)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.GetFieldValue(0).CRS, Is.Null); + } + + [Test] + public async Task Read_short_CRS() + { + await using var conn = await OpenConnectionAsync(GeoJSONOptions.ShortCRS); + await using var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 4326)", conn); + var point = (Point)(await cmd.ExecuteScalarAsync())!; + var crs = point.CRS as NamedCRS; + + Assert.That(crs, Is.Not.Null); + Assert.That(crs!.Properties["name"], Is.EqualTo("EPSG:4326")); + } + + [Test] + public async Task Read_long_CRS() + { + await using var conn = await OpenConnectionAsync(GeoJSONOptions.LongCRS); + await using var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 4326)", conn); + var point = (Point)(await cmd.ExecuteScalarAsync())!; + var crs = point.CRS as NamedCRS; - [Test] - public void ReadLongCRS() + Assert.That(crs, Is.Not.Null); + Assert.That(crs!.Properties["name"], Is.EqualTo("urn:ogc:def:crs:EPSG::4326")); + } + + [Test] + public async Task Write_ill_formed_CRS() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("ill:formed") }); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.TypeOf()); + } + + [Test] + public async Task Write_linked_CRS() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new LinkedCRS("href") }); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.TypeOf()); + } + + [Test] + public async Task Write_unspecified_CRS() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new UnspecifiedCRS() }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(0)); + } + + [Test] + public async Task Write_short_CRS() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("EPSG:4326") }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(4326)); + } + + [Test] + public async Task Write_long_CRS() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("urn:ogc:def:crs:EPSG::4326") }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(4326)); + } + + [Test] + public async Task Write_CRS84() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn); + cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("urn:ogc:def:crs:OGC::CRS84") }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(4326)); + } + + [Test] + public async Task Roundtrip_geometry_geography() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "geom GEOMETRY, geog GEOGRAPHY"); + + var point = new Point(new Position(0d, 0d)); + await using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (geom, geog) VALUES (@p, @p)", conn)) { - using (var conn = OpenConnection(option: GeoJSONOptions.LongCRS)) - using (var cmd = new NpgsqlCommand("SELECT st_setsrid(st_makepoint(0,0), 4326)", conn)) - { - var point = (Point)cmd.ExecuteScalar()!; - var crs = point.CRS as NamedCRS; - - Assert.That(crs, Is.Not.Null); - Assert.That(crs!.Properties["name"], Is.EqualTo("urn:ogc:def:crs:EPSG::4326")); - } + cmd.Parameters.AddWithValue("p", point); + await cmd.ExecuteNonQueryAsync(); } - [Test] - public void WriteIllFormedCRS() + await using (var cmd = new NpgsqlCommand($"SELECT geom, geog FROM {table}", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("ill:formed") }); - Assert.That(() => cmd.ExecuteScalar(), Throws.TypeOf()); - } + await reader.ReadAsync(); + Assert.That(reader[0], Is.EqualTo(point)); + Assert.That(reader[1], Is.EqualTo(point)); } + } + + [Test, TestCaseSource(nameof(Tests))] + public async Task Import_geometry(TestData data) + { + await using var conn = await OpenConnectionAsync(options: GeoJSONOptions.BoundingBox); + var table = await CreateTempTable(conn, "field geometry"); - [Test] - public void WriteLinkedCRS() + await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (field) FROM STDIN BINARY")) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new LinkedCRS("href") }); - Assert.That(() => cmd.ExecuteScalar(), Throws.TypeOf()); - } + await writer.StartRowAsync(); + await writer.WriteAsync(data.Geometry, NpgsqlDbType.Geometry); + + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public void WriteUnspecifiedCRS() + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT field FROM {table}"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.IsTrue(await reader.ReadAsync()); + var actual = reader.GetValue(0); + Assert.That(actual, Is.EqualTo(data.Geometry)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4827")] + public async Task Import_big_geometry() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id text, field geometry"); + + var geometry = new MultiLineString(new[] { + new LineString( + Enumerable.Range(1, 507) + .Select(i => new Position(longitude: i, latitude: i)) + .Append(new Position(longitude: 1d, latitude: 1d))), + new LineString(new[] { + new Position(longitude: 1d, latitude: 1d), + new Position(longitude: 1d, latitude: 2d), + new Position(longitude: 1d, latitude: 3d), + new Position(longitude: 1d, latitude: 1d), + }) + }); + + await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (id, field) FROM STDIN BINARY")) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new UnspecifiedCRS() }); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(0)); - } + await writer.StartRowAsync(); + await writer.WriteAsync("a", NpgsqlDbType.Text); + await writer.WriteAsync(geometry, NpgsqlDbType.Geometry); + + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public void WriteShortCRS() + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT field FROM {table}"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.IsTrue(await reader.ReadAsync()); + var actual = reader.GetValue(0); + Assert.That(actual, Is.EqualTo(geometry)); + } + + [Test, TestCaseSource(nameof(Tests))] + public async Task Export_geometry(TestData data) + { + await using var conn = await OpenConnectionAsync(options: GeoJSONOptions.BoundingBox); + var table = await CreateTempTable(conn, "field geometry"); + + await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (field) FROM STDIN BINARY")) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("EPSG:4326") }); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(4326)); - } + await writer.StartRowAsync(); + await writer.WriteAsync(data.Geometry, NpgsqlDbType.Geometry); + + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public void WriteLongCRS() + await using (var reader = await conn.BeginBinaryExportAsync($"COPY {table} (field) TO STDOUT BINARY")) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("urn:ogc:def:crs:EPSG::4326") }); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(4326)); - } + await reader.StartRowAsync(); + var field = await reader.ReadAsync(NpgsqlDbType.Geometry); + Assert.That(field, Is.EqualTo(data.Geometry)); } + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4830")] + public async Task Export_big_geometry() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id text, field geometry"); - [Test] - public void WriteCRS84() + var geometry = new Polygon(new[] { + new LineString( + Enumerable.Range(1, 507) + .Select(i => new Position(longitude: i, latitude: i)) + .Append(new Position(longitude: 1d, latitude: 1d))), + new LineString(new[] { + new Position(longitude: 1d, latitude: 1d), + new Position(longitude: 1d, latitude: 2d), + new Position(longitude: 1d, latitude: 3d), + new Position(longitude: 1d, latitude: 1d), + }) + }); + + await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (id, field) FROM STDIN BINARY")) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_srid(@p)", conn)) - { - cmd.Parameters.AddWithValue("p", new Point(new Position(0d, 0d)) { CRS = new NamedCRS("urn:ogc:def:crs:OGC::CRS84") }); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(4326)); - } + await writer.StartRowAsync(); + await writer.WriteAsync("aaaa", NpgsqlDbType.Text); + await writer.WriteAsync(geometry, NpgsqlDbType.Geometry); + + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public void RoundtripGeometryGeography() + await using (var reader = await conn.BeginBinaryExportAsync($"COPY {table} (id, field) TO STDOUT BINARY")) { - var point = new Point(new Position(0d, 0d)); - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); - using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) - { - cmd.Parameters.AddWithValue("p", point); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(point)); - Assert.That(reader[1], Is.EqualTo(point)); - } - } + await reader.StartRowAsync(); + var id = await reader.ReadAsync(); + var field = await reader.ReadAsync(NpgsqlDbType.Geometry); + Assert.That(id, Is.EqualTo("aaaa")); + Assert.That(field, Is.EqualTo(geometry)); } + } - protected override NpgsqlConnection OpenConnection(string? connectionString = null) - => OpenConnection(connectionString, GeoJSONOptions.None); + ValueTask OpenConnectionAsync(GeoJSONOptions options = GeoJSONOptions.None) + => GetDataSource(options).OpenConnectionAsync(); - protected NpgsqlConnection OpenConnection(string? connectionString = null, GeoJSONOptions option = GeoJSONOptions.None) + NpgsqlDataSource GetDataSource(GeoJSONOptions options = GeoJSONOptions.None) + => GeoJsonDataSources.GetOrAdd(options, _ => { - var conn = base.OpenConnection(connectionString); - conn.TypeMapper.UseGeoJson(option); - return conn; - } + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseGeoJson(options); + return dataSourceBuilder.Build(); + }); - [OneTimeSetUp] - public async Task SetUp() - { - await using var conn = await base.OpenConnectionAsync(); - await TestUtil.EnsurePostgis(conn); - } + [OneTimeSetUp] + public async Task SetUp() + { + await using var conn = await OpenConnectionAsync(); + await EnsurePostgis(conn); } + + [OneTimeTearDown] + public async Task Teardown() + => await Task.WhenAll(GeoJsonDataSources.Values.Select(async ds => await ds.DisposeAsync())); + + ConcurrentDictionary GeoJsonDataSources = new(); } diff --git a/test/Npgsql.PluginTests/JsonNetTests.cs b/test/Npgsql.PluginTests/JsonNetTests.cs index b686dd3569..b3fb1e26bb 100644 --- a/test/Npgsql.PluginTests/JsonNetTests.cs +++ b/test/Npgsql.PluginTests/JsonNetTests.cs @@ -5,275 +5,284 @@ using NUnit.Framework; using System; using System.Text; +using System.Threading.Tasks; // ReSharper disable AccessToModifiedClosure // ReSharper disable AccessToDisposedClosure -namespace Npgsql.PluginTests +namespace Npgsql.PluginTests; + +/// +/// Tests for the Npgsql.Json.NET mapping plugin +/// +[TestFixture(NpgsqlDbType.Jsonb)] +[TestFixture(NpgsqlDbType.Json)] +public class JsonNetTests : TestBase { - /// - /// Tests for the Npgsql.Json.NET mapping plugin - /// - [NonParallelizable] - [TestFixture(NpgsqlDbType.Jsonb)] - [TestFixture(NpgsqlDbType.Json)] - public class JsonNetTests : TestBase + [Test] + public Task Roundtrip_object() + => AssertType( + JsonDataSource, + new Foo { Bar = 8 }, + IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", + _pgTypeName, + _npgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] + public Task Roundtrip_string() + => AssertType( + JsonDataSource, + @"{""p"": 1}", + @"{""p"": 1}", + _pgTypeName, + _npgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] + public Task Roundtrip_char_array() + => AssertType( + JsonDataSource, + @"{""p"": 1}".ToCharArray(), + @"{""p"": 1}", + _pgTypeName, + _npgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] + public Task Roundtrip_byte_array() + => AssertType( + JsonDataSource, + Encoding.ASCII.GetBytes(@"{""p"": 1}"), + @"{""p"": 1}", + _pgTypeName, + _npgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_JObject() + => AssertType( + JsonDataSource, + new JObject { ["Bar"] = 8 }, + IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", + _pgTypeName, + _npgsqlDbType, + // By default we map JObject to jsonb + isDefaultForWriting: IsJsonb, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_JArray() + => AssertType( + JsonDataSource, + new JArray(new[] { 1, 2, 3 }), + IsJsonb ? "[1, 2, 3]" : "[1,2,3]", + _pgTypeName, + _npgsqlDbType, + // By default we map JArray to jsonb + isDefaultForWriting: IsJsonb, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public async Task Deserialize_failure() { - [Test] - public void RoundtripObject() - { - var expected = new Foo { Bar = 8 }; - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", _npgsqlDbType) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p2", NpgsqlDbType = _npgsqlDbType, TypedValue = expected - }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(expected)); - } - } - } - - [Test] - public void DeserializeFailure() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand($@"SELECT '[1, 2, 3]'::{_pgTypeName}", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - // Attempt to deserialize JSON array into object - Assert.That(() => reader.GetFieldValue(0), Throws.TypeOf()); - // State should still be OK to continue - var actual = reader.GetFieldValue(0); - Assert.That((int)actual[0], Is.EqualTo(1)); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] - public void RoundtripStringTypes() - { - var expected = "{\"p\":1}"; - // If we serialize to JSONB, Postgres will not store the Json.NET formatting, and will add a space after ':' - var expectedString = _npgsqlDbType.Equals(NpgsqlDbType.Jsonb) ? "{\"p\": 1}" - : "{\"p\":1}"; - - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand(@"SELECT @p1, @p2, @p3", conn); - - cmd.Parameters.Add(new NpgsqlParameter("p1", _npgsqlDbType) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter("p2", _npgsqlDbType) { Value = expected.ToCharArray() }); - cmd.Parameters.Add(new NpgsqlParameter("p3", _npgsqlDbType) { Value = Encoding.ASCII.GetBytes(expected) }); - - using var reader = cmd.ExecuteReader(); - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedString)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(expectedString.ToCharArray())); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(Encoding.ASCII.GetBytes(expectedString))); - } - - [Test, Ignore("INpgsqlTypeHandler>.Read currently not yet implemented in TextHandler")] - public void RoundtripArraySegment() - { - var expected = "{\"p\":1}"; - // If we serialize to JSONB, Postgres will not store the Json.NET formatting, and will add a space after ':' - var expectedString = _npgsqlDbType.Equals(NpgsqlDbType.Jsonb) ? "{\"p\": 1}" - : "{\"p\":1}"; + await using var conn = await JsonDataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($@"SELECT '[1, 2, 3]'::{_pgTypeName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + // Attempt to deserialize JSON array into object + Assert.That(() => reader.GetFieldValue(0), Throws.TypeOf()); + // State should still be OK to continue + var actual = reader.GetFieldValue(0); + Assert.That((int)actual[0], Is.EqualTo(1)); + } - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand(@"SELECT @p1", conn); + [Test] + public async Task Clr_type_mapping() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + if (IsJsonb) + dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(Foo) }); + else + dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(Foo) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new Foo { Bar = 8 }, + IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", + _pgTypeName, + _npgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } - cmd.Parameters.Add(new NpgsqlParameter>("p1", _npgsqlDbType) { Value = new ArraySegment(expected.ToCharArray()) }); + [Test] + public async Task Roundtrip_clr_array() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + if (IsJsonb) + dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(int[]) }); + else + dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(int[]) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new[] { 1, 2, 3 }, + IsJsonb ? "[1, 2, 3]" : "[1,2,3]", + _pgTypeName, + _npgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } - using var reader = cmd.ExecuteReader(); - reader.Read(); - Assert.That(reader.GetFieldValue>(0), Is.EqualTo(expectedString)); - } + class DateWrapper + { + public DateTime Date; + public override bool Equals(object? obj) => (obj as DateWrapper)?.Date == Date; + public override int GetHashCode() => Date.GetHashCode(); + } - class Foo - { - public int Bar { get; set; } - public override bool Equals(object? obj) => (obj as Foo)?.Bar == Bar; - public override int GetHashCode() => Bar.GetHashCode(); - } + [Test] + public async Task Custom_serializer_settings() + { + var settings = new JsonSerializerSettings { DateFormatString = @"T\he d\t\h o\f MMMM, yyyy" }; + + var dataSourceBuilder = CreateDataSourceBuilder(); + if (IsJsonb) + dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(DateWrapper) }, settings: settings); + else + dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(DateWrapper) }, settings: settings); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new DateWrapper { Date = new DateTime(2018, 04, 20) }, + IsJsonb ? "{\"Date\": \"The 20th of April, 2018\"}" : "{\"Date\":\"The 20th of April, 2018\"}", + _pgTypeName, + _npgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + } - class Bar - { - public int A { get; set; } - } + [Test] + public async Task Bug3464() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(Bug3464Class) }); + await using var dataSource = dataSourceBuilder.Build(); - [Test] - public void RoundtripJObject() - { - var expected = new JObject { ["Bar"] = 8 }; - - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand(@"SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", _npgsqlDbType) { Value = expected }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That((int)actual["Bar"], Is.EqualTo(8)); - } - } - } - - [Test] - public void RoundtripJArray() - { - var expected = new JArray(new[] { 1, 2, 3 }); - - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand(@"SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", _npgsqlDbType) { Value = expected }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var jarray = reader.GetFieldValue(0); - Assert.That(jarray.ToObject(), Is.EqualTo(new[] { 1, 2, 3 })); - } - } - } - - [Test] - public void ClrTypeMapping() - { - var expected = new Foo { Bar = 8 }; - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand(@"SELECT @p", conn)) - { - conn.TypeMapper.UseJsonNet(new[] { typeof(Foo) }); - - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That(actual.Bar, Is.EqualTo(8)); - } - } - } - - [Test, Ignore("https://github.com/npgsql/npgsql/issues/2568")] - public void ClrTypeMappingTwoTypes() - { - var value1 = new Foo { Bar = 8 }; - var value2 = new Bar { A = 8 }; - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn)) - { - conn.TypeMapper.UseJsonNet(new[] { typeof(Foo), typeof(Bar) }); - - cmd.Parameters.AddWithValue("p1", value1); - cmd.Parameters.AddWithValue("p2", value1); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual1 = reader.GetFieldValue(0); - Assert.That(actual1.Bar, Is.EqualTo(8)); - var actual2 = reader.GetFieldValue(1); - Assert.That(actual2.A, Is.EqualTo(8)); - } - } - } - - [Test] - public void RoundtripClrArray() - { - var expected = new[] { 1, 2, 3 }; - - using (var conn = OpenConnection()) - { - conn.TypeMapper.UseJsonNet(new[] { typeof(int[]) }); - - using (var cmd = new NpgsqlCommand($@"SELECT @p::{_pgTypeName}", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That(actual, Is.EqualTo(expected)); - } - } - } - } - - class DateWrapper + var expected = new Bug3464Class { SomeString = new string('5', 8174) }; + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn); + + cmd.Parameters.AddWithValue("p1", expected).NpgsqlDbType = _npgsqlDbType; + cmd.Parameters.AddWithValue("p2", expected).NpgsqlDbType = _npgsqlDbType; + + await using var reader = cmd.ExecuteReader(); + } + + public class Bug3464Class + { + public string? SomeString { get; set; } + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5475")] + public async Task Read_jarray_from_get_value() + { + await using var conn = await JsonDataSource.OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand { Connection = conn }; + + var json = new JArray(new JObject { { "name", "value1" } }); + + cmd.CommandText = $"SELECT @p"; + cmd.Parameters.Add(new("p", json)); + await cmd.ExecuteScalarAsync(); + } + [Test] + public async Task Write_jobject_without_npgsqldbtype() + { + await using var conn = await JsonDataSource.OpenConnectionAsync(); + var tableName = await TestUtil.CreateTempTable(conn, "key SERIAL PRIMARY KEY, ingredients json"); + + await using var cmd = new NpgsqlCommand { Connection = conn }; + + var jsonObject = new JObject { - public System.DateTime Date; - public override bool Equals(object? obj) => (obj as DateWrapper)?.Date == Date; - public override int GetHashCode() => Date.GetHashCode(); - } + { "name", "value1" }, + { "amount", 1 }, + { "unit", "ml" } + }; + + cmd.CommandText = $"INSERT INTO {tableName} (ingredients) VALUES (@p)"; + cmd.Parameters.Add(new("p", jsonObject)); + await cmd.ExecuteNonQueryAsync(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4537")] + public async Task Write_jobject_array_without_npgsqldbtype() + { + await using var conn = await JsonDataSource.OpenConnectionAsync(); + var tableName = await TestUtil.CreateTempTable(conn, "key SERIAL PRIMARY KEY, ingredients json[]"); + + await using var cmd = new NpgsqlCommand { Connection = conn }; - void RoundtripCustomSerializerSettings(bool asJsonb) + var jsonObject1 = new JObject { - var expected = new DateWrapper() { Date = new System.DateTime(2018, 04, 20) }; - - var settings = new JsonSerializerSettings() - { - DateFormatString = @"T\he d\t\h o\f MMMM, yyyy" - }; - - // If we serialize to JSONB, Postgres will not store the Json.NET formatting, and will add a space after ':' - var expectedString = asJsonb ? "{\"Date\": \"The 20th of April, 2018\"}" - : "{\"Date\":\"The 20th of April, 2018\"}"; - - using (var conn = OpenConnection()) - { - if (asJsonb) - { - conn.TypeMapper.UseJsonNet(jsonbClrTypes : new[] { typeof(DateWrapper) }, settings : settings); - } - else - { - conn.TypeMapper.UseJsonNet(jsonClrTypes : new[] { typeof(DateWrapper) }, settings : settings); - } - - using (var cmd = new NpgsqlCommand($@"SELECT @p::{_pgTypeName}, @p::text", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - var actualString = reader.GetFieldValue(1); - Assert.That(actual, Is.EqualTo(expected)); - Assert.That(actualString, Is.EqualTo(expectedString)); - } - } - } - } - - [Test] - public void RoundtripJsonbCustomSerializerSettings() => RoundtripCustomSerializerSettings(asJsonb : true); - - [Test] - public void RoundtripJsonCustomSerializerSettings() => RoundtripCustomSerializerSettings(asJsonb : false); - - protected override NpgsqlConnection OpenConnection(string? connectionString = null) + { "name", "value1" }, + { "amount", 1 }, + { "unit", "ml" } + }; + + var jsonObject2 = new JObject { - var conn = base.OpenConnection(connectionString); - conn.TypeMapper.UseJsonNet(); - return conn; - } + { "name", "value2" }, + { "amount", 2 }, + { "unit", "g" } + }; + + cmd.CommandText = $"INSERT INTO {tableName} (ingredients) VALUES (@p)"; + cmd.Parameters.Add(new("p", new[] { jsonObject1, jsonObject2 })); + await cmd.ExecuteNonQueryAsync(); + } - readonly NpgsqlDbType _npgsqlDbType; - readonly string _pgTypeName; + class Foo + { + public int Bar { get; set; } + public override bool Equals(object? obj) => (obj as Foo)?.Bar == Bar; + public override int GetHashCode() => Bar.GetHashCode(); + } - public JsonNetTests(NpgsqlDbType npgsqlDbType) - { - _npgsqlDbType = npgsqlDbType; - _pgTypeName = npgsqlDbType.ToString().ToLower(); - } + readonly NpgsqlDbType _npgsqlDbType; + readonly string _pgTypeName; + + [OneTimeSetUp] + public void SetUp() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseJsonNet(); + JsonDataSource = dataSourceBuilder.Build(); } + + [OneTimeTearDown] + public async Task Teardown() + => await JsonDataSource.DisposeAsync(); + + public JsonNetTests(NpgsqlDbType npgsqlDbType) + { + _npgsqlDbType = npgsqlDbType; + _pgTypeName = npgsqlDbType.ToString().ToLower(); + } + + bool IsJsonb => _npgsqlDbType == NpgsqlDbType.Jsonb; + + NpgsqlDataSource JsonDataSource = default!; } diff --git a/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs new file mode 100644 index 0000000000..3f5eb05177 --- /dev/null +++ b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs @@ -0,0 +1,106 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Tests; +using NpgsqlTypes; +using NUnit.Framework; +using Npgsql.NodaTime.Internal; + +namespace Npgsql.PluginTests; + +[NonParallelizable] // Since this test suite manipulates an AppContext switch +public class LegacyNodaTimeTests : TestBase, IDisposable +{ + const string TimeZone = "Europe/Berlin"; + + [Test] + public async Task Timestamp_as_ZonedDateTime() + { + await AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InZoneLeniently(DateTimeZoneProviders.Tzdb[TimeZone]), + "1998-04-12 13:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isNpgsqlDbTypeInferredFromClrType: false, isDefault: false); + } + + [Test] + public Task Timestamp_as_Instant() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), + "1998-04-12 13:26:38.789", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamp_as_LocalDateTime() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), + "1998-04-12 13:26:38.789", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamptz_as_Instant() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), + "1998-04-12 15:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isDefaultForWriting: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public async Task Timestamptz_ZonedDateTime_infinite_values_are_not_supported() + { + await AssertTypeUnsupportedRead("infinity", "timestamptz"); + await AssertTypeUnsupportedWrite(Instant.MaxValue.WithOffset(Offset.Zero), "timestamptz"); + } + + [Test] + public async Task Timestamptz_OffsetDateTime_infinite_values_are_not_supported() + { + await AssertTypeUnsupportedRead("infinity", "timestamptz"); + await AssertTypeUnsupportedWrite(Instant.MaxValue.WithOffset(Offset.Zero), "timestamptz"); + } + + #region Support + + protected override NpgsqlDataSource DataSource { get; } + + public LegacyNodaTimeTests() + { +#if DEBUG + NodaTimeUtils.LegacyTimestampBehavior = true; + Util.Statics.LegacyTimestampBehavior = true; + + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Timezone = TimeZone; + DataSource = builder.Build(); +#else + Assert.Ignore( + "Legacy NodaTime tests rely on the Npgsql.EnableLegacyTimestampBehavior AppContext switch and can only be run in DEBUG builds"); +#endif + } + + public void Dispose() + { +#if DEBUG + NodaTimeUtils.LegacyTimestampBehavior = false; + Util.Statics.LegacyTimestampBehavior = false; + + DataSource.Dispose(); +#endif + } + + #endregion Support +} diff --git a/test/Npgsql.PluginTests/LegacyPostgisTests.cs b/test/Npgsql.PluginTests/LegacyPostgisTests.cs deleted file mode 100644 index f29a65bfed..0000000000 --- a/test/Npgsql.PluginTests/LegacyPostgisTests.cs +++ /dev/null @@ -1,415 +0,0 @@ -using System; -using System.Diagnostics; -using System.Linq; -using System.Text.RegularExpressions; -using System.Threading.Tasks; -using Npgsql.LegacyPostgis; -using Npgsql.Tests; -using NpgsqlTypes; -using NUnit.Framework; - -namespace Npgsql.PluginTests -{ - class LegacyPostgisTests : TestBase - { - public class TestAtt - { - public PostgisGeometry Geom = default!; - public string SQL = default!; - } - - static readonly TestAtt[] Tests = - { - new TestAtt { Geom = new PostgisPoint(1D, 2500D), SQL = "st_makepoint(1,2500)" }, - new TestAtt { - Geom = new PostgisLineString(new[] { new Coordinate2D(1D, 1D), new Coordinate2D(1D, 2500D) }), - SQL = "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))" - }, - new TestAtt { - Geom = new PostgisPolygon(new[] { new[] { - new Coordinate2D(1d,1d), - new Coordinate2D(2d,2d), - new Coordinate2D(3d,3d), - new Coordinate2D(1d,1d) - }}), - SQL = "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))" - }, - new TestAtt { - Geom = new PostgisMultiPoint(new[] { new Coordinate2D(1D, 1D) }), - SQL = "st_multi(st_makepoint(1,1))" - }, - new TestAtt { - Geom = new PostgisMultiLineString(new[] { - new PostgisLineString(new[] { - new Coordinate2D(1D, 1D), - new Coordinate2D(1D, 2500D) - }) - }), - SQL = "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))" - }, - new TestAtt { - Geom = new PostgisMultiPolygon(new[] { - new PostgisPolygon(new[] { new[] { - new Coordinate2D(1d,1d), - new Coordinate2D(2d,2d), - new Coordinate2D(3d,3d), - new Coordinate2D(1d,1d) - }}) - }), - SQL = "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))" - }, - new TestAtt { - Geom = new PostgisGeometryCollection(new PostgisGeometry[] { - new PostgisPoint(1,1), - new PostgisMultiPolygon(new[] { - new PostgisPolygon(new[] { new[] { - new Coordinate2D(1d,1d), - new Coordinate2D(2d,2d), - new Coordinate2D(3d,3d), - new Coordinate2D(1d,1d) - }}) - }) - }), - SQL = "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))" - }, - new TestAtt { - Geom = new PostgisGeometryCollection(new PostgisGeometry[] { - new PostgisPoint(1,1), - new PostgisGeometryCollection(new PostgisGeometry[] { - new PostgisPoint(1,1), - new PostgisMultiPolygon(new[] { - new PostgisPolygon(new[] { new[] { - new Coordinate2D(1d,1d), - new Coordinate2D(2d,2d), - new Coordinate2D(3d,3d), - new Coordinate2D(1d,1d) - }}) - }) - }) - }), - SQL = "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))" - } - }; - - [Test,TestCaseSource(nameof(Tests))] - public void PostgisTestRead(TestAtt att) - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - var a = att; - cmd.CommandText = "Select " + a.SQL; - var p = cmd.ExecuteScalar()!; - Assert.IsTrue(p.Equals(a.Geom)); - } - } - - [Test, TestCaseSource(nameof(Tests))] - public void PostgisTestWrite(TestAtt a) - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Geometry,a.Geom); - a.Geom.SRID = 0; - cmd.CommandText = "Select st_asewkb(:p1) = st_asewkb(" + a.SQL + ")"; - bool areEqual; - try { - areEqual = (bool)cmd.ExecuteScalar()!; - } - catch (Exception e) - { - throw new Exception("Exception caught on " + a.Geom, e); - } - Assert.IsTrue(areEqual, "Error on comparison of " + a.Geom); - } - } - - [Test, TestCaseSource(nameof(Tests))] - public void PostgisTestWriteSrid(TestAtt a) - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Geometry, a.Geom); - a.Geom.SRID = 3942; - cmd.CommandText = "Select st_asewkb(:p1) = st_asewkb(st_setsrid("+ a.SQL + ",3942))"; - var p = (bool)cmd.ExecuteScalar()!; - Assert.IsTrue(p); - } - } - - [Test, TestCaseSource(nameof(Tests))] - public void PostgisTestReadSrid(TestAtt a) - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "Select st_setsrid(" + a.SQL + ",3942)"; - var p = cmd.ExecuteScalar()!; - Assert.IsTrue(p.Equals(a.Geom)); - Assert.IsTrue(((PostgisGeometry)p).SRID == 3942); - } - } - - [Test] - public void PostgisTestArrayRead() - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "Select ARRAY(select st_makepoint(1,1))"; - var p = cmd.ExecuteScalar() as PostgisGeometry[]; - var p2 = new PostgisPoint(1d, 1d); - Assert.IsTrue(p?[0] is PostgisPoint && p2 == (PostgisPoint)p[0]); - } - } - - [Test] - public void PostgisTestArrayWrite() - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - var p = new PostgisPoint[1] { new PostgisPoint(1d, 1d) }; - cmd.Parameters.AddWithValue(":p1", NpgsqlDbType.Array | NpgsqlDbType.Geometry, p); - cmd.CommandText = "SELECT :p1 = array(select st_makepoint(1,1))"; - Assert.IsTrue((bool)cmd.ExecuteScalar()!); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1022")] - public void MultiPolygonWithMultiplePolygons() - { - var geom2 = new PostgisMultiPolygon(new[] - { - new PostgisPolygon(new[] { - new[] - { - new Coordinate2D(40, 40), - new Coordinate2D(20, 45), - new Coordinate2D(45, 30), - new Coordinate2D(40, 40) - } - }), - new PostgisPolygon(new[] { - new[] - { - new Coordinate2D(20, 35), - new Coordinate2D(10, 30), - new Coordinate2D(10, 10), - new Coordinate2D(30, 5), - new Coordinate2D(45, 20), - new Coordinate2D(20, 35) - } - }) - }) { SRID = 4326 }; - using (var conn = OpenConnection()) - using (var command = conn.CreateCommand()) - { - command.Parameters.AddWithValue("p1", geom2); - command.CommandText = "Select :p1"; - command.ExecuteScalar(); - } - } - - [Test, TestCaseSource(nameof(Tests)), IssueLink("https://github.com/npgsql/npgsql/issues/1260")] - public void CopyBinary(TestAtt a) - { - using (var c = OpenConnection()) - { - using (var cmd = new NpgsqlCommand("CREATE TEMPORARY TABLE testcopybin (g geometry)", c)) - cmd.ExecuteNonQuery(); - - try - { - using (var writer = c.BeginBinaryImport($"COPY testcopybin (g) FROM STDIN (FORMAT BINARY)")) - { - for (var i = 0; i < 1000; i++) - writer.WriteRow(a.Geom); - writer.Complete(); - } - } - catch (Exception e) - { - Assert.Fail($"Copy from stdin failed with {e} at geometry {a.Geom}."); - } - - try - { - using (var rdr = c.BeginBinaryExport($"COPY testcopybin (g) TO STDOUT (FORMAT BINARY) ")) - { - for (var i =0; i < 1000; i++) - { - rdr.StartRow(); - Assert.IsTrue(a.Geom.Equals(rdr.Read())); - } - } - } - catch(Exception e) - { - Assert.Fail($"Copy from stdout failed with {e} at geometry {a.Geom}."); - } - } - } - - [Test, TestCaseSource(nameof(Tests)), IssueLink("https://github.com/npgsql/npgsql/issues/1260")] - public void CopyBinaryArray(TestAtt a) - { - using (var c = OpenConnection()) - { - using (var cmd = new NpgsqlCommand("CREATE TEMPORARY TABLE testcopybinarray (g geometry[3])", c)) - cmd.ExecuteNonQuery(); - - var t = new PostgisGeometry[3] { a.Geom, a.Geom, a.Geom }; - try - { - using (var writer = c.BeginBinaryImport("COPY testcopybinarray (g) FROM STDIN (FORMAT BINARY)")) - { - for (var i = 0; i < 1000; i++) - writer.WriteRow(new[] { t }); - writer.Complete(); - } - } - catch(Exception e) - { - Assert.Fail($"Copy from stdin failed with {e} at geometry {a.Geom}."); - } - - try - { - using (var rdr = c.BeginBinaryExport("COPY testcopybinarray (g) TO STDOUT (FORMAT BINARY)")) - for (var i = 0; i < 1000; i++) - { - rdr.StartRow(); - Assert.IsTrue(t.SequenceEqual(rdr.Read())); - } - } - catch(Exception e) - { - Assert.Fail($"Copy to stdout failed with {e} at geometry {a.Geom}."); - } - } - } - - [Test] - public void TestPolygonEnumeration() - { - var a = new Coordinate2D[2][] { - new Coordinate2D[4] { new Coordinate2D(0D, 0D), new Coordinate2D(0D, 1D), - new Coordinate2D(1D, 1D), new Coordinate2D(0D, 0D) }, - new Coordinate2D[5] { new Coordinate2D(0D, 0D), new Coordinate2D(0D, 2D), - new Coordinate2D(2D, 2D),new Coordinate2D(2D, 0D), - new Coordinate2D(0D, 0D) } }; - Assert.That(a.SequenceEqual(new PostgisPolygon(a))); - } - - [Test] - public void ReadAsConcreteType() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT st_makepoint(1, 1)", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new PostgisPoint(1, 1))); - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); - } - } - - [Test] - public void Bug1381() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add("p", NpgsqlTypes.NpgsqlDbType.Geometry).Value = new PostgisMultiPolygon(new[] - { - new PostgisPolygon(new[] - { - new[] - { - new Coordinate2D(-0.555701, 46.42473701), - new Coordinate2D(-0.549486, 46.42707801), - new Coordinate2D(-0.549843, 46.42749901), - new Coordinate2D(-0.555524, 46.42533901), - new Coordinate2D(-0.555701, 46.42473701) - } - }) - // This is the problem: - { SRID = 4326 } - }) { SRID = 4326 }; - - cmd.ExecuteNonQuery(); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1557")] - public void SubGeometriesWithSRID() - { - var point = new PostgisPoint(1, 1) - { - SRID = 4326 - }; - - var lineString = new PostgisLineString(new[] { new Coordinate2D(2, 2), new Coordinate2D(3, 3) }) - { - SRID = 4326 - }; - - var polygon = new PostgisPolygon(new[] { new[] { new Coordinate2D(4, 4), new Coordinate2D(5, 5), new Coordinate2D(6, 6), new Coordinate2D(4, 4) } }) - { - SRID = 4326 - }; - - var collection = new PostgisGeometryCollection(new PostgisGeometry[] { point, lineString, polygon }) - { - SRID = 4326 - }; - - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT :p", conn)) - { - cmd.Parameters.AddWithValue("p", collection); - cmd.ExecuteNonQuery(); - } - } - - [Test] - public void RoundtripGeometryGeography() - { - var point = new PostgisPoint(1d, 1d); - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); - using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) - { - cmd.Parameters.AddWithValue("@p", point); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(point)); - Assert.That(reader[1], Is.EqualTo(point)); - } - } - } - - protected override NpgsqlConnection OpenConnection(string? connectionString = null) - { - var conn = base.OpenConnection(connectionString); - conn.TypeMapper.UseLegacyPostgis(); - return conn; - } - - [OneTimeSetUp] - public async Task SetUp() - { - await using var conn = await base.OpenConnectionAsync(); - await TestUtil.EnsurePostgis(conn); - } - } -} diff --git a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs index 0b9935d0ed..4e225d121c 100644 --- a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs +++ b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs @@ -1,38 +1,63 @@ using System; -using System.Collections; -using System.Diagnostics; -using System.Text.RegularExpressions; +using System.Collections.Concurrent; +using System.Linq; using System.Threading.Tasks; using NetTopologySuite.Geometries; using NetTopologySuite.Geometries.Implementation; using Npgsql.Tests; +using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.PluginTests +namespace Npgsql.PluginTests; + +public class NetTopologySuiteTests : TestBase { - public class NetTopologySuiteTests : TestBase + static readonly TestCaseData[] TestCases = { - public struct TestData - { - public Ordinates Ordinates; - public Geometry Geometry; - public string CommandText; - } + new TestCaseData(Ordinates.None, new Point(1d, 2500d), "st_makepoint(1,2500)") + .SetName("Point"), - public static IEnumerable TestCases { - get - { - // Two dimensional data - yield return new TestCaseData(Ordinates.None, new Point(1d, 2500d), "st_makepoint(1,2500)"); + new TestCaseData(Ordinates.None, new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }), "st_multi(st_makepoint(1, 1))") + .SetName("MultiPoint"), + + new TestCaseData( + Ordinates.None, + new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }), + "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))") + .SetName("LineString"), - yield return new TestCaseData( - Ordinates.None, - new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }), - "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))" - ); + new TestCaseData( + Ordinates.None, + new MultiLineString(new[] + { + new LineString(new[] + { + new Coordinate(1d, 1d), + new Coordinate(1d, 2500d) + }) + }), + "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))") + .SetName("MultiLineString"), - yield return new TestCaseData( - Ordinates.None, + new TestCaseData( + Ordinates.None, + new Polygon( + new LinearRing(new[] + { + new Coordinate(1d, 1d), + new Coordinate(2d, 2d), + new Coordinate(3d, 3d), + new Coordinate(1d, 1d) + }) + ), + "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))") + .SetName("Polygon"), + + new TestCaseData( + Ordinates.None, + new MultiPolygon(new[] + { new Polygon( new LinearRing(new[] { @@ -41,31 +66,19 @@ public static IEnumerable TestCases { new Coordinate(3d, 3d), new Coordinate(1d, 1d) }) - ), - "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))" - ); - - yield return new TestCaseData( - Ordinates.None, - new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }), - "st_multi(st_makepoint(1, 1))" - ); - - yield return new TestCaseData( - Ordinates.None, - new MultiLineString(new[] - { - new LineString(new[] - { - new Coordinate(1d, 1d), - new Coordinate(1d, 2500d) - }) - }), - "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))" - ); + ) + }), + "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))") + .SetName("MultiPolygon"), - yield return new TestCaseData( - Ordinates.None, + new TestCaseData(Ordinates.None, GeometryCollection.Empty, "st_geomfromtext('GEOMETRYCOLLECTION EMPTY')") + .SetName("EmptyCollection"), + + new TestCaseData( + Ordinates.None, + new GeometryCollection(new Geometry[] + { + new Point(new Coordinate(1d, 1d)), new MultiPolygon(new[] { new Polygon( @@ -77,18 +90,16 @@ public static IEnumerable TestCases { new Coordinate(1d, 1d) }) ) - }), - "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))" - ); + }) + }), + "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))") + .SetName("Collection"), - yield return new TestCaseData( - Ordinates.None, - GeometryCollection.Empty, - "st_geomfromtext('GEOMETRYCOLLECTION EMPTY')" - ); - - yield return new TestCaseData( - Ordinates.None, + new TestCaseData( + Ordinates.None, + new GeometryCollection(new Geometry[] + { + new Point(new Coordinate(1d, 1d)), new GeometryCollection(new Geometry[] { new Point(new Coordinate(1d, 1d)), @@ -104,151 +115,233 @@ public static IEnumerable TestCases { }) ) }) - }), - "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))" - ); + }) + }), + "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))") + .SetName("CollectionNested"), - yield return new TestCaseData( - Ordinates.None, - new GeometryCollection(new Geometry[] + new TestCaseData(Ordinates.XYZ, new Point(1d, 2d, 3d), "st_makepoint(1,2,3)") + .SetName("PointXYZ"), + + new TestCaseData( + Ordinates.XYZM, + new Point( + new DotSpatialAffineCoordinateSequence(new[] { 1d, 2d }, new[] { 3d }, new[] { 4d }), + GeometryFactory.Default), + "st_makepoint(1,2,3,4)") + .SetName("PointXYZM"), + + new TestCaseData( + Ordinates.None, + new LinearRing(new[] { - new Point(new Coordinate(1d, 1d)), - new GeometryCollection(new Geometry[] - { - new Point(new Coordinate(1d, 1d)), - new MultiPolygon(new[] - { - new Polygon( - new LinearRing(new[] - { - new Coordinate(1d, 1d), - new Coordinate(2d, 2d), - new Coordinate(3d, 3d), - new Coordinate(1d, 1d) - }) - ) - }) - }) + new Coordinate(1d, 1d), + new Coordinate(2d, 2d), + new Coordinate(3d, 3d), + new Coordinate(1d, 1d) }), - "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))" - ); - - yield return new TestCaseData(Ordinates.XYZ, new Point(1d, 2d, 3d), "st_makepoint(1,2,3)"); - - yield return new TestCaseData( - Ordinates.XYZM, - new Point( - new DotSpatialAffineCoordinateSequence(new[] { 1d, 2d }, new[] { 3d }, new[] { 4d }), - GeometryFactory.Default), - "st_makepoint(1,2,3,4)" - ); - } - } + "st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])") + .SetName("LinearRing") + }; - [Test, TestCaseSource(nameof(TestCases))] - public void TestRead(Ordinates ordinates, Geometry geometry, string sqlRepresentation) - { - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = $"SELECT {sqlRepresentation}"; - Assert.That(Equals(cmd.ExecuteScalar(), geometry)); - } - } + [Test, TestCaseSource(nameof(TestCases))] + public async Task Read(Ordinates ordinates, Geometry geometry, string sqlRepresentation) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT {sqlRepresentation}"; + Assert.That(Equals(cmd.ExecuteScalar(), geometry)); + } + + [Test, TestCaseSource(nameof(TestCases))] + public async Task Write(Ordinates ordinates, Geometry geometry, string sqlRepresentation) + { + await using var conn = await OpenConnectionAsync(handleOrdinates: ordinates); + await using var cmd = conn.CreateCommand(); + cmd.Parameters.AddWithValue("p1", geometry); + cmd.CommandText = $"SELECT st_asewkb(@p1) = st_asewkb({sqlRepresentation})"; + Assert.That(cmd.ExecuteScalar(), Is.True); + } - [Test, TestCaseSource(nameof(TestCases))] - public void TestWrite(Ordinates ordinates, Geometry geometry, string sqlRepresentation) + [Test] + public async Task Array() + { + var point = new Point(new Coordinate(1d, 1d)); + + await AssertType( + DataSource, + new Geometry[] { point }, + '{' + GetSqlLiteral(point) + '}', + "geometry[]", + NpgsqlDbType.Geometry | NpgsqlDbType.Array, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Read_as_concrete_type() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT st_makepoint(1,1)", conn); + await using var reader = cmd.ExecuteReader(); + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(new Point(new Coordinate(1d, 1d)))); + Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Roundtrip_geometry_geography() + { + var point = new Point(new Coordinate(1d, 1d)); + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); + await using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) { - using (var conn = OpenConnection(handleOrdinates: ordinates)) - using (var cmd = conn.CreateCommand()) - { - cmd.Parameters.AddWithValue("p1", geometry); - cmd.CommandText = $"SELECT st_asewkb(@p1) = st_asewkb({sqlRepresentation})"; - Assert.That(cmd.ExecuteScalar(), Is.True); - } + cmd.Parameters.AddWithValue("@p", point); + cmd.ExecuteNonQuery(); } - [Test] - public void TestArrayRead() + await using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) + await using (var reader = cmd.ExecuteReader()) { - using (var conn = OpenConnection(handleOrdinates: Ordinates.XY)) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "SELECT ARRAY(SELECT st_makepoint(1,1))"; - var result = cmd.ExecuteScalar(); - Assert.That(result, Is.InstanceOf()); - Assert.That(result, Is.EquivalentTo(new[] { new Point(new Coordinate(1d, 1d)) })); - } + reader.Read(); + Assert.That(reader[0], Is.EqualTo(point)); + Assert.That(reader[1], Is.EqualTo(point)); } + } + + [Test, Explicit] + public async Task Concurrency_test() + { + await using var adminConnection = await OpenConnectionAsync(); + var table = await CreateTempTable( + adminConnection, + "point GEOMETRY, linestring GEOMETRY, polygon GEOMETRY, " + + "multipoint GEOMETRY, multilinestring GEOMETRY, multipolygon GEOMETRY, " + + "collection GEOMETRY"); + await adminConnection.ExecuteNonQueryAsync($"INSERT INTO {table} DEFAULT VALUES"); - [Test] - public void TestArrayWrite() + var point = new Point(new Coordinate(1d, 1d)); + var lineString = new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }); + var polygon = new Polygon( + new LinearRing(new[] + { + new Coordinate(1d, 1d), + new Coordinate(2d, 2d), + new Coordinate(3d, 3d), + new Coordinate(1d, 1d) + }) + ); + var multiPoint = new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }); + var multiLineString = new MultiLineString(new[] { - using (var conn = OpenConnection(handleOrdinates: Ordinates.XY)) - using (var cmd = conn.CreateCommand()) + new LineString(new[] { - cmd.Parameters.AddWithValue("@p1", new[] { new Point(new Coordinate(1d, 1d)) }); - cmd.CommandText = "SELECT @p1 = array(select st_makepoint(1,1))"; - Assert.That(cmd.ExecuteScalar(), Is.True); - } - } - - [Test] - public void ReadAsConcreteType() + new Coordinate(1d, 1d), + new Coordinate(1d, 2500d) + }) + }); + var multiPolygon = new MultiPolygon(new[] + { + new Polygon( + new LinearRing(new[] + { + new Coordinate(1d, 1d), + new Coordinate(2d, 2d), + new Coordinate(3d, 3d), + new Coordinate(1d, 1d) + }) + ) + }); + var collection = new GeometryCollection(new Geometry[] { - using (var conn = OpenConnection(handleOrdinates: Ordinates.XY)) - using (var cmd = new NpgsqlCommand("SELECT st_makepoint(1,1)", conn)) - using (var reader = cmd.ExecuteReader()) + new Point(new Coordinate(1d, 1d)), + new MultiPolygon(new[] { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new Point(new Coordinate(1d, 1d)))); - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); - } - } + new Polygon( + new LinearRing(new[] + { + new Coordinate(1d, 1d), + new Coordinate(2d, 2d), + new Coordinate(3d, 3d), + new Coordinate(1d, 1d) + }) + ) + }) + }); - [Test] - public void RoundtripGeometryGeography() + await Task.WhenAll(Enumerable.Range(0, 30).Select(i => Task.Run(async () => { - var point = new Point(new Coordinate(1d, 1d)); - using (var conn = OpenConnection(handleOrdinates: Ordinates.XY)) + for (var i = 0; i < 1000; i++) { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (geom GEOMETRY, geog GEOGRAPHY)"); - using (var cmd = new NpgsqlCommand("INSERT INTO data (geom, geog) VALUES (@p, @p)", conn)) + await using var connection = OpenConnection(); + + await using (var cmd = new NpgsqlCommand()) { - cmd.Parameters.AddWithValue("@p", point); - cmd.ExecuteNonQuery(); + cmd.Connection = connection; + cmd.CommandText = + $"UPDATE {table} SET point=$1, linestring=$2, polygon=$3, multipoint=$4, multilinestring=$5, multipolygon=$6, collection=$7"; + cmd.Parameters.Add(new() { Value = point }); + cmd.Parameters.Add(new() { Value = lineString }); + cmd.Parameters.Add(new() { Value = polygon }); + cmd.Parameters.Add(new() { Value = multiPoint }); + cmd.Parameters.Add(new() { Value = multiLineString }); + cmd.Parameters.Add(new() { Value = multiPolygon }); + cmd.Parameters.Add(new() { Value = collection }); + await cmd.ExecuteNonQueryAsync(); } - using (var cmd = new NpgsqlCommand("SELECT geom, geog FROM data", conn)) - using (var reader = cmd.ExecuteReader()) + await using (var cmd = new NpgsqlCommand($"SELECT * FROM {table}", connection)) + await using (var reader = await cmd.ExecuteReaderAsync()) { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(point)); - Assert.That(reader[1], Is.EqualTo(point)); + await reader.ReadAsync(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(point)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(lineString)); + Assert.That(reader.GetFieldValue(2), Is.EqualTo(polygon)); + Assert.That(reader.GetFieldValue(3), Is.EqualTo(multiPoint)); + Assert.That(reader.GetFieldValue(4), Is.EqualTo(multiLineString)); + Assert.That(reader.GetFieldValue(5), Is.EqualTo(multiPolygon)); + Assert.That(reader.GetFieldValue(6), Is.EqualTo(collection)); } } - } + }))); + } - protected override NpgsqlConnection OpenConnection(string? connectionString = null) - => OpenConnection(connectionString); + protected ValueTask OpenConnectionAsync(string? connectionString = null, Ordinates handleOrdinates = Ordinates.None) + { + if (handleOrdinates == Ordinates.None) + handleOrdinates = Ordinates.XY; - protected NpgsqlConnection OpenConnection(string? connectionString = null, Ordinates handleOrdinates = Ordinates.None) + var dataSource = NtsDataSources.GetOrAdd(handleOrdinates, o => { - if (handleOrdinates == Ordinates.None) - handleOrdinates = Ordinates.XY; - - var conn = base.OpenConnection(connectionString); - conn.TypeMapper.UseNetTopologySuite( + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseNetTopologySuite( new DotSpatialAffineCoordinateSequenceFactory(handleOrdinates), - handleOrdinates: handleOrdinates); - return conn; - } + handleOrdinates: o); + return dataSourceBuilder.Build(); + }); - [OneTimeSetUp] - public async Task SetUp() - { - using var conn = await base.OpenConnectionAsync(); - await TestUtil.EnsurePostgis(conn); - } + if (handleOrdinates == Ordinates.XY) + _xyDataSource ??= dataSource; + + return dataSource.OpenConnectionAsync(); + } + + static string GetSqlLiteral(Geometry geometry) + => string.Join("", geometry.ToBinary().Select(b => $"{b:X2}")); + + [OneTimeSetUp] + public async Task SetUp() + { + var connection = await OpenConnectionAsync(handleOrdinates: Ordinates.XY); + await EnsurePostgis(connection); } + + [OneTimeTearDown] + public async Task Teardown() + => await Task.WhenAll(NtsDataSources.Values.Select(async ds => await ds.DisposeAsync())); + + protected override NpgsqlDataSource DataSource => _xyDataSource ?? throw new InvalidOperationException(); + NpgsqlDataSource? _xyDataSource; + + ConcurrentDictionary NtsDataSources = new(); } diff --git a/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs new file mode 100644 index 0000000000..59f581e7de --- /dev/null +++ b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs @@ -0,0 +1,312 @@ +using System; +using System.Threading.Tasks; +using NodaTime; +using Npgsql.Tests; +using Npgsql.Util; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.NodaTime.Internal.NodaTimeUtils; + +namespace Npgsql.PluginTests; + +[TestFixture(false)] +#if DEBUG +[TestFixture(true)] +[NonParallelizable] // Since this test suite manipulates an AppContext switch +#endif +public class NodaTimeInfinityTests : TestBase, IDisposable +{ + [Test] // #4715 + public async Task DateRange_with_upper_bound_infinity() + { + if (Statics.DisableDateTimeInfinityConversions) + return; + + await AssertType( + new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue), + "[-infinity,infinity]", + "daterange", + NpgsqlDbType.DateRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, + """{"[-infinity,infinity]"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, + """{[-infinity,infinity]}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } + + [Test] + public async Task Timestamptz_read_values() + { + if (Statics.DisableDateTimeInfinityConversions) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = + new NpgsqlCommand("SELECT 'infinity'::timestamp with time zone, '-infinity'::timestamp with time zone", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader.GetFieldValue(0), Is.EqualTo(Instant.MaxValue)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(DateTime.MaxValue)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(Instant.MinValue)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(DateTime.MinValue)); + } + + [Test] + public async Task Timestamptz_write_values() + { + if (Statics.DisableDateTimeInfinityConversions) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1::text, $2::text, $3::text, $4::text", conn) + { + Parameters = + { + new() { Value = Instant.MaxValue }, + new() { Value = DateTime.MaxValue }, + new() { Value = Instant.MinValue }, + new() { Value = DateTime.MinValue } + } + }; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader[0], Is.EqualTo("infinity")); + Assert.That(reader[1], Is.EqualTo("infinity")); + Assert.That(reader[2], Is.EqualTo("-infinity")); + Assert.That(reader[3], Is.EqualTo("-infinity")); + } + + [Test] + public async Task Timestamptz_write() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand("SELECT ($1 AT TIME ZONE 'UTC')::text", conn) + { + Parameters = { new() { Value = Instant.MinValue, NpgsqlDbType = NpgsqlDbType.TimestampTz } } + }; + + if (Statics.DisableDateTimeInfinityConversions) + { + // NodaTime Instant.MinValue is outside the PG timestamp range. + Assert.That(async () => await cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.DatetimeFieldOverflow)); + } + else + { + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("-infinity")); + } + + await using var cmd2 = new NpgsqlCommand("SELECT ($1 AT TIME ZONE 'UTC')::text", conn) + { + Parameters = { new() { Value = Instant.MaxValue, NpgsqlDbType = NpgsqlDbType.TimestampTz } } + }; + + Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions ? "9999-12-31 23:59:59.999999" : "infinity")); + } + + [Test] + public async Task Timestamptz_read() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( + "SELECT '-infinity'::timestamp with time zone, 'infinity'::timestamp with time zone", conn); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + Assert.That(() => reader[1], Throws.Exception.TypeOf()); + } + else + { + Assert.That(reader[0], Is.EqualTo(Instant.MinValue)); + Assert.That(reader[1], Is.EqualTo(Instant.MaxValue)); + } + } + + [Test] + public async Task Timestamp_write() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand("SELECT $1::text", conn) + { + Parameters = { new() { Value = LocalDateTime.MinIsoValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } + }; + + if (Statics.DisableDateTimeInfinityConversions) + { + // NodaTime LocalDateTime.MinValue is outside the PG timestamp range. + Assert.That(async () => await cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.DatetimeFieldOverflow)); + } + else + { + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("-infinity")); + } + + await using var cmd2 = new NpgsqlCommand("SELECT $1::text", conn) + { + Parameters = { new() { Value = LocalDateTime.MaxIsoValue, NpgsqlDbType = NpgsqlDbType.Timestamp } } + }; + + Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions + ? "9999-12-31 23:59:59.999999" + : "infinity")); + } + + [Test] + public async Task Timestamp_read() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( + "SELECT '-infinity'::timestamp without time zone, 'infinity'::timestamp without time zone", conn); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + Assert.That(() => reader[1], Throws.Exception.TypeOf()); + } + else + { + Assert.That(reader[0], Is.EqualTo(LocalDateTime.MinIsoValue)); + Assert.That(reader[1], Is.EqualTo(LocalDateTime.MaxIsoValue)); + } + } + + [Test] + public async Task Date_write() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand("SELECT $1::text", conn) + { + Parameters = { new() { Value = LocalDate.MinIsoValue, NpgsqlDbType = NpgsqlDbType.Date } } + }; + + // LocalDate.MinIsoValue is outside of the PostgreSQL date range + if (Statics.DisableDateTimeInfinityConversions) + Assert.That(async () => await cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.DatetimeFieldOverflow)); + else + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("-infinity")); + + cmd.Parameters[0].Value = LocalDate.MaxIsoValue; + + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(Statics.DisableDateTimeInfinityConversions ? "9999-12-31" : "infinity")); + } + + [Test] + public async Task Date_read() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand("SELECT '-infinity'::date, 'infinity'::date", conn); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + Assert.That(() => reader[1], Throws.Exception.TypeOf()); + } + else + { + Assert.That(reader[0], Is.EqualTo(LocalDate.MinIsoValue)); + Assert.That(reader[1], Is.EqualTo(LocalDate.MaxIsoValue)); + } + } + + [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] + public async Task DateConvertInfinity() + { + if (Statics.DisableDateTimeInfinityConversions) + return; + + await using var conn = await OpenConnectionAsync(); + conn.ExecuteNonQuery("CREATE TEMP TABLE data (d1 DATE, d2 DATE, d3 DATE, d4 DATE)"); + + using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4)", conn)) + { + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Date, LocalDate.MaxIsoValue); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Date, LocalDate.MinIsoValue); + cmd.Parameters.AddWithValue("p3", NpgsqlDbType.Date, DateTime.MaxValue); + cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Date, DateTime.MinValue); + cmd.ExecuteNonQuery(); + } + + using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT FROM data", conn)) + using (var reader = cmd.ExecuteReader()) + { + reader.Read(); + Assert.That(reader.GetValue(0), Is.EqualTo("infinity")); + Assert.That(reader.GetValue(1), Is.EqualTo("-infinity")); + Assert.That(reader.GetValue(2), Is.EqualTo("infinity")); + Assert.That(reader.GetValue(3), Is.EqualTo("-infinity")); + } + + using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) + using (var reader = cmd.ExecuteReader()) + { + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(LocalDate.MaxIsoValue)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(LocalDate.MinIsoValue)); + Assert.That(reader.GetFieldValue(2), Is.EqualTo(DateTime.MaxValue)); + Assert.That(reader.GetFieldValue(3), Is.EqualTo(DateTime.MinValue)); + } + } + + protected override NpgsqlDataSource DataSource { get; } + + public NodaTimeInfinityTests(bool disableDateTimeInfinityConversions) + { +#if DEBUG + Statics.DisableDateTimeInfinityConversions = disableDateTimeInfinityConversions; +#else + if (disableDateTimeInfinityConversions) + { + Assert.Ignore( + "NodaTimeInfinityTests rely on the Npgsql.DisableDateTimeInfinityConversions AppContext switch and can only be run in DEBUG builds"); + } +#endif + + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin"; + DataSource = builder.Build(); + } + + public void Dispose() + { +#if DEBUG + Statics.DisableDateTimeInfinityConversions = false; +#endif + + DataSource.Dispose(); + } +} diff --git a/test/Npgsql.PluginTests/NodaTimeTests.cs b/test/Npgsql.PluginTests/NodaTimeTests.cs index 815e841672..adccd163cc 100644 --- a/test/Npgsql.PluginTests/NodaTimeTests.cs +++ b/test/Npgsql.PluginTests/NodaTimeTests.cs @@ -1,377 +1,727 @@ using System; using System.Data; -using System.Globalization; +using System.Threading.Tasks; using NodaTime; +using Npgsql.NodaTime.Properties; using Npgsql.Tests; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; // ReSharper disable AccessToModifiedClosure // ReSharper disable AccessToDisposedClosure -namespace Npgsql.PluginTests +namespace Npgsql.PluginTests; + +public class NodaTimeTests : MultiplexingTestBase, IDisposable { - public class NodaTimeTests : TestBase + #region Timestamp without time zone + + static readonly TestCaseData[] TimestampValues = { - #region Timestamp + new TestCaseData(new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), "1998-04-12 13:26:38.789") + .SetName("Timestamp_pre2000"), + new TestCaseData(new LocalDateTime(2015, 1, 27, 8, 45, 12, 345), "2015-01-27 08:45:12.345") + .SetName("Timestamp_post2000"), + new TestCaseData(new LocalDateTime(1999, 12, 31, 23, 59, 59, 999).PlusNanoseconds(456000), "1999-12-31 23:59:59.999456") + .SetName("Timestamp_with_microseconds") + }; + + [Test, TestCaseSource(nameof(TimestampValues))] + public Task Timestamp_as_LocalDateTime(LocalDateTime localDateTime, string sqlLiteral) + => AssertType(localDateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamp_as_unspecified_DateTime() + => AssertType( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), + "1998-04-12 13:26:38", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime2, + isDefaultForReading: false); + + [Test] + public Task Timestamp_as_long() + => AssertType( + -54297202000000, + "1998-04-12 13:26:38", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime2, + isDefault: false); + + [Test] + public Task Timestamp_cannot_use_as_Instant() + => AssertTypeUnsupported( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), + "1998-04-12 13:26:38.789", + "timestamp without time zone"); + + [Test] + public Task Timestamp_cannot_use_as_ZonedDateTime() + => AssertTypeUnsupported( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc(), + "1998-04-12 13:26:38.789", + "timestamp without time zone"); + + [Test] + public Task Timestamp_cannot_use_as_OffsetDateTime() + => AssertTypeUnsupported( + new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).WithOffset(Offset.FromHours(2)), + "1998-04-12 13:26:38.789", + "timestamp without time zone"); + + [Test] + public Task Timestamp_cannot_use_as_DateTimeOffset() + => AssertTypeUnsupported( + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + "1998-04-12 13:26:38", + "timestamp without time zone"); + + [Test] + public Task Timestamp_cannot_write_utc_DateTime() + => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); + + [Test] + public async Task Tsrange_as_NpgsqlRange_of_LocalDateTime() + { + await AssertType( + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), + """["1998-04-12 13:26:38","1998-04-12 15:26:38"]""", + "tsrange", + NpgsqlDbType.TimestampRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), }, + """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]"}""", + "tsrange[]", + NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), }, + """{["1998-04-12 13:26:38","1998-04-12 15:26:38"]}""", + "tsmultirange", + NpgsqlDbType.TimestampMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } - static readonly TestCaseData[] TimestampCases = { - new TestCaseData(new LocalDateTime(1998, 4, 12, 13, 26, 38, 789)).SetName(nameof(Timestamp) + "Pre2000"), - new TestCaseData(new LocalDateTime(2015, 1, 27, 8, 45, 12, 345)).SetName(nameof(Timestamp) + "Post2000"), - new TestCaseData(new LocalDateTime(1999, 12, 31, 23, 59, 59, 999).PlusNanoseconds(456000)).SetName(nameof(Timestamp) + "Microseconds"), - }; + [Test] + public async Task Tsmultirange_as_array_of_NpgsqlRange_of_LocalDateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - [Test, TestCaseSource(nameof(TimestampCases))] - public void Timestamp(LocalDateTime localDateTime) - { - using (var conn = OpenConnection()) + await AssertType( + new[] { - var instant = localDateTime.InUtc().ToInstant(); - var minTimestampPostgres = Instant.FromUtc(-4713, 12, 31, 00, 00, 00); - var maxTimestampPostgres = Instant.MaxValue; - var dateTime = new DateTime(2020, 03, 04, 12, 20, 44, 0, DateTimeKind.Utc); - - conn.ExecuteNonQuery("CREATE TEMP TABLE data (d1 TIMESTAMP, d2 TIMESTAMP, d3 TIMESTAMP, d4 TIMESTAMP, d5 TIMESTAMP, d6 TIMESTAMP, d7 TIMESTAMP, d8 TIMESTAMP)"); - - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8)", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Timestamp) { Value = instant }); - cmd.Parameters.Add(new NpgsqlParameter("p2", DbType.DateTime) { Value = instant }); - cmd.Parameters.Add(new NpgsqlParameter("p3", DbType.DateTime2) { Value = instant }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p4", Value = instant }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p5", Value = localDateTime }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p6", Value = minTimestampPostgres }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p7", Value = maxTimestampPostgres }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p8", Value = dateTime }); - cmd.ExecuteNonQuery(); - } - - // Make sure the values inserted are the good ones, textually - using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT, d5::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - for (var i = 0; i < reader.FieldCount; i++) - Assert.That(reader.GetValue(i), Is.EqualTo(instant.ToString("yyyy'-'MM'-'dd' 'HH':'mm':'ss'.'FFFFFF", CultureInfo.InvariantCulture))); - } - - using (var cmd = new NpgsqlCommand("SELECT d6::TEXT, d7::TEXT, d8::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo("4714-12-31 00:00:00 BC")); - Assert.That(reader.GetValue(1), Is.EqualTo(maxTimestampPostgres.ToString("yyyy'-'MM'-'dd' 'HH':'mm':'ss'.'FFFFFF", CultureInfo.InvariantCulture))); - Assert.That(reader.GetValue(2), Is.EqualTo("2020-03-04 12:20:44")); - } - - using (var cmd = new NpgsqlCommand("SELECT d1, d2, d3, d4, d5 FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - for (var i = 0; i < reader.FieldCount; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Instant))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(instant)); - Assert.That(reader.GetValue(i), Is.EqualTo(instant)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(localDateTime)); - Assert.That(() => reader.GetFieldValue(i), Throws.TypeOf()); - Assert.That(() => reader.GetDateTime(i), Is.EqualTo(localDateTime.ToDateTimeUnspecified())); - Assert.That(() => reader.GetFieldValue(i), Is.EqualTo(localDateTime.ToDateTimeUnspecified())); - Assert.That(() => reader.GetDate(i), Throws.TypeOf()); - } - } - } - } + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38), + new(1998, 4, 12, 15, 26, 38)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38), + new(1998, 4, 13, 15, 26, 38)), + }, + """{["1998-04-12 13:26:38","1998-04-12 15:26:38"],["1998-04-13 13:26:38","1998-04-13 15:26:38"]}""", + "tsmultirange", + NpgsqlDbType.TimestampMultirange, + isNpgsqlDbTypeInferredFromClrType: false); + } - [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] - public void TimestampConvertInfinity() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { ConvertInfinityDateTime = true }; - using (var conn = OpenConnection(csb)) + #endregion Timestamp without time zone + + #region Timestamp with time zone + + static readonly TestCaseData[] TimestamptzValues = + { + new TestCaseData(new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), "1998-04-12 15:26:38+02") + .SetName("Timestamptz_pre2000"), + new TestCaseData(new LocalDateTime(2015, 1, 27, 8, 45, 12, 345).InUtc().ToInstant(), "2015-01-27 09:45:12.345+01") + .SetName("Timestamptz_post2000"), + new TestCaseData(new LocalDateTime(2013, 7, 25, 0, 0, 0).InUtc().ToInstant(), "2013-07-25 02:00:00+02") + .SetName("Timestamptz_write_date_only"), + new TestCaseData(new LocalDateTime(1999, 12, 31, 23, 59, 59, 999).PlusNanoseconds(456000).InUtc().ToInstant(), "2000-01-01 00:59:59.999456+01") + .SetName("Timestamptz_with_microseconds") + }; + + [Test, TestCaseSource(nameof(TimestamptzValues))] + public Task Timestamptz_as_Instant(Instant instant, string sqlLiteral) + => AssertType(instant, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Timestamptz_as_ZonedDateTime() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false); + + [Test] + public Task Timestamptz_as_OffsetDateTime() + => AssertType( + new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false); + + [Test] + public Task Timestamptz_as_utc_DateTime() + => AssertType( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isDefaultForReading: false); + + [Test] + public Task Timestamptz_as_DateTimeOffset() + => AssertType( + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isDefaultForReading: false); + + [Test] + public Task Timestamptz_as_long() + => AssertType( + -54297202000000, + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isDefault: false); + + [Test] + public Task Timestamptz_cannot_use_as_LocalDateTime() + => AssertTypeUnsupported(new LocalDateTime(1998, 4, 12, 13, 26, 38), "1998-04-12 13:26:38Z", "timestamp with time zone"); + + [Test] + public async Task Timestamptz_cannot_write_non_utc_ZonedDateTime() + => await AssertTypeUnsupportedWrite( + new LocalDateTime().InUtc().ToInstant().InZone(DateTimeZoneProviders.Tzdb["Europe/Berlin"]), + "timestamp with time zone"); + + [Test] + public async Task Timestamptz_cannot_write_non_utc_OffsetDateTime() + => await AssertTypeUnsupportedWrite(new LocalDateTime().WithOffset(Offset.FromHours(2)), "timestamp with time zone"); + + [Test] + public async Task Timestamptz_cannot_write_non_utc_DateTime() + { + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); + } + + [Test] + public async Task Tstzrange_as_Interval() + { + await AssertType( + new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } + + [Test] + public Task Tstzrange_with_no_end_as_Interval() + => AssertType( + new Interval(new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), null), + """["1998-04-12 15:26:38+02",)""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + + [Test] + public Task Tstzrange_with_no_start_as_Interval() + => AssertType( + new Interval(null, new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant()), + """(,"1998-04-12 15:26:38+02")""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + + [Test] + public Task Tstzrange_with_no_start_or_end_as_Interval() + => AssertType( + new Interval(null, null), + """(,)""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + + [Test] + public Task Tstzrange_as_NpgsqlRange_of_Instant() + => AssertType( + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); + + [Test] + public Task Tstzrange_as_NpgsqlRange_of_ZonedDateTime() + => AssertType( + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc()), + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); + + [Test] + public Task Tstzrange_as_NpgsqlRange_of_OffsetDateTime() + => AssertType( + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), + new LocalDateTime(1998, 4, 12, 15, 26, 38).WithOffset(Offset.Zero)), + """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); + + [Test] + public async Task Tstzmultirange_as_array_of_Interval() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (d1 TIMESTAMP, d2 TIMESTAMP, d3 TIMESTAMP, d4 TIMESTAMP)"); - - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4)", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Timestamp, Instant.MaxValue); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Timestamp, Instant.MinValue); - cmd.Parameters.AddWithValue("p3", NpgsqlDbType.Timestamp, DateTime.MaxValue); - cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Timestamp, DateTime.MinValue); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo("infinity")); - Assert.That(reader.GetValue(1), Is.EqualTo("-infinity")); - Assert.That(reader.GetValue(2), Is.EqualTo("infinity")); - Assert.That(reader.GetValue(3), Is.EqualTo("-infinity")); - } - - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(Instant.MaxValue)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(Instant.MinValue)); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(DateTime.MaxValue)); - Assert.That(reader.GetFieldValue(3), Is.EqualTo(DateTime.MinValue)); - } - } - } + new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + new Interval( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), + }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"),["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02")}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false); + } - #endregion Timestamp + [Test] + public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_Instant() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - #region Timestamp with time zone + await AssertType( + new[] + { + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + new NpgsqlRange( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), + }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false); + } - [Test] - public void TimestampTz() - { - using (var conn = OpenConnection()) + [Test] + public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_ZonedDateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - var timezone = "America/New_York"; - conn.ExecuteNonQuery($"SET TIMEZONE TO '{timezone}'"); - Assert.That(conn.Timezone, Is.EqualTo(timezone)); - // Nodatime provider should return timestamptz's as ZonedDateTime in the session timezone - - var instant = Instant.FromUtc(2015, 6, 27, 8, 45, 12) + Duration.FromMilliseconds(345); - var utcZonedDateTime = instant.InUtc(); - var localZonedDateTime = utcZonedDateTime.WithZone(DateTimeZoneProviders.Tzdb[timezone]); - var offsetDateTime = localZonedDateTime.ToOffsetDateTime(); - var dateTimeOffset = offsetDateTime.ToDateTimeOffset(); - var dateTime = dateTimeOffset.DateTime; - var localDateTime = dateTimeOffset.LocalDateTime; - - conn.ExecuteNonQuery("CREATE TEMP TABLE data (d1 TIMESTAMPTZ, d2 TIMESTAMPTZ, d3 TIMESTAMPTZ, d4 TIMESTAMPTZ, d5 TIMESTAMPTZ, d6 TIMESTAMPTZ)"); - - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4, @p5, @p6)", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.TimestampTz) { Value = instant }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = utcZonedDateTime }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p3", Value = localZonedDateTime }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p4", Value = offsetDateTime }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p5", Value = dateTimeOffset }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p6", Value = dateTime }); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT, d5::TEXT, d6::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - // When converting timestamptz as a string as we're doing here, PostgreSQL automatically converts - // it to the session timezone - for (var i = 0; i < reader.FieldCount; i++) - Assert.That(reader.GetValue(i), Is.EqualTo( - localZonedDateTime.ToString("uuuu'-'MM'-'dd' 'HH':'mm':'ss'.'fff", CultureInfo.InvariantCulture) + "-04") - ); - } - - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - for (var i = 0; i < reader.FieldCount; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Instant))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(instant)); - Assert.That(reader.GetValue(i), Is.EqualTo(instant)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(localZonedDateTime)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(offsetDateTime)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTimeOffset)); - Assert.That(() => reader.GetFieldValue(i), Throws.TypeOf()); - Assert.That(() => reader.GetDateTime(i), Is.EqualTo(localDateTime)); - Assert.That(() => reader.GetDate(i), Throws.TypeOf()); - } - } - } - } + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc()), + new NpgsqlRange( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc(), + new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc()), + }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false); + } + + [Test] + public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_OffsetDateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - #endregion Timestamp with time zone + await AssertType( + new[] + { + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), + new LocalDateTime(1998, 4, 12, 15, 26, 38).WithOffset(Offset.Zero)), + new NpgsqlRange( + new LocalDateTime(1998, 4, 13, 13, 26, 38).WithOffset(Offset.Zero), + new LocalDateTime(1998, 4, 13, 15, 26, 38).WithOffset(Offset.Zero)), + }, + """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false); + } - #region Date + [Test] + public async Task Tstzrange_array_as_array_of_Interval() + { + await using var conn = await OpenConnectionAsync(); - [Test] - public void Date() - { - using (var conn = OpenConnection()) + await AssertType( + new[] { - var localDate = new LocalDate(2002, 3, 4); - var dateTime = new DateTime(localDate.Year, localDate.Month, localDate.Day); - - using (var cmd = new NpgsqlCommand("CREATE TEMP TABLE data (d1 DATE, d2 DATE, d3 DATE, d4 DATE, d5 DATE)", conn)) - cmd.ExecuteNonQuery(); - - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4, @p5)", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Date) { Value = localDate }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = localDate }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p3", Value = new LocalDate(-5, 3, 3) }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p4", Value = dateTime }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p5", Value = dateTime, NpgsqlDbType = NpgsqlDbType.Date }); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT, d5::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo("2002-03-04")); - Assert.That(reader.GetValue(1), Is.EqualTo("2002-03-04")); - Assert.That(reader.GetValue(2), Is.EqualTo("0006-03-03 BC")); - Assert.That(reader.GetValue(3), Is.EqualTo("2002-03-04")); - Assert.That(reader.GetValue(4), Is.EqualTo("2002-03-04")); - } - - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(LocalDate))); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(localDate)); - Assert.That(reader.GetValue(0), Is.EqualTo(localDate)); - Assert.That(() => reader.GetDateTime(0), Is.EqualTo(dateTime)); - Assert.That(() => reader.GetDate(0), Is.EqualTo(new NpgsqlDate(localDate.Year, localDate.Month, localDate.Day))); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(new LocalDate(-5, 3, 3))); - Assert.That(reader.GetFieldValue(3), Is.EqualTo(dateTime)); - Assert.That(reader.GetDateTime(4), Is.EqualTo(dateTime)); - } - } - } + new Interval( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + new Interval( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), + new Interval( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), + null), + new Interval( + null, + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant()), + new Interval( + null, + null) + }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",)","(,\"1998-04-13 15:26:38+02\")","(,)"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForWriting: false); + } - [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] - public void DateConvertInfinity() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { ConvertInfinityDateTime = true }; - using (var conn = OpenConnection(csb)) + [Test] + public async Task Tstzrange_array_as_array_of_NpgsqlRange_of_Instant() + { + await using var conn = await OpenConnectionAsync(); + + await AssertType( + new[] { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (d1 DATE, d2 DATE, d3 DATE, d4 DATE)"); - - using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2, @p3, @p4)", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Date, LocalDate.MaxIsoValue); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Date, LocalDate.MinIsoValue); - cmd.Parameters.AddWithValue("p3", NpgsqlDbType.Date, DateTime.MaxValue); - cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Date, DateTime.MinValue); - cmd.ExecuteNonQuery(); - } - - using (var cmd = new NpgsqlCommand("SELECT d1::TEXT, d2::TEXT, d3::TEXT, d4::TEXT FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo("infinity")); - Assert.That(reader.GetValue(1), Is.EqualTo("-infinity")); - Assert.That(reader.GetValue(2), Is.EqualTo("infinity")); - Assert.That(reader.GetValue(3), Is.EqualTo("-infinity")); - } - - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(LocalDate.MaxIsoValue)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(LocalDate.MinIsoValue)); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(DateTime.MaxValue)); - Assert.That(reader.GetFieldValue(3), Is.EqualTo(DateTime.MinValue)); - } - } - } + new NpgsqlRange( + new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), + new NpgsqlRange( + new LocalDateTime(1998, 4, 13, 13, 26, 38).InUtc().ToInstant(), + new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), + }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isNpgsqlDbTypeInferredFromClrType: false, + isDefault: false); + } - #endregion Date + #endregion Timestamp with time zone - #region Time + #region Date - [Test] - public void Time() - { - using (var conn = OpenConnection()) + [Test] + public Task Date_as_LocalDate() + => AssertType(new LocalDate(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Date_as_DateTime() + => AssertType(new DateTime(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + + [Test] + public Task Date_as_int() + => AssertType(7579, "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + + [Test] + public async Task Daterange_as_DateInterval() + { + await AssertType( + new DateInterval(new(2002, 3, 4), new(2002, 3, 6)), + "[2002-03-04,2002-03-07)", + "daterange", + NpgsqlDbType.DateRange, + isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // DateInterval[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, + """{"[2002-03-04,2002-03-07)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, + """{[2002-03-04,2002-03-07)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + } + + [Test] + public async Task Daterange_as_NpgsqlRange_of_LocalDate() + { + await AssertType( + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + "[2002-03-04,2002-03-06)", + "daterange", + NpgsqlDbType.DateRange, + isNpgsqlDbTypeInferredFromClrType: false, + isDefaultForReading: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{"[2002-03-04,2002-03-06)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{[2002-03-04,2002-03-06)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + } + + [Test] + public async Task Datemultirange_as_array_of_DateInterval() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - var expected = new LocalTime(1, 2, 3, 4).PlusNanoseconds(5000); - var timeSpan = new TimeSpan(0, 1, 2, 3, 4).Add(TimeSpan.FromTicks(50)); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Time) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter("p2", DbType.Time) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter("p3", DbType.Time) { Value = timeSpan }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(LocalTime))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(() => reader.GetTimeSpan(i), Is.EqualTo(timeSpan)); - } - } - } - } - } + new DateInterval(new(2002, 3, 4), new(2002, 3, 5)), + new DateInterval(new(2002, 3, 8), new(2002, 3, 10)) + }, + "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", + "datemultirange", + NpgsqlDbType.DateMultirange, + isNpgsqlDbTypeInferredFromClrType: false); + } - #endregion Time + [Test] + public async Task Datemultirange_as_array_of_NpgsqlRange_of_LocalDate() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - #region Time with time zone + await AssertType( + new[] + { + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) + }, + "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", + "datemultirange", + NpgsqlDbType.DateMultirange, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } - [Test] - public void TimeTz() - { - using (var conn = OpenConnection()) +#if NET6_0_OR_GREATER + [Test] + public Task Date_as_DateOnly() + => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForReading: false); + + [Test] + public async Task Daterange_as_NpgsqlRange_of_DateOnly() + { + await AssertType( + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + "[2002-03-04,2002-03-06)", + "daterange", + NpgsqlDbType.DateRange, + isDefaultForReading: false, skipArrayCheck: true); + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{"[2002-03-04,2002-03-06)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false, skipArrayCheck: true); + + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(14, 0)) + return; + + await AssertType( + new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, + """{[2002-03-04,2002-03-06)}""", + "datemultirange", + NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + } +#endif + + [Test] + public async Task Daterange_array_as_array_of_DateInterval() + { + await using var conn = await OpenConnectionAsync(); + + await AssertType( + new[] { - var time = new LocalTime(1, 2, 3, 4).PlusNanoseconds(5000); - var offset = Offset.FromHoursAndMinutes(3, 30) + Offset.FromSeconds(5); - var expected = new OffsetTime(time, offset); - var dateTimeOffset = new DateTimeOffset(0001, 01, 02, 03, 43, 20, TimeSpan.FromHours(3)); - var dateTime = dateTimeOffset.DateTime; - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.TimeTz) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType.TimeTz) { Value = dateTimeOffset }); - cmd.Parameters.Add(new NpgsqlParameter("p4", dateTimeOffset)); - cmd.Parameters.Add(new NpgsqlParameter("p5", NpgsqlDbType.TimeTz) { Value = dateTime }); - cmd.Parameters.Add(new NpgsqlParameter("p6", dateTime)); - - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(OffsetTime))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - for (var i = 2; i < 4; i++) - { - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTimeOffset)); - } - for (var i = 4; i < 6; i++) - { - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTime)); - } - } - } - } - } + new DateInterval(new(2002, 3, 4), new(2002, 3, 5)), + new DateInterval(new(2002, 3, 8), new(2002, 3, 10)) + }, + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefaultForWriting: false); + } - #endregion Time with time zone + [Test] + public async Task Daterange_array_as_array_of_NpgsqlRange_of_LocalDate() + { + await using var conn = await OpenConnectionAsync(); - #region Interval + await AssertType( + new[] + { + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) + }, + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false); + } - [Test] - public void IntervalAsPeriod() - { - // PG has microsecond precision, so sub-microsecond values are stripped - var expectedPeriod = new PeriodBuilder + #endregion Date + + #region Time + + [Test] + public Task Time_as_LocalTime() + => AssertType(new LocalTime(10, 45, 34, 500), "10:45:34.5", "time without time zone", NpgsqlDbType.Time, DbType.Time, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Time_as_TimeSpan() + => AssertType( + new TimeSpan(0, 10, 45, 34, 500), + "10:45:34.5", + "time without time zone", + NpgsqlDbType.Time, + DbType.Time, + isDefault: false); + +#if NET6_0_OR_GREATER + [Test] + public Task Time_as_TimeOnly() + => AssertType( + new TimeOnly(10, 45, 34, 500), + "10:45:34.5", + "time without time zone", + NpgsqlDbType.Time, + DbType.Time, + isDefaultForReading: false); +#endif + + #endregion Time + + #region Time with time zone + + [Test] + public Task TimeTz_as_OffsetTime() + => AssertType( + new OffsetTime(new LocalTime(1, 2, 3, 4).PlusNanoseconds(5000), Offset.FromHoursAndMinutes(3, 30) + Offset.FromSeconds(5)), + "01:02:03.004005+03:30:05", + "time with time zone", + NpgsqlDbType.TimeTz, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public async Task TimeTz_as_DateTimeOffset() + { + await AssertTypeRead( + "13:03:45.51+02", + "time with time zone", + new DateTimeOffset(1, 1, 2, 13, 3, 45, 510, TimeSpan.FromHours(2)), isDefault: false); + + await AssertTypeWrite( + new DateTimeOffset(1, 1, 1, 13, 3, 45, 510, TimeSpan.FromHours(2)), + "13:03:45.51+02", + "time with time zone", + NpgsqlDbType.TimeTz, + isDefault: false); + } + + #endregion Time with time zone + + #region Interval + + [Test] + public Task Interval_as_Period() + => AssertType( + new PeriodBuilder { Years = 1, Months = 2, @@ -382,91 +732,65 @@ public void IntervalAsPeriod() Seconds = 7, Milliseconds = 8, Nanoseconds = 9000 - }.Build().Normalize(); + }.Build().Normalize(), + "1 year 2 mons 25 days 05:06:07.008009", + "interval", + NpgsqlDbType.Interval, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Interval_as_Duration() + => AssertType( + Duration.FromDays(5) + Duration.FromMinutes(4) + Duration.FromSeconds(3) + Duration.FromMilliseconds(2) + + Duration.FromNanoseconds(1000), + "5 days 00:04:03.002001", + "interval", + NpgsqlDbType.Interval, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public async Task Interval_as_Duration_with_months_fails() + { + var exception = await AssertTypeUnsupportedRead("2 months", "interval"); + Assert.That(exception.Message, Is.EqualTo(NpgsqlNodaTimeStrings.CannotReadIntervalWithMonthsAsDuration)); + } - using var conn = OpenConnection(); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Interval) { Value = expectedPeriod }); - cmd.Parameters.AddWithValue("p2", expectedPeriod); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Period))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedPeriod)); - Assert.That(reader.GetValue(i), Is.EqualTo(expectedPeriod)); - } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3438")] + public async Task Bug3438() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - [Test] - public void IntervalAsDuration() - { - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - - // PG has microsecond precision, so sub-microsecond values are stripped - var expected = Duration.FromDays(5) + Duration.FromMinutes(4) + Duration.FromSeconds(3) + Duration.FromMilliseconds(2) + - Duration.FromNanoseconds(1500); - - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Interval) { Value = expected }); - cmd.Parameters.AddWithValue("p2", expected); - using var reader = cmd.ExecuteReader(); - reader.Read(); - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Period))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected - Duration.FromNanoseconds(500))); - } - } + var expected = Duration.FromSeconds(2148); - [Test] - public void IntervalAsTimeSpan() + cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Interval) { Value = expected }); + cmd.Parameters.AddWithValue("p2", expected); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + for (var i = 0; i < 2; i++) { - var expected = new TimeSpan(1, 2, 3, 4, 5); - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Interval) { Value = expected }); - cmd.Parameters.AddWithValue("p2", expected); - using var reader = cmd.ExecuteReader(); - reader.Read(); - - for (var i = 0; i < 2; i++) - { - Assert.That(() => reader.GetTimeSpan(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - } + Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Period))); } + } - [Test] - public void IntervalAsDurationWithMonthsFails() - { - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand("SELECT make_interval(months => 2)", conn); - using var reader = cmd.ExecuteReader(); - reader.Read(); + #endregion Interval - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf().With.Message.EqualTo( - "Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead.")); - } + #region Support - #endregion Interval + protected override NpgsqlDataSource DataSource { get; } - #region Support + public NodaTimeTests(MultiplexingMode multiplexingMode) + : base(multiplexingMode) + { + var builder = CreateDataSourceBuilder(); + builder.UseNodaTime(); + builder.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin"; + DataSource = builder.Build(); + } - protected override NpgsqlConnection OpenConnection(string? connectionString = null) - { - var conn = new NpgsqlConnection(connectionString ?? ConnectionString); - conn.Open(); - conn.TypeMapper.UseNodaTime(); - return conn; - } + public void Dispose() + => DataSource.Dispose(); - #endregion Support - } + #endregion Support } diff --git a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj index eb4a7ab472..30dfb8ea16 100644 --- a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj +++ b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj @@ -1,20 +1,18 @@  - - false - + + - + - + - diff --git a/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs b/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs index 147db93b2d..c92cd069f9 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs @@ -1,17 +1,16 @@ using AdoNet.Specification.Tests; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public sealed class NpgsqlCommandTests : CommandTestBase { - public sealed class NpgsqlCommandTests : CommandTestBase + public NpgsqlCommandTests(NpgsqlDbFactoryFixture fixture) + : base(fixture) { - public NpgsqlCommandTests(NpgsqlDbFactoryFixture fixture) - : base(fixture) - { - } - - // PostgreSQL only supports a single transaction on a given connection at a given time. As a result, - // Npgsql completely ignores DbCommand.Transaction. - public override void ExecuteReader_throws_when_transaction_required() {} - public override void ExecuteReader_throws_when_transaction_mismatched() {} } -} + + // PostgreSQL only supports a single transaction on a given connection at a given time. As a result, + // Npgsql completely ignores DbCommand.Transaction. + public override void ExecuteReader_throws_when_transaction_required() {} + public override void ExecuteReader_throws_when_transaction_mismatched() {} +} \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs b/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs index cfdcb3ae2b..fa71ea0f2f 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs @@ -1,12 +1,11 @@ using AdoNet.Specification.Tests; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public sealed class NpgsqlConnectionTests : ConnectionTestBase { - public sealed class NpgsqlConnectionTests : ConnectionTestBase + public NpgsqlConnectionTests(NpgsqlDbFactoryFixture fixture) + : base(fixture) { - public NpgsqlConnectionTests(NpgsqlDbFactoryFixture fixture) - : base(fixture) - { - } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs index 8437b526c7..356d1da966 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs @@ -1,11 +1,9 @@ using AdoNet.Specification.Tests; -using Xunit; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public sealed class NpgsqlDataReaderTests : DataReaderTestBase { - public sealed class NpgsqlDataReaderTests : DataReaderTestBase - { - public NpgsqlDataReaderTests(NpgsqlSelectValueFixture fixture) - : base(fixture) {} - } -} + public NpgsqlDataReaderTests(NpgsqlSelectValueFixture fixture) + : base(fixture) {} +} \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlDbFactoryFixture.cs b/test/Npgsql.Specification.Tests/NpgsqlDbFactoryFixture.cs index 19d847dc59..6d8fcbad17 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlDbFactoryFixture.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlDbFactoryFixture.cs @@ -2,16 +2,15 @@ using System.Data.Common; using AdoNet.Specification.Tests; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public class NpgsqlDbFactoryFixture : IDbFactoryFixture { - public class NpgsqlDbFactoryFixture : IDbFactoryFixture - { - public DbProviderFactory Factory => NpgsqlFactory.Instance; + public DbProviderFactory Factory => NpgsqlFactory.Instance; - const string DefaultConnectionString = - "Server=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0"; + const string DefaultConnectionString = + "Server=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0"; - public string ConnectionString => - Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; - } -} + public string ConnectionString => + Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; +} \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs b/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs index 5602f9fb45..67f1d9f1b4 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs @@ -5,13 +5,14 @@ using System.Linq; using AdoNet.Specification.Tests; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public class NpgsqlSelectValueFixture : NpgsqlDbFactoryFixture, ISelectValueFixture, IDeleteFixture, IDisposable { - public class NpgsqlSelectValueFixture : NpgsqlDbFactoryFixture, ISelectValueFixture, IDisposable + public NpgsqlSelectValueFixture() { - public NpgsqlSelectValueFixture() - { - Utility.ExecuteNonQuery(this, @"DROP TABLE IF EXISTS select_value; + Utility.ExecuteNonQuery(this, @" +DROP TABLE IF EXISTS select_value; CREATE TABLE select_value ( id INTEGER NOT NULL PRIMARY KEY, @@ -38,34 +39,37 @@ INSERT INTO select_value VALUES (4, NULL, false, '0001-01-01', '0001-01-01', '0001-01-01', 0.000000000000001, 2.23e-308, '33221100-5544-7766-9988-aabbccddeeff', -32768, -2147483648, -9223372036854775808, 1.18e-38, NULL, '00:00:00'), (5, NULL, true, '9999-12-31', '9999-12-31 23:59:59.999', '9999-12-31 23:59:59.999 +14:00', 99999999999999999999.999999999999999, 1.79e308, 'ccddeeff-aabb-8899-7766-554433221100', 32767, 2147483647, 9223372036854775807, 3.40e38, NULL, '23:59:59.999'); "); - } + } - public void Dispose() => Utility.ExecuteNonQuery(this, "DROP TABLE IF EXISTS select_value;"); + public void Dispose() => Utility.ExecuteNonQuery(this, "DROP TABLE IF EXISTS select_value;"); - public string CreateSelectSql(DbType dbType, ValueKind kind) => - $"SELECT \"{dbType.ToString()}\" FROM select_value WHERE id = {(int)kind};"; + public string CreateSelectSql(DbType dbType, ValueKind kind) => + $"SELECT \"{dbType.ToString()}\" FROM select_value WHERE id = {(int)kind};"; - public string CreateSelectSql(byte[] value) => - $@"SELECT E'{string.Join("", value.Select(x => @"\x" + x.ToString("X2")))}'::bytea"; + public string CreateSelectSql(byte[] value) => + $@"SELECT E'{string.Join("", value.Select(x => @"\x" + x.ToString("X2")))}'::bytea"; - public string SelectNoRows => "SELECT 1 WHERE 0 = 1;"; + public string SelectNoRows => "SELECT 1 WHERE 0 = 1;"; - public IReadOnlyCollection SupportedDbTypes { get; } = new ReadOnlyCollection(new[] - { - DbType.Binary, - DbType.Boolean, - DbType.Date, - DbType.DateTime, - DbType.DateTimeOffset, - DbType.Decimal, - DbType.Double, - DbType.Guid, - DbType.Int16, - DbType.Int32, - DbType.Int64, - DbType.Single, - DbType.String, - DbType.Time - }); - } -} + public IReadOnlyCollection SupportedDbTypes { get; } = new ReadOnlyCollection(new[] + { + DbType.Binary, + DbType.Boolean, + DbType.Date, + DbType.DateTime, + DbType.DateTimeOffset, + DbType.Decimal, + DbType.Double, + DbType.Guid, + DbType.Int16, + DbType.Int32, + DbType.Int64, + DbType.Single, + DbType.String, + DbType.Time + }); + + public Type NullValueExceptionType => typeof(InvalidCastException); + + public string DeleteNoRows => "DELETE FROM select_value WHERE 1 = 0"; +} \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/Utility.cs b/test/Npgsql.Specification.Tests/Utility.cs index b1f9a73934..51bdc18dcd 100644 --- a/test/Npgsql.Specification.Tests/Utility.cs +++ b/test/Npgsql.Specification.Tests/Utility.cs @@ -1,23 +1,21 @@ -using System; using AdoNet.Specification.Tests; -namespace Npgsql.Specification.Tests +namespace Npgsql.Specification.Tests; + +public static class Utility { - public static class Utility + public static void ExecuteNonQuery(IDbFactoryFixture factoryFixture, string sql) { - public static void ExecuteNonQuery(IDbFactoryFixture factoryFixture, string sql) + using (var connection = factoryFixture.Factory.CreateConnection()!) { - using (var connection = factoryFixture.Factory.CreateConnection()!) - { - connection.ConnectionString = factoryFixture.ConnectionString; - connection.Open(); + connection.ConnectionString = factoryFixture.ConnectionString; + connection.Open(); - using (var command = connection.CreateCommand()) - { - command.CommandText = sql; - command.ExecuteNonQuery(); - } + using (var command = connection.CreateCommand()) + { + command.CommandText = sql; + command.ExecuteNonQuery(); } } } -} +} \ No newline at end of file diff --git a/test/Npgsql.Tests/AsyncTests.cs b/test/Npgsql.Tests/AsyncTests.cs index 277bcb93db..3d7ebc3300 100644 --- a/test/Npgsql.Tests/AsyncTests.cs +++ b/test/Npgsql.Tests/AsyncTests.cs @@ -1,55 +1,49 @@ -using System.Data; +using NUnit.Framework; +using System.Data; using System.Threading.Tasks; -using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class AsyncTests : TestBase { - public class AsyncTests : TestBase + [Test] + public async Task NonQuery() { - [Test] - public async Task NonQuery() - { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (int INTEGER)"); - using (var cmd = new NpgsqlCommand("INSERT INTO data (int) VALUES (4)", conn)) - await cmd.ExecuteNonQueryAsync(); - Assert.That(conn.ExecuteScalar("SELECT int FROM data"), Is.EqualTo(4)); - } - } + await using var conn = await OpenConnectionAsync(); + var tableName = await CreateTempTable(conn, "int INTEGER"); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"INSERT INTO {tableName} (int) VALUES (4)"; + await cmd.ExecuteNonQueryAsync(); + Assert.That(await conn.ExecuteScalarAsync($"SELECT int FROM {tableName}"), Is.EqualTo(4)); + } - [Test] - public async Task Scalar() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); - } - } + [Test] + public async Task Scalar() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } - [Test] - public async Task Reader() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - await reader.ReadAsync(); - Assert.That(reader[0], Is.EqualTo(1)); - } - } + [Test] + public async Task Reader() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.That(reader[0], Is.EqualTo(1)); + } - [Test] - public async Task Columnar() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT NULL, 2, 'Some Text'", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - await reader.ReadAsync(); - Assert.That(await reader.IsDBNullAsync(0), Is.True); - Assert.That(await reader.GetFieldValueAsync(2), Is.EqualTo("Some Text")); - } - } + [Test] + public async Task Columnar() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT NULL, 2, 'Some Text'", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + Assert.That(await reader.IsDBNullAsync(0), Is.True); + Assert.That(await reader.GetFieldValueAsync(2), Is.EqualTo("Some Text")); } } diff --git a/test/Npgsql.Tests/AuthenticationTests.cs b/test/Npgsql.Tests/AuthenticationTests.cs new file mode 100644 index 0000000000..5a041a7aca --- /dev/null +++ b/test/Npgsql.Tests/AuthenticationTests.cs @@ -0,0 +1,536 @@ +using System; +using System.Data; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql.Properties; +using Npgsql.Tests.Support; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +public class AuthenticationTests : MultiplexingTestBase +{ + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_UserNameFromEnvironment_Succeeds() + { + using var _ = SetEnvironmentVariable("PGUSER", new NpgsqlConnectionStringBuilder(ConnectionString).Username); + await using var dataSource = CreateDataSource(csb => csb.Username = null); + await using var __ = await dataSource.OpenConnectionAsync(); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_PasswordFromEnvironment_Succeeds() + { + using var _ = SetEnvironmentVariable("PGPASSWORD", new NpgsqlConnectionStringBuilder(ConnectionString).Password); + await using var dataSource = CreateDataSource(csb => csb.Passfile = null); + await using var __ = await dataSource.OpenConnectionAsync(); + } + + [Test] + public async Task Set_Password_on_NpgsqlDataSource() + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + await using var dataSource = dataSourceBuilder.Build(); + + // No password provided + Assert.That(() => dataSource.OpenConnectionAsync(), Throws.Exception.TypeOf()); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString); + dataSource.Password = connectionStringBuilder.Password!; + + await using var connection1 = await dataSource.OpenConnectionAsync(); + await using var connection2 = dataSource.OpenConnection(); + } + + [Test] + public async Task Password_provider([Values]bool async) + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + var password = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString).Password!; + var syncProviderCalled = false; + var asyncProviderCalled = false; + dataSourceBuilder.UsePasswordProvider(_ => + { + syncProviderCalled = true; + return password; + }, (_,_) => + { + asyncProviderCalled = true; + return new(password); + }); + + using var dataSource = dataSourceBuilder.Build(); + using var conn = async ? await dataSource.OpenConnectionAsync() : dataSource.OpenConnection(); + Assert.True(async ? asyncProviderCalled : syncProviderCalled, "Password_provider not used"); + } + + [Test] + public void Password_provider_exception() + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + dataSourceBuilder.UsePasswordProvider(_ => throw new Exception(), (_,_) => throw new Exception()); + + using var dataSource = dataSourceBuilder.Build(); + Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync()); + } + + [Test] + public async Task Periodic_password_provider() + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + var password = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString).Password!; + + var mre = new ManualResetEvent(false); + dataSourceBuilder.UsePeriodicPasswordProvider((_, _) => + { + mre.Set(); + return new(password); + }, TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(10)); + + await using (var dataSource = dataSourceBuilder.Build()) + { + await using var connection1 = await dataSource.OpenConnectionAsync(); + await using var connection2 = dataSource.OpenConnection(); + + mre.Reset(); + if (!mre.WaitOne(TimeSpan.FromSeconds(30))) + Assert.Fail("Periodic password refresh did not occur"); + } + + mre.Reset(); + if (mre.WaitOne(TimeSpan.FromSeconds(1))) + Assert.Fail("Periodic password refresh occurred after disposal of the data source"); + } + + [Test] + public async Task Periodic_password_provider_with_first_time_exception() + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + dataSourceBuilder.UsePeriodicPasswordProvider( + (_, _) => throw new Exception("FOO"), TimeSpan.FromDays(30), TimeSpan.FromSeconds(10)); + await using var dataSource = dataSourceBuilder.Build(); + + Assert.That(() => dataSource.OpenConnectionAsync(), Throws.Exception.TypeOf() + .With.InnerException.With.Message.EqualTo("FOO")); + Assert.That(() => dataSource.OpenConnection(), Throws.Exception.TypeOf() + .With.InnerException.With.Message.EqualTo("FOO")); + } + + [Test] + public async Task Periodic_password_provider_with_second_time_exception() + { + var dataSourceBuilder = GetPasswordlessDataSourceBuilder(); + var password = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString).Password!; + + var times = 0; + var mre = new ManualResetEvent(false); + + dataSourceBuilder.UsePeriodicPasswordProvider( + (_, _) => + { + if (times++ > 1) + { + mre.Set(); + throw new Exception("FOO"); + } + + return new(password); + }, + TimeSpan.FromMilliseconds(100), + TimeSpan.FromMilliseconds(10)); + await using var dataSource = dataSourceBuilder.Build(); + + mre.WaitOne(); + + // The periodic timer threw, but previously returned a password. Make sure we keep using that last known one. + using (await dataSource.OpenConnectionAsync()) {} + using (dataSource.OpenConnection()) {} + } + + [Test] + public void Both_password_and_password_provider_is_not_supported() + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(TestUtil.ConnectionString); + dataSourceBuilder.UsePeriodicPasswordProvider((_, _) => new("foo"), TimeSpan.FromMinutes(1), TimeSpan.FromSeconds(10)); + Assert.That(() => dataSourceBuilder.Build(), Throws.Exception.TypeOf() + .With.Message.EqualTo(NpgsqlStrings.CannotSetBothPasswordProviderAndPassword)); + } + + [Test] + public void Multiple_password_providers_is_not_supported() + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(TestUtil.ConnectionString); + dataSourceBuilder + .UsePeriodicPasswordProvider((_, _) => new("foo"), TimeSpan.FromMinutes(1), TimeSpan.FromSeconds(10)) + .UsePasswordProvider(_ => "foo", (_,_) => new("foo")); + Assert.That(() => dataSourceBuilder.Build(), Throws.Exception.TypeOf() + .With.Message.EqualTo(NpgsqlStrings.CannotSetMultiplePasswordProviderKinds)); + } + + #region pgpass + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Use_pgpass_from_connection_string() + { + using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); + var builder = new NpgsqlConnectionStringBuilder(ConnectionString); + var passFile = Path.GetTempFileName(); + File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{builder.Password}"); + + try + { + await using var dataSource = CreateDataSource(csb => + { + csb.Passfile = null; + csb.Passfile = passFile; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + } + finally + { + File.Delete(passFile); + } + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Use_pgpass_from_environment_variable() + { + using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); + var builder = new NpgsqlConnectionStringBuilder(ConnectionString); + var passFile = Path.GetTempFileName(); + File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{builder.Password}"); + using var passFileVariable = SetEnvironmentVariable("PGPASSFILE", passFile); + + try + { + await using var dataSource = CreateDataSource(csb => csb.Password = null); + await using var conn = await dataSource.OpenConnectionAsync(); + } + finally + { + File.Delete(passFile); + } + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Use_pgpass_from_homedir() + { + using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); + + string? dirToDelete = null; + string passFile; + string? previousPassFile = null; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var dir = Path.Combine(Environment.GetEnvironmentVariable("APPDATA")!, "postgresql"); + if (!Directory.Exists(dir)) + { + Directory.CreateDirectory(dir); + dirToDelete = dir; + } + passFile = Path.Combine(dir, "pgpass.conf"); + } + else + { + passFile = Path.Combine(Environment.GetEnvironmentVariable("HOME")!, ".pgpass"); + } + + if (File.Exists(passFile)) + { + previousPassFile = Path.GetTempFileName(); + File.Move(passFile, previousPassFile); + } + + try + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionString); + File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{builder.Password}"); + await using var dataSource = CreateDataSource(csb => csb.Passfile = null); + await using var conn = await dataSource.OpenConnectionAsync(); + } + finally + { + File.Delete(passFile); + if (dirToDelete is not null) + Directory.Delete(dirToDelete); + if (previousPassFile is not null) + File.Move(previousPassFile, passFile); + } + } + + #endregion pgpass + + [Test] + [NonParallelizable] // Sets environment variable + public void Password_source_precedence() + { + using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); + + var builder = new NpgsqlConnectionStringBuilder(ConnectionString); + var password = builder.Password; + var passwordBad = password + "_bad"; + + var passFile = Path.GetTempFileName(); + var passFileBad = passFile + "_bad"; + + using var deletePassFile = Defer(() => File.Delete(passFile)); + using var deletePassFileBad = Defer(() => File.Delete(passFileBad)); + + File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{password}"); + File.WriteAllText(passFileBad, $"*:*:*:{builder.Username}:{passwordBad}"); + + using (SetEnvironmentVariable("PGPASSFILE", passFileBad)) + { + // Password from the connection string goes first + using (SetEnvironmentVariable("PGPASSWORD", passwordBad)) + { + using var dataSource1 = CreateDataSource(csb => + { + csb.Password = password; + csb.Passfile = passFileBad; + }); + + Assert.That(() => dataSource1.OpenConnection(), Throws.Nothing); + } + + // Password from the environment variable goes second + using (SetEnvironmentVariable("PGPASSWORD", password)) + { + using var dataSource2 = CreateDataSource(csb => + { + csb.Password = null; + csb.Passfile = passFileBad; + }); + + Assert.That(() => dataSource2.OpenConnection(), Throws.Nothing); + } + + // Passfile from the connection string goes third + using var dataSource3 = CreateDataSource(csb => + { + csb.Password = null; + csb.Passfile = passFile; + }); + + Assert.That(() => dataSource3.OpenConnection(), Throws.Nothing); + } + + // Passfile from the environment variable goes fourth + using (SetEnvironmentVariable("PGPASSFILE", passFile)) + { + using var dataSource4 = CreateDataSource(csb => + { + csb.Password = null; + csb.Passfile = null; + }); + + Assert.That(() => dataSource4.OpenConnection(), Throws.Nothing); + } + + static DeferDisposable Defer(Action action) => new(action); + } + + readonly struct DeferDisposable : IDisposable + { + readonly Action _action; + public DeferDisposable(Action action) => _action = action; + public void Dispose() => _action(); + } + + [Test, Description("Connects with a bad password to ensure the proper error is thrown")] + public void Authentication_failure() + { + using var dataSource = CreateDataSource(csb => csb.Password = "bad"); + using var conn = dataSource.CreateConnection(); + + Assert.That(() => conn.OpenAsync(), Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).StartsWith("28") + ); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); + } + + [Test, Description("Simulates a timeout during the authentication phase")] + [IssueLink("https://github.com/npgsql/npgsql/issues/3227")] + public async Task Timeout_during_authentication() + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { Timeout = 1 }; + await using var postmasterMock = new PgPostmasterMock(builder.ConnectionString); + _ = postmasterMock.AcceptServer(); + + // The server will accept a connection from the client, but will not respond to the client's authentication + // request. This should trigger a timeout + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var connection = dataSource.CreateConnection(); + Assert.That(async () => await connection.OpenAsync(), + Throws.Exception.TypeOf() + .With.InnerException.TypeOf()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1180")] + public void Pool_by_password() + { + using var _ = CreateTempPool(ConnectionString, out var connectionString); + using (var goodConn = new NpgsqlConnection(connectionString)) + goodConn.Open(); + + var badConnectionString = new NpgsqlConnectionStringBuilder(connectionString) + { + Password = "badpasswd" + }.ConnectionString; + using (var conn = new NpgsqlConnection(badConnectionString)) + Assert.That(conn.Open, Throws.Exception.TypeOf()); + } + + [Test, Explicit("Requires user specific local setup")] + public async Task AuthenticateIntegratedSecurity() + { + await using var dataSource = NpgsqlDataSource.Create(new NpgsqlConnectionStringBuilder(ConnectionString) + { + Username = null, + Password = null + }); + await using var c = await dataSource.OpenConnectionAsync(); + Assert.That(c.State, Is.EqualTo(ConnectionState.Open)); + } + + #region ProvidePasswordCallback Tests + +#pragma warning disable CS0618 // ProvidePasswordCallback is Obsolete + + [Test, Description("ProvidePasswordCallback is used when password is not supplied in connection string")] + public async Task ProvidePasswordCallback_is_used() + { + using var _ = CreateTempPool(ConnectionString, out var connString); + var builder = new NpgsqlConnectionStringBuilder(connString); + var goodPassword = builder.Password; + var getPasswordDelegateWasCalled = false; + builder.Password = null; + + Assume.That(goodPassword, Is.Not.Null); + + using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) + { + conn.Open(); + Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); + + // Do this again, since with multiplexing the very first connection attempt is done via + // the non-multiplexing path, to surface any exceptions. + NpgsqlConnection.ClearPool(conn); + conn.Close(); + getPasswordDelegateWasCalled = false; + conn.Open(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); + } + + string ProvidePasswordCallback(string host, int port, string database, string username) + { + getPasswordDelegateWasCalled = true; + return goodPassword!; + } + } + + [Test, Description("ProvidePasswordCallback is not used when password is supplied in connection string")] + public void ProvidePasswordCallback_is_not_used() + { + using var _ = CreateTempPool(ConnectionString, out var connString); + + using (var conn = new NpgsqlConnection(connString) { ProvidePasswordCallback = ProvidePasswordCallback }) + { + conn.Open(); + + // Do this again, since with multiplexing the very first connection attempt is done via + // the non-multiplexing path, to surface any exceptions. + NpgsqlConnection.ClearPool(conn); + conn.Close(); + conn.Open(); + } + + string ProvidePasswordCallback(string host, int port, string database, string username) + { + throw new Exception("password should come from connection string, not delegate"); + } + } + + [Test, Description("Exceptions thrown from client application are wrapped when using ProvidePasswordCallback Delegate")] + public void ProvidePasswordCallback_exceptions_are_wrapped() + { + using var _ = CreateTempPool(ConnectionString, out var connString); + var builder = new NpgsqlConnectionStringBuilder(connString) + { + Password = null + }; + + using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) + { + Assert.That(() => conn.Open(), Throws.Exception + .TypeOf() + .With.InnerException.Message.EqualTo("inner exception from ProvidePasswordCallback") + ); + } + + string ProvidePasswordCallback(string host, int port, string database, string username) + { + throw new Exception("inner exception from ProvidePasswordCallback"); + } + } + + [Test, Description("Parameters passed to ProvidePasswordCallback delegate are correct")] + public void ProvidePasswordCallback_gets_correct_arguments() + { + using var _ = CreateTempPool(ConnectionString, out var connString); + var builder = new NpgsqlConnectionStringBuilder(connString); + var goodPassword = builder.Password; + builder.Password = null; + + Assume.That(goodPassword, Is.Not.Null); + + string? receivedHost = null; + int? receivedPort = null; + string? receivedDatabase = null; + string? receivedUsername = null; + + using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) + { + conn.Open(); + Assert.AreEqual(builder.Host, receivedHost); + Assert.AreEqual(builder.Port, receivedPort); + Assert.AreEqual(builder.Database, receivedDatabase); + Assert.AreEqual(builder.Username, receivedUsername); + } + + string ProvidePasswordCallback(string host, int port, string database, string username) + { + receivedHost = host; + receivedPort = port; + receivedDatabase = database; + receivedUsername = username; + + return goodPassword!; + } + } + +#pragma warning restore CS0618 // ProvidePasswordCallback is Obsolete + + #endregion + + NpgsqlDataSourceBuilder GetPasswordlessDataSourceBuilder() + => new(TestUtil.ConnectionString) + { + ConnectionStringBuilder = + { + Password = null + } + }; + + public AuthenticationTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} +} diff --git a/test/Npgsql.Tests/AutoPrepareTests.cs b/test/Npgsql.Tests/AutoPrepareTests.cs index f06e5d2815..14d6997230 100644 --- a/test/Npgsql.Tests/AutoPrepareTests.cs +++ b/test/Npgsql.Tests/AutoPrepareTests.cs @@ -1,481 +1,569 @@ -using System; +using NpgsqlTypes; +using NUnit.Framework; +using System; +using System.Data; using System.Linq; -using System.Threading; using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class AutoPrepareTests : TestBase { - [Parallelizable(ParallelScope.None)] - public class AutoPrepareTests : TestBase + [Test] + public void Basic() { - [Test] - public void Basic() + using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn); + checkCmd.Prepare(); - using (var conn = OpenConnection(csb)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - { - checkCmd.Prepare(); - - conn.ExecuteNonQuery("SELECT 1"); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); - - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - Assert.That(cmd.IsPrepared, Is.False); - cmd.ExecuteScalar(); - Assert.That(cmd.IsPrepared, Is.True); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - cmd.ExecuteScalar(); - Assert.That(cmd.IsPrepared, Is.True); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - } - - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.ExecuteScalar(); - Assert.That(cmd.IsPrepared, Is.True); - } - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - conn.UnprepareAll(); - } + conn.ExecuteNonQuery("SELECT 1"); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); + + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + { + Assert.That(cmd.IsPrepared, Is.False); + cmd.ExecuteScalar(); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + cmd.ExecuteScalar(); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); } - [Test, Description("Passes the maximum limit for autoprepared statements, recycling the least-recently used one")] - public void Recycle() + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - AutoPrepareMinUsages = 2, - MaxAutoPrepare = 2 - }; + cmd.ExecuteScalar(); + Assert.That(cmd.IsPrepared, Is.True); + } + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + } - using (var conn = OpenConnection(csb)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - { - checkCmd.Prepare(); - - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); - var cmd1 = new NpgsqlCommand("SELECT 1", conn); - cmd1.ExecuteNonQuery(); cmd1.ExecuteNonQuery(); - Assert.That(cmd1.IsPrepared, Is.True); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - Thread.Sleep(10); - - var cmd2 = new NpgsqlCommand("SELECT 2", conn); - cmd2.ExecuteNonQuery(); cmd2.ExecuteNonQuery(); - Assert.That(cmd2.IsPrepared, Is.True); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(2)); - - // Use cmd1 to make cmd2 the lru - Thread.Sleep(1); - cmd1.ExecuteNonQuery(); - - // Cause another statement to be autoprepared. This should eject cmd2. - conn.ExecuteNonQuery("SELECT 3"); conn.ExecuteNonQuery("SELECT 3"); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(2)); - - cmd2.ExecuteNonQuery(); - Assert.That(cmd2.IsPrepared, Is.False); - using (var getTextCmd = new NpgsqlCommand("SELECT statement FROM pg_prepared_statements WHERE statement NOT LIKE '%COUNT%' ORDER BY statement", conn)) - using (var reader = getTextCmd.ExecuteReader()) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetString(0), Is.EqualTo("SELECT 1")); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetString(0), Is.EqualTo("SELECT 3")); - } - conn.UnprepareAll(); - } + [Test, Description("Passes the maximum limit for autoprepared statements, recycling the least-recently used one")] + public void Recycle() + { + using var dataSource = CreateDataSource(csb => + { + csb.AutoPrepareMinUsages = 2; + csb.MaxAutoPrepare = 2; + }); + using var conn = dataSource.OpenConnection(); + using var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn); + checkCmd.Prepare(); + + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); + var cmd1 = new NpgsqlCommand("SELECT 1", conn); + cmd1.ExecuteNonQuery(); cmd1.ExecuteNonQuery(); + Assert.That(cmd1.IsPrepared, Is.True); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + + var cmd2 = new NpgsqlCommand("SELECT 2", conn); + cmd2.ExecuteNonQuery(); cmd2.ExecuteNonQuery(); + Assert.That(cmd2.IsPrepared, Is.True); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(2)); + + cmd1.ExecuteNonQuery(); + + // Cause another statement to be autoprepared. This should eject cmd2. + conn.ExecuteNonQuery("SELECT 3"); conn.ExecuteNonQuery("SELECT 3"); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(2)); + + cmd2.ExecuteNonQuery(); + Assert.That(cmd2.IsPrepared, Is.False); + using (var getTextCmd = new NpgsqlCommand("SELECT statement FROM pg_prepared_statements WHERE statement NOT LIKE '%COUNT%' ORDER BY statement", conn)) + using (var reader = getTextCmd.ExecuteReader()) + { + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetString(0), Is.EqualTo("SELECT 1")); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetString(0), Is.EqualTo("SELECT 3")); } + } - [Test] - public void Persist() + [Test] + public void Persist() + { + using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(Persist), - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }.ToString(); - try - { - using (var conn = OpenConnection(connString)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - { - checkCmd.Prepare(); - conn.ExecuteNonQuery("SELECT 1"); conn.ExecuteNonQuery("SELECT 1"); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - } - - // We now have two prepared statements which should be persisted - - using (var conn = OpenConnection(connString)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - { - checkCmd.Prepare(); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.ExecuteScalar(); - //Assert.That(cmd.IsPrepared); - } - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - } - } - finally + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + + using (var conn = dataSource.OpenConnection()) + using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) + { + checkCmd.Prepare(); + conn.ExecuteNonQuery("SELECT 1"); conn.ExecuteNonQuery("SELECT 1"); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + } + + // We now have two prepared statements which should be persisted + + using (var conn = dataSource.OpenConnection()) + using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) + { + checkCmd.Prepare(); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = new NpgsqlConnection(connString)) - NpgsqlConnection.ClearPool(conn); + cmd.ExecuteScalar(); + //Assert.That(cmd.IsPrepared); } + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); } + } - [Test] - public void PromoteAutoToExplicit() + [Test] + public async Task Positional_parameter() + { + await using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using (var conn = OpenConnection(csb)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - using (var cmd1 = new NpgsqlCommand("SELECT 1", conn)) - using (var cmd2 = new NpgsqlCommand("SELECT 1", conn)) - { - checkCmd.Prepare(); + csb.AutoPrepareMinUsages = 2; + csb.MaxAutoPrepare = 2; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn); + await checkCmd.PrepareAsync(); + + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + + Assert.That(cmd.IsPrepared, Is.False); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + Assert.That(cmd.IsPrepared, Is.False); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + Assert.That(cmd.IsPrepared, Is.True); + } - cmd1.ExecuteNonQuery(); cmd1.ExecuteNonQuery(); - // cmd1 is now autoprepared - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + [Test] + public void Promote_auto_to_explicit() + { + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn); + using var cmd1 = new NpgsqlCommand("SELECT 1", conn); + using var cmd2 = new NpgsqlCommand("SELECT 1", conn); + checkCmd.Prepare(); + + cmd1.ExecuteNonQuery(); cmd1.ExecuteNonQuery(); + // cmd1 is now autoprepared + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + + // Promote (replace) the autoprepared statement with an explicit one. + cmd2.Prepare(); + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + + // cmd1's statement is no longer valid (has been closed), make sure it still works (will run unprepared) + cmd2.ExecuteScalar(); + } - // Promote (replace) the autoprepared statement with an explicit one. - cmd2.Prepare(); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + [Test] + public void Candidate_eject() + { + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 3; + }); + using var conn = dataSource.OpenConnection(); + using var cmd = conn.CreateCommand(); - // cmd1's statement is no longer valid (has been closed), make sure it still works (will run unprepared) - cmd2.ExecuteScalar(); - conn.UnprepareAll(); - } + for (var i = 0; i < PreparedStatementManager.CandidateCount; i++) + { + cmd.CommandText = $"SELECT {i}"; + cmd.ExecuteNonQuery(); } - [Test] - public void CandidateEject() + // The candidate list is now full with single-use statements. + + cmd.CommandText = "SELECT 'double_use'"; + cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); + // We now have a single statement that has been used twice. + + for (var i = PreparedStatementManager.CandidateCount; i < PreparedStatementManager.CandidateCount * 2; i++) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 3 - }; - using (var conn = OpenConnection(csb)) - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; - - for (var i = 0; i < PreparedStatementManager.CandidateCount; i++) - { - cmd.CommandText = $"SELECT {i}"; - cmd.ExecuteNonQuery(); - Thread.Sleep(1); - } - - // The candidate list is now full with single-use statements. - - cmd.CommandText = $"SELECT 'double_use'"; - cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); - // We now have a single statement that has been used twice. - - for (var i = PreparedStatementManager.CandidateCount; i < PreparedStatementManager.CandidateCount * 2; i++) - { - cmd.CommandText = $"SELECT {i}"; - cmd.ExecuteNonQuery(); - Thread.Sleep(1); - } - - // The new single-use statements should have ejected all previous single-use statements - cmd.CommandText = "SELECT 1"; - cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); - Assert.That(cmd.IsPrepared, Is.False); - - // But the double-use statement should still be there - cmd.CommandText = "SELECT 'double_use'"; - cmd.ExecuteNonQuery(); - Assert.That(cmd.IsPrepared, Is.True); - - conn.UnprepareAll(); - } + cmd.CommandText = $"SELECT {i}"; + cmd.ExecuteNonQuery(); } - [Test] - public void OneCommandSameSqlTwice() + // The new single-use statements should have ejected all previous single-use statements + cmd.CommandText = "SELECT 1"; + cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); + Assert.That(cmd.IsPrepared, Is.False); + + // But the double-use statement should still be there + cmd.CommandText = "SELECT 'double_use'"; + cmd.ExecuteNonQuery(); + Assert.That(cmd.IsPrepared, Is.True); + } + + [Test] + public void One_command_same_sql_twice() + { + using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using (var conn = OpenConnection(csb)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 1; SELECT 1; SELECT 1", conn)) - { - //cmd.Prepare(); - //Assert.That(cmd.IsPrepared, Is.True); - cmd.ExecuteNonQuery(); - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - conn.UnprepareAll(); - } - } + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 1; SELECT 1; SELECT 1", conn); + //cmd.Prepare(); + //Assert.That(cmd.IsPrepared, Is.True); + cmd.ExecuteNonQuery(); + Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.EqualTo(1)); + } - [Test] - public void AcrossCloseOpenDifferentConnector() + [Test] + public void Across_close_open_different_connector() + { + using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(AutoPrepareTests) + '.' + nameof(AcrossCloseOpenDifferentConnector), - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }.ToString(); - using (var conn1 = new NpgsqlConnection(connString)) - using (var conn2 = new NpgsqlConnection(connString)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn1)) - { - conn1.Open(); - cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); - Assert.That(cmd.IsPrepared, Is.True); - var processId = conn1.ProcessID; - conn1.Close(); - conn2.Open(); - conn1.Open(); - Assert.That(conn1.ProcessID, Is.Not.EqualTo(processId)); - Assert.That(cmd.IsPrepared, Is.False); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); // Execute unprepared - cmd.Prepare(); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - NpgsqlConnection.ClearPool(conn1); - } - } + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn1 = dataSource.CreateConnection(); + using var conn2 = dataSource.CreateConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn1); + conn1.Open(); + cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); + Assert.That(cmd.IsPrepared, Is.True); + var processId = conn1.ProcessID; + conn1.Close(); + conn2.Open(); + conn1.Open(); + Assert.That(conn1.ProcessID, Is.Not.EqualTo(processId)); + Assert.That(cmd.IsPrepared, Is.False); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); // Execute unprepared + cmd.Prepare(); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } - [Test] - public void UnprepareAll() + [Test] + public void Unprepare_all() + { + using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Prepare(); // Explicit + conn.ExecuteNonQuery("SELECT 2"); conn.ExecuteNonQuery("SELECT 2"); // Auto + Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.EqualTo(2)); + conn.UnprepareAll(); + Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.Zero); + } - using (var conn = OpenConnection(csb)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); // Explicit - conn.ExecuteNonQuery("SELECT 2"); conn.ExecuteNonQuery("SELECT 2"); // Auto - Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.EqualTo(2)); - conn.UnprepareAll(); - Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.Zero); - } + [Test, Description("Prepares the same SQL with different parameters (overloading)")] + public void Overloaded_sql() + { + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) + { + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Integer, 8); + cmd.ExecuteNonQuery(); + cmd.ExecuteNonQuery(); + Assert.That(cmd.IsPrepared, Is.True); } - - [Test, Description("Prepares the same SQL with different parameters (overloading)")] - public void OverloadedSql() + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using (var conn = OpenConnection(csb)) - { - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Integer, 8); - cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); - Assert.That(cmd.IsPrepared, Is.True); - } - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, "foo"); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); - Assert.That(cmd.IsPrepared, Is.False); - } - - // SQL overloading is a pretty rare/exotic scenario. Handling it properly would involve keying - // prepared statements not just by SQL but also by the parameter types, which would pointlessly - // increase allocations. Instead, the second execution simply reuns unprepared - Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM pg_prepared_statements"), Is.EqualTo(1)); - conn.UnprepareAll(); - } + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, "foo"); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); + Assert.That(cmd.IsPrepared, Is.False); } - [Test, Description("Tests parameter derivation a parameterized query (CommandType.Text) that is already auto-prepared.")] - public void DeriveParametersForAutoPreparedStatement() - { - const string query = "SELECT @p::integer"; - const int answer = 42; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using (var conn = OpenConnection(csb)) - using (var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn)) - using (var cmd = new NpgsqlCommand(query, conn)) - { - checkCmd.Prepare(); - cmd.Parameters.AddWithValue("@p", NpgsqlDbType.Integer, answer); - cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); // cmd1 is now autoprepared - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); - Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + // SQL overloading is a pretty rare/exotic scenario. Handling it properly would involve keying + // prepared statements not just by SQL but also by the parameter types, which would pointlessly + // increase allocations. Instead, the second execution simply runs unprepared. + Assert.That(conn.ExecuteScalar(CountPreparedStatements), Is.EqualTo(1)); + } - // Derive parameters for the already autoprepared statement - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters.Count, Is.EqualTo(1)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("p")); + [Test, Description("Tests parameter derivation a parameterized query (CommandType.Text) that is already auto-prepared.")] + public void Derive_parameters_for_auto_prepared_statement() + { + const string query = "SELECT @p::integer"; + const int answer = 42; + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + using var checkCmd = new NpgsqlCommand(CountPreparedStatements, conn); + using var cmd = new NpgsqlCommand(query, conn); + checkCmd.Prepare(); + cmd.Parameters.AddWithValue("@p", NpgsqlDbType.Integer, answer); + cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); // cmd1 is now autoprepared + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(2)); + + // Derive parameters for the already autoprepared statement + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters.Count, Is.EqualTo(1)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("p")); + + // DeriveParameters should have silently unprepared the autoprepared statements + Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); + Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); + + cmd.Parameters["@p"].Value = answer; + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(answer)); + } - // DeriveParameters should have silently unprepared the autoprepared statements - Assert.That(checkCmd.ExecuteScalar(), Is.EqualTo(0)); - Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2644")] + public void Row_description_properly_cloned() + { + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + conn.UnprepareAll(); + using var cmd1 = new NpgsqlCommand("SELECT 1 AS foo", conn); + using var cmd2 = new NpgsqlCommand("SELECT 1 AS bar", conn); + + cmd1.ExecuteNonQuery(); + cmd1.ExecuteNonQuery(); // Query is now auto-prepared + cmd2.ExecuteNonQuery(); + using var reader = cmd1.ExecuteReader(); + Assert.That(reader.GetName(0), Is.EqualTo("foo")); + } - cmd.Parameters["@p"].Value = answer; - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(answer)); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3106")] + public async Task Dont_auto_prepare_more_than_max_statements_in_batch() + { + const int maxAutoPrepare = 50; - conn.UnprepareAll(); - } + await using var dataSource = CreateDataSource(csb => csb.MaxAutoPrepare = maxAutoPrepare); + await using var connection = await dataSource.OpenConnectionAsync(); + for (var i = 0; i < 100; i++) + { + await using var command = connection.CreateCommand(); + command.CommandText = string.Join("", Enumerable.Range(0, 100).Select(n => $"SELECT {n};")); + await command.ExecuteNonQueryAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2644")] - public void RowDescriptionProperlyCloned() + Assert.That(await connection.ExecuteScalarAsync(CountPreparedStatements), Is.LessThanOrEqualTo(maxAutoPrepare)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3106")] + public async Task Dont_auto_prepare_more_than_max_statements_in_batch_random() + { + const int maxAutoPrepare = 10; + + await using var dataSource = CreateDataSource(csb => csb.MaxAutoPrepare = maxAutoPrepare); + await using var connection = await dataSource.OpenConnectionAsync(); + var random = new Random(1); + for (var i = 0; i < 100; i++) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using var conn = OpenConnection(csb); - using var cmd1 = new NpgsqlCommand("SELECT 1 AS foo", conn); - using var cmd2 = new NpgsqlCommand("SELECT 1 AS bar", conn); - - cmd1.ExecuteNonQuery(); - cmd1.ExecuteNonQuery(); // Query is now auto-prepared - cmd2.ExecuteNonQuery(); - using (var reader = cmd1.ExecuteReader()) - Assert.That(reader.GetName(0), Is.EqualTo("foo")); - - conn.UnprepareAll(); + await using var command = connection.CreateCommand(); + command.CommandText = string.Join("", Enumerable.Range(0, 100).Select(n => $"SELECT {random.Next(200)};")); + await command.ExecuteNonQueryAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3106")] - public async Task DontAutoPrepareMoreThanMaxStatementsInBatch() + Assert.That(await connection.ExecuteScalarAsync(CountPreparedStatements), Is.LessThanOrEqualTo(maxAutoPrepare)); + } + + [Test] + public async Task Replace_and_execute_within_same_batch() + { + await using var dataSource = CreateDataSource(csb => { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 50, - }; + csb.MaxAutoPrepare = 1; + csb.AutoPrepareMinUsages = 2; + }); + await using var connection = await dataSource.OpenConnectionAsync(); + for (var i = 0; i < 2; i++) + await connection.ExecuteNonQueryAsync("SELECT 1"); + + // SELECT 1 is now auto-prepared and occupying the only slot. + // Within the same batch, cause another SQL to replace it, and then execute it. + await connection.ExecuteNonQueryAsync("SELECT 2; SELECT 2; SELECT 1"); + } - using var _ = CreateTempPool(builder.ToString(), out var connectionString); - await using var connection = new NpgsqlConnection(connectionString); - await connection.OpenAsync(); - for (var i = 0; i < 100; i++) - { - using var command = connection.CreateCommand(); - command.CommandText = string.Join("", Enumerable.Range(0, 100).Select(n => $"SELECT {n};")); - await command.ExecuteNonQueryAsync(); - } - } + // Exclude some internal Npgsql queries which include pg_type as well as the count statement itself + const string CountPreparedStatements = """ +SELECT COUNT(*) FROM pg_prepared_statements +WHERE statement NOT LIKE '%pg_prepared_statements%' +AND statement NOT LIKE '%pg_type%' +"""; - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3106")] - public async Task DontAutoPrepareMoreThanMaxStatementsInBatchRandom() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2665")] + public async Task Auto_prepared_command_failure() + { + await using var dataSource = CreateDataSource(csb => { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - }; + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + await using var conn = await dataSource.OpenConnectionAsync(); - await using var connection = new NpgsqlConnection(builder.ToString()); - await connection.OpenAsync(); - var random = new Random(1); - for (var i = 0; i < 100; i++) - { - using var command = connection.CreateCommand(); - command.CommandText = string.Join("", Enumerable.Range(0, 100).Select(n => $"SELECT {random.Next(200)};")); - await command.ExecuteNonQueryAsync(); - } - } + var tableName = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id integer)"); - [Test] - public async Task ReplaceAndExecuteWithinSameBatch() + await using (var command = new NpgsqlCommand($"INSERT INTO {tableName} (id) VALUES (1)", conn)) { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 1, - AutoPrepareMinUsages = 2 - }; - - using var _ = CreateTempPool(builder.ToString(), out var connectionString); - await using var connection = new NpgsqlConnection(connectionString); - await connection.OpenAsync(); - for (var i = 0; i < 2; i++) - await connection.ExecuteNonQueryAsync("SELECT 1"); - - // SELECT 1 is now auto-prepared and occupying the only slot. - // Within the same batch, cause another SQL to replace it, and then execute it. - await connection.ExecuteNonQueryAsync("SELECT 2; SELECT 2; SELECT 1"); + await command.ExecuteNonQueryAsync(); + await conn.ExecuteNonQueryAsync($"DROP TABLE {tableName}"); + Assert.ThrowsAsync(async () => await command.ExecuteNonQueryAsync()); } - // Exclude some internal Npgsql queries which include pg_type as well as the count statement itself - const string CountPreparedStatements = @" -SELECT COUNT(*) FROM pg_prepared_statements - WHERE statement NOT LIKE '%pg_prepared_statements%' - AND statement NOT LIKE '%pg_type%'"; + await conn.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id integer)"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2665")] - public void AutoPreparedCommandFailure() + await using (var command = new NpgsqlCommand($"INSERT INTO {tableName} (id) VALUES (1)", conn)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2 - }; - using var conn = OpenConnection(csb); - - conn.ExecuteNonQuery("CREATE TEMP TABLE test_table (id integer)"); + await command.ExecuteNonQueryAsync(); + await command.ExecuteNonQueryAsync(); + } + } - using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) - { - command.ExecuteNonQuery(); - conn.ExecuteNonQuery("DROP TABLE test_table"); - Assert.Throws(() => command.ExecuteNonQuery()); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3002")] + public void Replace_with_bad_sql() + { + using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 2; + csb.AutoPrepareMinUsages = 1; + }); + using var conn = dataSource.OpenConnection(); + + conn.ExecuteNonQuery("SELECT 1"); + conn.ExecuteNonQuery("SELECT 2"); + + // Attempt to replace SELECT 1, but fail because of bad SQL. + // Because of the issue, PreparedStatementManager.NumPrepared is reduced from 2 to 1 + Assert.That(() => conn.ExecuteNonQuery("SELECTBAD"), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + // Prevent SELECT 2 from being the LRU + conn.ExecuteNonQuery("SELECT 2"); + // And attempt to replace again, reducing PreparedStatementManager.NumPrepared to 0 + Assert.That(() => conn.ExecuteNonQuery("SELECTBAD"), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + + // Since PreparedStatementManager.NumPrepared is 0, Npgsql will now send DISCARD ALL, but our internal state thinks + // SELECT 2 is still prepared. + conn.Close(); + conn.Open(); + + Assert.That(conn.ExecuteScalar("SELECT 2"), Is.EqualTo(2)); + } - conn.ExecuteNonQuery("CREATE TEMP TABLE test_table (id integer)"); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4082")] + public async Task Batch_statement_execution_error_cleanup() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 2; + csb.AutoPrepareMinUsages = 1; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + var funcName = await GetTempFunctionName(conn); + + // Create a function we can use to raise an error with a single statement + await conn.ExecuteNonQueryAsync( +$""" +CREATE OR REPLACE FUNCTION {funcName}() RETURNS VOID AS + 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345'', DETAIL = ''testdetail''; END;' +LANGUAGE 'plpgsql'; +"""); + + conn.UnprepareAll(); + + // Occupy _auto1 and _auto2 + await conn.ExecuteNonQueryAsync("SELECT 1"); + await conn.ExecuteNonQueryAsync("SELECT 2"); + + // Execute two new SELECTs which will replace the above two. _auto1 will now contain SELECT pg_temp.emit_exception() + // and _auto2 will contain SELECT 4. Note that they must be in this order because only the statements following + // the error-triggering statement will be unprepared. + // + // We expect error 12345. Prior to the error being raised, the SELECT pg_temp.emit_exception will be successfully prepared + // and the previous _auto1 (SELECT 1) will be successfully closed. However, the subsequent SELECT 4 will not be prepared, + // and the previous _auto2 (SELECT 2) will not be properly closed. SELECT 4 will then be unprepared. + var ex = Assert.ThrowsAsync(async () => await conn.ExecuteNonQueryAsync($"SELECT {funcName}(); SELECT 4"))!; + Assert.That(ex, Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo("12345")); + + // The PreparedStatementManager prioritises replacement of unprepared statements, so we know this will replace SELECT 4 in + // _auto2. The code previously assumed that cleanup was never required when replacing an unprepared statement (since it + // was never prepared in PG) and this is true in most cases. However, in this case, SELECT 3 needs to logically replace + // SELECT 2. + // + // Due to the bug, _auto2 never gets cleaned up and this throws a 42P05 (prepared statement "_auto2" already exists) + // when we try to use that slot + Assert.That(await conn.ExecuteScalarAsync("SELECT 3"), Is.EqualTo(3)); + } - using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) - { - command.ExecuteNonQuery(); - command.ExecuteNonQuery(); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4404"), IssueLink("https://github.com/npgsql/npgsql/issues/5220")] + public async Task SchemaOnly() + { + await using var dataSource = CreateDataSource(csb => + { + csb.AutoPrepareMinUsages = 2; + csb.MaxAutoPrepare = 10; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); - conn.UnprepareAll(); + for (var i = 0; i < 5; i++) + { + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); } - void DumpPreparedStatements(NpgsqlConnection conn) + // Make sure there is no protocol desync due to #5220 + await cmd.ExecuteScalarAsync(); + } + + [Test] + public async Task Auto_prepared_statement_invalidation() + { + await using var dataSource = CreateDataSource(csb => { - using (var cmd = new NpgsqlCommand("SELECT name,statement FROM pg_prepared_statements", conn)) - using (var reader = cmd.ExecuteReader()) - { - while (reader.Read()) - Console.WriteLine($"{reader.GetString(0)}: {reader.GetString(1)}"); - } - } + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + await using var connection = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(connection, "foo int"); + + await using var command = new NpgsqlCommand($"SELECT * FROM {table}", connection); + for (var i = 0; i < 2; i++) + await command.ExecuteNonQueryAsync(); + + await connection.ExecuteNonQueryAsync($"ALTER TABLE {table} RENAME COLUMN foo TO bar"); + + // Since we've changed the table schema, the next execution of the prepared statement will error with 0A000 + var exception = Assert.ThrowsAsync(() => command.ExecuteNonQueryAsync())!; + Assert.That(exception.SqlState, Is.EqualTo(PostgresErrorCodes.FeatureNotSupported)); // cached plan must not change result type + + // However, Npgsql should invalidate the prepared statement in this case, so the next execution should work + Assert.DoesNotThrowAsync(() => command.ExecuteNonQueryAsync()); + } + + void DumpPreparedStatements(NpgsqlConnection conn) + { + using var cmd = new NpgsqlCommand("SELECT name,statement FROM pg_prepared_statements", conn); + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + Console.WriteLine($"{reader.GetString(0)}: {reader.GetString(1)}"); } } diff --git a/test/Npgsql.Tests/BatchTests.cs b/test/Npgsql.Tests/BatchTests.cs new file mode 100644 index 0000000000..acbff0a540 --- /dev/null +++ b/test/Npgsql.Tests/BatchTests.cs @@ -0,0 +1,911 @@ +using NUnit.Framework; +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Threading.Tasks; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.Default)] +[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.Default)] +[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.SequentialAccess)] +[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] +public class BatchTests : MultiplexingTestBase +{ + #region Parameters + + [Test] + public async Task Named_parameters() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT @p") { Parameters = { new("p", 8) } }, + new("SELECT @p1, @p2") { Parameters = { new("p1", 9), new("p2", 10) } } + } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(2)); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader[0], Is.EqualTo(9)); + Assert.That(reader[1], Is.EqualTo(10)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + public async Task Positional_parameters() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT $1") { Parameters = { new() { Value = 8 } } }, + new("SELECT $1, $2") { Parameters = { new() { Value = 9 }, new() { Value = 10 } } } + } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(2)); + Assert.That(reader[0], Is.EqualTo(9)); + Assert.That(reader[1], Is.EqualTo(10)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + #endregion Parameters + + #region NpgsqlBatchCommand + + [Test] + public async Task RecordsAffected_and_Rows() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (name) VALUES ('a'), ('b')"), + new($"UPDATE {table} SET name='c' WHERE name='b'"), + new($"UPDATE {table} SET name='d' WHERE name='doesnt_exist'"), + new($"SELECT name FROM {table}"), + new($"DELETE FROM {table}") + } + }; + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume SELECT result set to parse the CommandComplete + await reader.CloseAsync(); + + var command = batch.BatchCommands[0]; + Assert.That(command.RecordsAffected, Is.EqualTo(2)); + Assert.That(command.Rows, Is.EqualTo(2)); + + command = batch.BatchCommands[1]; + Assert.That(command.RecordsAffected, Is.EqualTo(1)); + Assert.That(command.Rows, Is.EqualTo(1)); + + command = batch.BatchCommands[2]; + Assert.That(command.RecordsAffected, Is.EqualTo(0)); + Assert.That(command.Rows, Is.EqualTo(0)); + + command = batch.BatchCommands[3]; + Assert.That(command.RecordsAffected, Is.EqualTo(-1)); + Assert.That(command.Rows, Is.EqualTo(2)); + + command = batch.BatchCommands[4]; + Assert.That(command.RecordsAffected, Is.EqualTo(2)); + Assert.That(command.Rows, Is.EqualTo(2)); + } + + [Test] + public async Task Merge_RecordsAffected_and_Rows() + { + await using var conn = await OpenConnectionAsync(); + + MinimumPgVersion(conn, "15.0", "MERGE statement was introduced in PostgreSQL 15"); + + var table = await CreateTempTable(conn, "name TEXT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (name) VALUES ('a'), ('b')"), + new($"MERGE INTO {table} S USING (SELECT 'b' as name) T ON T.name = S.name WHEN MATCHED THEN UPDATE SET name = 'c'"), + new($"MERGE INTO {table} S USING (SELECT 'b' as name) T ON T.name = S.name WHEN NOT MATCHED THEN INSERT (name) VALUES ('b')"), + new($"MERGE INTO {table} S USING (SELECT 'b' as name) T ON T.name = S.name WHEN MATCHED THEN DELETE"), + new($"MERGE INTO {table} S USING (SELECT 'b' as name) T ON T.name = S.name WHEN NOT MATCHED THEN DO NOTHING") + } + }; + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume MERGE result set to parse the CommandComplete + await reader.CloseAsync(); + + var command = batch.BatchCommands[0]; + Assert.That(command.StatementType, Is.EqualTo(StatementType.Insert)); + Assert.That(command.RecordsAffected, Is.EqualTo(2)); + Assert.That(command.Rows, Is.EqualTo(2)); + + command = batch.BatchCommands[1]; + Assert.That(command.StatementType, Is.EqualTo(StatementType.Merge)); + Assert.That(command.RecordsAffected, Is.EqualTo(1)); + Assert.That(command.Rows, Is.EqualTo(1)); + + command = batch.BatchCommands[2]; + Assert.That(command.StatementType, Is.EqualTo(StatementType.Merge)); + Assert.That(command.RecordsAffected, Is.EqualTo(1)); + Assert.That(command.Rows, Is.EqualTo(1)); + + command = batch.BatchCommands[3]; + Assert.That(command.StatementType, Is.EqualTo(StatementType.Merge)); + Assert.That(command.RecordsAffected, Is.EqualTo(1)); + Assert.That(command.Rows, Is.EqualTo(1)); + + command = batch.BatchCommands[4]; + Assert.That(command.StatementType, Is.EqualTo(StatementType.Merge)); + Assert.That(command.RecordsAffected, Is.EqualTo(0)); + Assert.That(command.Rows, Is.EqualTo(0)); + } + + [Test] + public async Task StatementTypes() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (name) VALUES ('a'), ('b')"), + new($"UPDATE {table} SET name='c' WHERE name='b'"), + new($"UPDATE {table} SET name='d' WHERE name='doesnt_exist'"), + new("BEGIN"), + new($"SELECT name FROM {table}"), + new($"DELETE FROM {table}"), + new("COMMIT") + } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume SELECT result set to parse the CommandComplete + await reader.CloseAsync(); + + Assert.That(batch.BatchCommands[0].StatementType, Is.EqualTo(StatementType.Insert)); + Assert.That(batch.BatchCommands[1].StatementType, Is.EqualTo(StatementType.Update)); + Assert.That(batch.BatchCommands[2].StatementType, Is.EqualTo(StatementType.Update)); + Assert.That(batch.BatchCommands[3].StatementType, Is.EqualTo(StatementType.Other)); + Assert.That(batch.BatchCommands[4].StatementType, Is.EqualTo(StatementType.Select)); + Assert.That(batch.BatchCommands[5].StatementType, Is.EqualTo(StatementType.Delete)); + Assert.That(batch.BatchCommands[6].StatementType, Is.EqualTo(StatementType.Other)); + } + + [Test] + public async Task StatementType_Call() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Stored procedures are supported starting with PG 11"); + + var sproc = await GetTempProcedureName(conn); + await conn.ExecuteNonQueryAsync($"CREATE PROCEDURE {sproc}() LANGUAGE sql AS ''"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new($"CALL {sproc}()") } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume SELECT result set to parse the CommandComplete + await reader.CloseAsync(); + + Assert.That(batch.BatchCommands[0].StatementType, Is.EqualTo(StatementType.Call)); + } + + [Test] + public async Task CommandType_StoredProcedure() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Stored procedures are supported starting with PG 11"); + + var sproc = await GetTempProcedureName(conn); + await conn.ExecuteNonQueryAsync($"CREATE PROCEDURE {sproc}() LANGUAGE sql AS ''"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new($"{sproc}") {CommandType = CommandType.StoredProcedure} } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume SELECT result set to parse the CommandComplete + await reader.CloseAsync(); + + Assert.That(batch.BatchCommands[0].StatementType, Is.EqualTo(StatementType.Call)); + } + + + [Test] + public async Task StatementType_Merge() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "15.0", "Stored procedures are supported starting with PG 11"); + + var table = await CreateTempTable(conn, "name TEXT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new($"MERGE INTO {table} S USING (SELECT 'b' as name) T ON T.name = S.name WHEN NOT MATCHED THEN DO NOTHING") } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + // Consume SELECT result set to parse the CommandComplete + await reader.CloseAsync(); + + Assert.That(batch.BatchCommands[0].StatementType, Is.EqualTo(StatementType.Merge)); + } + + [Test] + public async Task StatementOID() + { + using var conn = await OpenConnectionAsync(); + + MaximumPgVersionExclusive(conn, "12.0", + "Support for 'CREATE TABLE ... WITH OIDS' has been removed in 12.0. See https://www.postgresql.org/docs/12/release-12.html#id-1.11.6.5.4"); + + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TABLE {table} (name TEXT) WITH OIDS"); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (name) VALUES (@p1)") { Parameters = { new("p1", "foo") } }, + new($"UPDATE {table} SET name='b' WHERE name=@p2") { Parameters = { new("p2", "bar") } } + } + }; + + await batch.ExecuteNonQueryAsync(); + + Assert.That(batch.BatchCommands[0].OID, Is.Not.EqualTo(0)); + Assert.That(batch.BatchCommands[1].OID, Is.EqualTo(0)); + } + + [Test] + public void CanCreateParameter() => Assert.True(new NpgsqlBatchCommand().CanCreateParameter); + + [Test] + public void CreateParameter() => Assert.NotNull(new NpgsqlBatchCommand().CreateParameter()); + + #endregion NpgsqlBatchCommand + + #region Command behaviors + + [Test] + public async Task SingleResult() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + var reader = await batch.ExecuteReaderAsync(CommandBehavior.SingleResult | Behavior); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.NextResult(), Is.False); + } + + [Test] + public async Task SingleRow() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + await using var reader = await batch.ExecuteReaderAsync(CommandBehavior.SingleRow | Behavior); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.NextResult(), Is.False); + } + + [Test] + public async Task SchemaOnly_GetFieldType() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 'foo'") } + }; + + await using var reader = await batch.ExecuteReaderAsync(CommandBehavior.SchemaOnly | Behavior); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(string))); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + public async Task SchemaOnly_returns_no_data() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 'foo'") } + }; + + await using var reader = await batch.ExecuteReaderAsync(CommandBehavior.SchemaOnly | Behavior); + Assert.That(reader.Read(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(reader.Read(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/693")] + public async Task CloseConnection() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + await using (var reader = await batch.ExecuteReaderAsync(CommandBehavior.CloseConnection | Behavior)) + while (reader.Read()) {} + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + #endregion Command behaviors + + #region Error barriers + + [Test] + public async Task Batch_with_error_at_start([Values] bool withErrorBarriers) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (8)") + }, + EnableErrorBarriers = withErrorBarriers + }; + + var exception = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + Assert.That(exception.BatchCommand, Is.SameAs(batch.BatchCommands[0])); + + Assert.That(await conn.ExecuteScalarAsync($"SELECT count(*) FROM {table}"), withErrorBarriers + ? Is.EqualTo(1) + : Is.EqualTo(0)); + } + + [Test] + public async Task Batch_with_error_at_end([Values] bool withErrorBarriers) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (id) VALUES (8)"), + new("INVALID SQL") + }, + EnableErrorBarriers = withErrorBarriers + }; + + var exception = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + Assert.That(exception.BatchCommand, Is.SameAs(batch.BatchCommands[1])); + + Assert.That(await conn.ExecuteScalarAsync($"SELECT count(*) FROM {table}"), withErrorBarriers + ? Is.EqualTo(1) + : Is.EqualTo(0)); + } + + [Test] + public async Task Batch_with_multiple_errors([Values] bool withErrorBarriers) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (id) VALUES (8)"), + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (9)"), + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (10)") + }, + EnableErrorBarriers = withErrorBarriers + }; + + if (withErrorBarriers) + { + // A Sync is inserted after each command, so all commands are executed and all exceptions are thrown as an AggregateException + var exception = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + var aggregateException = (AggregateException)exception.InnerException!; + Assert.That(((PostgresException)aggregateException.InnerExceptions[0]).BatchCommand, Is.SameAs(batch.BatchCommands[1])); + Assert.That(((PostgresException)aggregateException.InnerExceptions[1]).BatchCommand, Is.SameAs(batch.BatchCommands[3])); + + Assert.That(await conn.ExecuteScalarAsync($"SELECT count(*) FROM {table}"), Is.EqualTo(3)); + } + else + { + // PG skips all commands after the first error; an exception is only raised for the first one, and the entire batch is + // rolled back (implicit transaction). + var exception = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + Assert.That(exception.BatchCommand, Is.SameAs(batch.BatchCommands[1])); + + Assert.That(await conn.ExecuteScalarAsync($"SELECT count(*) FROM {table}"), Is.EqualTo(0)); + } + + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Batch_close_dispose_reader_with_multiple_errors([Values] bool withErrorBarriers, [Values] bool dispose) + { + // Create a temp pool since we dispose the reader (and check the state afterwards) and it can be reused by another connection + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT NULL WHERE 1=0"), + new($"INSERT INTO {table} (id) VALUES (8)"), + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (9)"), + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (10)") + }, + EnableErrorBarriers = withErrorBarriers + }; + + await using (var reader = await batch.ExecuteReaderAsync(Behavior)) + { + if (withErrorBarriers) + { + // A Sync is inserted after each command, so all commands are executed and all exceptions are thrown as an AggregateException + var exception = Assert.ThrowsAsync(async () => + { + if (dispose) + await reader.DisposeAsync(); + else + await reader.CloseAsync(); + })!; + var aggregateException = (AggregateException)exception.InnerException!; + Assert.That(((PostgresException)aggregateException.InnerExceptions[0]).BatchCommand, Is.SameAs(batch.BatchCommands[2])); + Assert.That(((PostgresException)aggregateException.InnerExceptions[1]).BatchCommand, Is.SameAs(batch.BatchCommands[4])); + } + else + { + // PG skips all commands after the first error; an exception is only raised for the first one, and the entire batch is + // rolled back (implicit transaction). + var exception = Assert.ThrowsAsync(async () => + { + if (dispose) + await reader.DisposeAsync(); + else + await reader.CloseAsync(); + })!; + + Assert.That(exception.BatchCommand, Is.SameAs(batch.BatchCommands[2])); + } + + Assert.That(reader.State, Is.EqualTo(dispose ? ReaderState.Disposed : ReaderState.Closed)); + } + + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Batch_with_result_sets_and_error([Values] bool withErrorBarriers) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (id) VALUES (9)"), + new("SELECT 1"), + new("INVALID SQL"), + new($"INSERT INTO {table} (id) VALUES (9)"), + new("SELECT 2") + }, + EnableErrorBarriers = withErrorBarriers + }; + + await using (var reader = await batch.ExecuteReaderAsync(Behavior)) + { + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader[0], Is.EqualTo(1)); + Assert.That(await reader.ReadAsync(), Is.False); + + Assert.That(async () => await reader.NextResultAsync(), Throws.Exception.TypeOf()); + + Assert.That(reader.State, Is.EqualTo(ReaderState.Consumed)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + Assert.That(await conn.ExecuteScalarAsync($"SELECT count(*) FROM {table}"), withErrorBarriers + ? Is.EqualTo(2) + : Is.EqualTo(0)); + } + + [Test] + public async Task Error_with_AppendErrorBarrier() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (id) VALUES (8)"), + new("INVALID SQL") { AppendErrorBarrier = true }, + new($"INSERT INTO {table} (id) VALUES (9)") + } + }; + + // A Sync is placed after the 2nd command (INVALID SQL), so the 1st command is rolled back but not the 3rd. + var exception = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + Assert.That(exception.BatchCommand, Is.SameAs(batch.BatchCommands[1])); + + Assert.That(await conn.ExecuteScalarAsync($"SELECT id FROM {table} ORDER BY id"), Is.EqualTo(9)); + } + + [Test] + public async Task AppendErrorBarrier_on_last_command([Values] bool enabled) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new($"INSERT INTO {table} (id) VALUES (8)"), + new($"INSERT INTO {table} (id) VALUES (9)") { AppendErrorBarrier = enabled } + }, + EnableErrorBarriers = true + }; + + Assert.That(await batch.ExecuteNonQueryAsync(), Is.EqualTo(2)); + } + + [Test] + public async Task Error_barriers_with_SchemaOnly() + { + await using var conn = await OpenConnectionAsync(); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT 1"), + new("SELECT 'foo'") + }, + EnableErrorBarriers = true + }; + + await using var reader = await batch.ExecuteReaderAsync(CommandBehavior.SchemaOnly | Behavior); + + var columnSchema = await reader.GetColumnSchemaAsync(); + Assert.That(columnSchema[0].DataType, Is.SameAs(typeof(int))); + + Assert.That(await reader.NextResultAsync(), Is.True); + columnSchema = await reader.GetColumnSchemaAsync(); + Assert.That(columnSchema[0].DataType, Is.SameAs(typeof(string))); + } + + #endregion Error barriers + + #region Miscellaneous + + [Test] + public async Task Single_batch_command() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 8") } + }; + + await using var reader = await batch.ExecuteReaderAsync(Behavior); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + public async Task Empty_batch() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn); + await using var reader = await batch.ExecuteReaderAsync(Behavior); + + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + public async Task Semicolon_is_not_allowed() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1; SELECT 2") } + }; + + Assert.That(() => batch.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] + public async Task NpgsqlException_references_BatchCommand_with_single_command() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS + 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' +LANGUAGE 'plpgsql'"); + + // We use NpgsqlConnection.CreateBatch to test that the batch isn't recycled when referenced in an exception + var batch = conn.CreateBatch(); + batch.BatchCommands.Add(new($"SELECT {function}()")); + + var e = Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync(Behavior))!; + Assert.That(e.BatchCommand, Is.SameAs(batch.BatchCommands[0])); + + // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command + // resources are referenced by the exception above, which is very likely to escape the using statement of the command. + batch.Dispose(); + var cmd2 = conn.CreateBatch(); + Assert.AreNotSame(cmd2, batch); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] + public async Task NpgsqlException_references_BatchCommand_with_multiple_commands() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS + 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' +LANGUAGE 'plpgsql'"); + + // We use NpgsqlConnection.CreateBatch to test that the batch isn't recycled when referenced in an exception + var batch = conn.CreateBatch(); + batch.BatchCommands.Add(new("SELECT 1")); + batch.BatchCommands.Add(new($"SELECT {function}()")); + + await using (var reader = await batch.ExecuteReaderAsync(Behavior)) + { + + var e = Assert.ThrowsAsync(async () => await reader.NextResultAsync())!; + Assert.That(e.BatchCommand, Is.SameAs(batch.BatchCommands[1])); + } + + // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command + // resources are referenced by the exception above, which is very likely to escape the using statement of the command. + batch.Dispose(); + var cmd2 = conn.CreateBatch(); + Assert.AreNotSame(cmd2, batch); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4202")] + public async Task ExecuteScalar_without_parameters() + { + await using var conn = await OpenConnectionAsync(); + var batch = new NpgsqlBatch(conn) { BatchCommands = { new("SELECT 1") } }; + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4264")] + public async Task Batch_with_auto_prepare_reuse() + { + await using var dataSource = CreateDataSource(csb => csb.MaxAutoPrepare = 20); + await using var conn = await dataSource.OpenConnectionAsync(); + + var tempTableName = await CreateTempTable(conn, "id int"); + + await using var batch = new NpgsqlBatch(conn); + for (var i = 0; i < 2; ++i) + { + for (var j = 0; j < 10; ++j) + { + batch.BatchCommands.Add(new NpgsqlBatchCommand($"DELETE FROM {tempTableName} WHERE 1=0")); + } + await batch.ExecuteNonQueryAsync(); + batch.BatchCommands.Clear(); + } + } + +#if NET6_0_OR_GREATER // no batch reuse until 6.0 + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5239")] + public async Task Batch_dispose_reuse() + { + await using var conn = await OpenConnectionAsync(); + NpgsqlBatch firstBatch; + await using (var batch = conn.CreateBatch()) + { + firstBatch = batch; + + batch.BatchCommands.Add(new NpgsqlBatchCommand("SELECT 1")); + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + await using (var batch = conn.CreateBatch()) + { + Assert.That(batch, Is.SameAs(firstBatch)); + + batch.BatchCommands.Add(new NpgsqlBatchCommand("SELECT 2")); + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(2)); + } + + await conn.CloseAsync(); + await conn.OpenAsync(); + + await using (var batch = conn.CreateBatch()) + { + Assert.That(batch, Is.SameAs(firstBatch)); + + batch.BatchCommands.Add(new NpgsqlBatchCommand("SELECT 3")); + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(3)); + } + } +#endif + + #endregion Miscellaneous + + #region Logging + + [Test] + public async Task Log_ExecuteScalar_single_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1") } + }; + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Log_ExecuteScalar_multiple_statements_with_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT $1") { Parameters = { new() { Value = 8 } } }, + new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } + } + }; + + using (listLoggerProvider.Record()) + { + await batch.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + + // Note: the message formatter of Microsoft.Extensions.Logging doesn't seem to handle arrays inside tuples, so we get the + // following ugliness (https://github.com/dotnet/runtime/issues/63165). Serilog handles this fine. + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, System.Object[]), (SELECT $1, 9, System.Object[])]")); + AssertLoggingStateDoesNotContain(executingCommandEvent, "CommandText"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + + var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[0].Parameters[0], Is.EqualTo(8)); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1, 9")); + Assert.That(batchCommands[1].Parameters[0], Is.EqualTo(9)); + } + + [Test] + public async Task Log_ExecuteScalar_single_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT $1") { Parameters = { new() { Value = 8 } } }, + new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } + } + }; + + using (listLoggerProvider.Record()) + { + await batch.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1, 9]")); + var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1], Is.EqualTo("SELECT $1, 9")); + } + + #endregion Logging + + #region Initialization / setup / teardown + + // ReSharper disable InconsistentNaming + readonly bool IsSequential; + readonly CommandBehavior Behavior; + // ReSharper restore InconsistentNaming + + public BatchTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) : base(multiplexingMode) + { + Behavior = behavior; + IsSequential = (Behavior & CommandBehavior.SequentialAccess) != 0; + } + + #endregion +} diff --git a/test/Npgsql.Tests/BugTests.cs b/test/Npgsql.Tests/BugTests.cs index 0dd22ba8f1..e3c05dd5fb 100644 --- a/test/Npgsql.Tests/BugTests.cs +++ b/test/Npgsql.Tests/BugTests.cs @@ -1,825 +1,766 @@ -using System; -using System.Collections.Generic; +using Npgsql.BackendMessages; +using Npgsql.Tests.Support; +using NpgsqlTypes; +using NUnit.Framework; +using System; using System.Data; -using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; using System.Transactions; +using Npgsql.Internal.Postgres; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; -namespace Npgsql.Tests +public class BugTests : TestBase { - public class BugTests : TestBase + static uint ByteaOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Bytea).Value; + + #region Sequential reader bugs + + [Test, Description("In sequential access, performing a null check on a non-first field would check the first field")] + public void SequentialNullCheckOnNonFirstField() { - #region Sequential reader bugs + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 'X', NULL", conn); + using var dr = cmd.ExecuteReader(CommandBehavior.SequentialAccess); + dr.Read(); + Assert.That(dr.IsDBNull(1), Is.True); + } - [Test, Description("In sequential access, performing a null check on a non-first field would check the first field")] - public void SequentialNullCheckOnNonFirstField() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 'X', NULL", conn)) - using (var dr = cmd.ExecuteReader(CommandBehavior.SequentialAccess)) - { - dr.Read(); - Assert.That(dr.IsDBNull(1), Is.True); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1034")] + public void SequentialSkipOverFirstRow() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + using var reader = cmd.ExecuteReader(CommandBehavior.SequentialAccess); + Assert.That(reader.NextResult(), Is.True); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1034")] - public void SequentialSkipOverFirstRow() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - using (var reader = cmd.ExecuteReader(CommandBehavior.SequentialAccess)) - { - Assert.That(reader.NextResult(), Is.True); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(2)); - } - } + [Test] + public void SequentialConsumeWithNull() + { + using var conn = OpenConnection(); + using var command = new NpgsqlCommand("SELECT 1, NULL", conn); + using var reader = command.ExecuteReader(CommandBehavior.SequentialAccess); + reader.Read(); + } - [Test] - public void SequentialConsumeWithNull() + #endregion + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1210")] + public void Many_parameters_with_mixed_FormatCode() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand(); + cmd.Connection = conn; + var sb = new StringBuilder("SELECT @text_param"); + cmd.Parameters.AddWithValue("@text_param", "some_text"); + for (var i = 0; i < conn.Settings.WriteBufferSize; i++) { - using (var conn = OpenConnection()) - using (var command = new NpgsqlCommand("SELECT 1, NULL", conn)) - using (var reader = command.ExecuteReader(CommandBehavior.SequentialAccess)) - reader.Read(); + var paramName = $"@binary_param{i}"; + sb.Append(','); + sb.Append(paramName); + cmd.Parameters.AddWithValue(paramName, 8); } + cmd.CommandText = sb.ToString(); - #endregion - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1210")] - public void ManyParametersWithMixedFormatCode() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; - var sb = new StringBuilder("SELECT @text_param"); - cmd.Parameters.AddWithValue("@text_param", "some_text"); - for (var i = 0; i < conn.Settings.WriteBufferSize; i++) - { - var paramName = $"@binary_param{i}"; - sb.Append(","); - sb.Append(paramName); - cmd.Parameters.AddWithValue(paramName, 8); - } - cmd.CommandText = sb.ToString(); + var ex = Assert.Throws(() => cmd.ExecuteNonQuery())!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.ProgramLimitExceeded) + .Or.EqualTo(PostgresErrorCodes.TooManyColumns)); // PostgreSQL 14.5, 13.8, 12.12, 11.17 and 10.22 changed the returned error + } - Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("54000") - ); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1450")] + public void Bug1450() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand(); + cmd.Connection = conn; + cmd.CommandText = "CREATE TEMP TABLE a (a1 int); CREATE TEMP TABLE b (b1 int);"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = "CREATE TEMP TABLE c (c1 int);"; + cmd.ExecuteNonQuery(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1238")] - public void RecordWithNonIntField() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT ('one'::TEXT, 2)", conn)) - using (var reader = cmd.ExecuteReader()) + [Test] + public async Task Bug1645() + { + await using var conn = await OpenConnectionAsync(); + var tableName = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + Assert.That(() => { - reader.Read(); - var record = reader.GetFieldValue(0); - Assert.That(record[0], Is.EqualTo("one")); - Assert.That(record[1], Is.EqualTo(2)); - } - } + using var writer = conn.BeginBinaryImport($"COPY {tableName} (field_text, field_int4) FROM STDIN BINARY"); + writer.StartRow(); + writer.Write("foo"); + writer.Write(8); + + writer.StartRow(); + throw new InvalidOperationException("Catch me outside the using statement if you can!"); + }, Throws.Exception + .TypeOf() + .With.Property(nameof(InvalidOperationException.Message)).EqualTo("Catch me outside the using statement if you can!") + ); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {tableName}"), Is.Zero); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1450")] - public void Bug1450() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3600")] + public async Task Bug3600() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; - cmd.CommandText = "CREATE TEMP TABLE a (a1 int); CREATE TEMP TABLE b (b1 int);"; - cmd.ExecuteNonQuery(); + CommandTimeout = 1, + }; + await using var postmasterMock = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var serverMock = await postmasterMock.WaitForServerConnection(); + await serverMock + .WriteCopyInResponse() + .FlushAsync(); + var ex = Assert.ThrowsAsync(async () => + { + await using var importer = await conn.BeginTextImportAsync($"COPY SomeTable (field_text, field_int4) FROM STDIN"); + }); + Assert.That(ex!.InnerException, Is.TypeOf()); + } - cmd.CommandText = "CREATE TEMP TABLE c (c1 int);"; - cmd.ExecuteNonQuery(); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1497")] + public async Task Bug1497() + { + await using var conn = await OpenConnectionAsync(); + var tableName = await CreateTempTable(conn, "id INT4"); + conn.ExecuteNonQuery($"INSERT INTO {tableName} (id) VALUES (NULL)"); + await using var cmd = new NpgsqlCommand($"SELECT * FROM {tableName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + var dt = new DataTable(); + dt.Load(reader); + } - [Test] - public void Bug1645() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1558")] + public void Bug1558() + { + using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); - Assert.That(() => - { - using (var writer = conn.BeginBinaryImport("COPY data (field_text, field_int4) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write("foo"); - writer.Write(8); - - writer.StartRow(); - throw new InvalidOperationException("Catch me outside the using statement if you can!"); - } - }, Throws.Exception - .TypeOf() - .With.Property(nameof(InvalidOperationException.Message)).EqualTo("Catch me outside the using statement if you can!") - ); - Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM data"), Is.Zero); - } - } + csb.Pooling = false; + csb.Enlist = true; + }); + using var tx = new TransactionScope(); + using var conn = dataSource.OpenConnection(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1497")] - public void Bug1497() + [Test] + public void Bug1695() + { + using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE data (id INT4)"); - conn.ExecuteNonQuery("INSERT INTO data (id) VALUES (NULL)"); - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader()) - { - var dt = new DataTable(); - dt.Load(reader); - } - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1558")] - public void Bug1558() + csb.Pooling = false; + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 1; + }); + using var conn = dataSource.OpenConnection(); + using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) + using (var reader = cmd.ExecuteReader()) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - Enlist = true - }; - using (var tx = new TransactionScope()) - using (var conn = new NpgsqlConnection(csb.ToString())) - { - conn.Open(); - } + reader.Read(); + // Both statements should get prepared. However, purposefully skip processing the + // second resultset and make sure the second statement got prepared correctly. } + Assert.That(conn.ExecuteScalar("SELECT 2"), Is.EqualTo(2)); + } - [Test] - public void Bug1695() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1700")] + public void Bug1700() + { + Assert.That(() => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 1 - }; - using (var conn = OpenConnection(csb)) - { - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - // Both statements should get prepared. However, purposefully skip processing the - // second resultset and make sure the second statement got prepared correctly. - } - Assert.That(conn.ExecuteScalar("SELECT 2"), Is.EqualTo(2)); - } - } + using var conn = OpenConnection(); + using var tx = conn.BeginTransaction(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; + var reader = cmd.ExecuteReader(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1700")] - public void Bug1700() - { - Assert.That(() => + while (reader.Read()) { - using (var conn = OpenConnection()) - using (var tx = conn.BeginTransaction()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "SELECT 1"; - var reader = cmd.ExecuteReader(); - - while (reader.Read()) - { - // Simulate exception whilst processing the data reader... - throw new InvalidOperationException("Some problem parsing the returned data"); - - // As this exception unwinds the stack, it calls Dispose on the NpgsqlTransaction - // which then throws a NpgsqlOperationInProgressException as it tries to rollback - // the transaction. This hides the underlying cause of the problem (in this case - // our InvalidOperationException exception) - } - - // Note, we never get here - tx.Commit(); - } - }, Throws.InvalidOperationException.With.Message.EqualTo("Some problem parsing the returned data")); - } + // Simulate exception whilst processing the data reader... + throw new InvalidOperationException("Some problem parsing the returned data"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1964")] - public void Bug1964() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("INVALID SQL", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Direction = ParameterDirection.Output }); - Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42601")); + // As this exception unwinds the stack, it calls Dispose on the NpgsqlTransaction + // which then throws a NpgsqlOperationInProgressException as it tries to rollback + // the transaction. This hides the underlying cause of the problem (in this case + // our InvalidOperationException exception) } - } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1986")] - public void Bug1986() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 'hello', 'goodbye'", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - using (var textReader1 = reader.GetTextReader(0)) - { + // Note, we never get here + tx.Commit(); + }, Throws.InvalidOperationException.With.Message.EqualTo("Some problem parsing the returned data")); + } - } - using (var textReader2 = reader.GetTextReader(1)) - { + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1964")] + public void Bug1964() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("INVALID SQL", conn); + cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Direction = ParameterDirection.Output }); + Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1986")] + public void Bug1986() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 'hello', 'goodbye'", conn); + using var reader = cmd.ExecuteReader(); + reader.Read(); + using (var textReader1 = reader.GetTextReader(0)) + { - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1987")] - public void Bug1987() + } + using (var textReader2 = reader.GetTextReader(1)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 10, - AutoPrepareMinUsages = 2, - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum("mood"); - for (var i = 0; i < 2; i++) - { - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", Mood.Happy); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(Mood.Happy)); - } - } - } } + } - enum Mood { Sad, Ok, Happy }; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1987")] + public async Task Bug1987() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2003")] - public void Bug2003() - { - // A big RowDescription (larger than buffer size) causes an oversize buffer allocation, but which isn't - // picked up by sequential reader which continues to read from the original buffer. - using (var conn = OpenConnection()) - { - var longFieldName = new string('x', conn.Settings.ReadBufferSize); - using (var cmd = new NpgsqlCommand($"SELECT 8 AS {longFieldName}", conn)) - using (var reader = cmd.ExecuteReader(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - } - } - } + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - [Test] - public async Task Bug2046() + for (var i = 0; i < 2; i++) { - var expected = 64.27245f; - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT @p = 64.27245::real, 64.27245::real, @p", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.That(rdr.GetFieldValue(0)); - Assert.That(rdr.GetFieldValue(1), Is.EqualTo(expected)); - Assert.That(rdr.GetFieldValue(2), Is.EqualTo(expected)); - } - } + await using var cmd = new NpgsqlCommand("SELECT @p", connection); + cmd.Parameters.AddWithValue("p", Mood.Happy); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(Mood.Happy)); } + } - [Test] - [Ignore("Multiplexing: fails")] - public void Bug1761() - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Enlist = true, - Pooling = true, - MinPoolSize = 1, - MaxPoolSize = 1 - }.ConnectionString; + enum Mood { Sad, Ok, Happy }; - for (var i = 0; i < 2; i++) - { - try - { - using (var scope = new TransactionScope(TransactionScopeOption.Required, TimeSpan.FromMilliseconds(100))) - { - Thread.Sleep(1000); - - // Ambient transaction is now unusable, attempts to enlist to it will fail. We should recover - // properly from this failure. - - using (var connection = OpenConnection(connString)) - using (var cmd = new NpgsqlCommand("SELECT 1", connection)) - { - cmd.CommandText = "select 1;"; - cmd.ExecuteNonQuery(); - } - - scope.Complete(); - } - } - catch (TransactionException) - { - //do nothing - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2003")] + public void Bug2003() + { + // A big RowDescription (larger than buffer size) causes an oversize buffer allocation, but which isn't + // picked up by sequential reader which continues to read from the original buffer. + using var conn = OpenConnection(); + var longFieldName = new string('x', conn.Settings.ReadBufferSize); + using var cmd = new NpgsqlCommand($"SELECT 8 AS {longFieldName}", conn); + using var reader = cmd.ExecuteReader(CommandBehavior.SequentialAccess); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + } + + [Test] + public async Task Bug2046() + { + var expected = 64.27245f; + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT @p = 64.27245::real, 64.27245::real, @p", conn); + cmd.Parameters.AddWithValue("p", expected); + using var rdr = await cmd.ExecuteReaderAsync(); + rdr.Read(); + Assert.That(rdr.GetFieldValue(0)); + Assert.That(rdr.GetFieldValue(1), Is.EqualTo(expected)); + Assert.That(rdr.GetFieldValue(2), Is.EqualTo(expected)); + } - [Test] - public void Bug2274() + [Test] + public void Bug1761() + { + using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p", - Direction = ParameterDirection.Output - }); - using (var reader = cmd.ExecuteReader(CommandBehavior.SingleRow)) - { - Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - Assert.That(reader.Read(), Is.False); - } - } - } + csb.Enlist = true; + csb.Pooling = true; + csb.MinPoolSize = 1; + csb.MaxPoolSize = 1; + }); - [Test] - public void Bug2278() + for (var i = 0; i < 2; i++) { - using (var conn = OpenConnection()) + try { - try - { - conn.ExecuteNonQuery("CREATE TYPE enum_type AS ENUM ('left', 'right')"); - conn.ExecuteNonQuery("CREATE DOMAIN enum_domain AS enum_type NOT NULL"); - conn.ExecuteNonQuery("CREATE TYPE composite_type AS (value enum_domain)"); - conn.ExecuteNonQuery("CREATE TEMP TABLE data (value composite_type)"); - conn.ExecuteNonQuery("INSERT INTO data (value) VALUES (ROW('left'))"); + using var scope = new TransactionScope(TransactionScopeOption.Required, TimeSpan.FromMilliseconds(100)); + Thread.Sleep(1000); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite_type"); - conn.TypeMapper.MapEnum("enum_type"); + // Ambient transaction is now unusable, attempts to enlist to it will fail. We should recover + // properly from this failure. - conn.ExecuteScalar("SELECT * FROM data AS d"); - } - finally + using (var connection = dataSource.OpenConnection()) + using (var cmd = new NpgsqlCommand("SELECT 1", connection)) { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data; DROP TYPE IF EXISTS composite_type; DROP DOMAIN IF EXISTS enum_domain; DROP TYPE IF EXISTS enum_type"); - conn.ReloadTypes(); + cmd.CommandText = "select 1;"; + cmd.ExecuteNonQuery(); } - } - } - class Bug2278CompositeType - { - public Bug2278EnumType Value { get; set; } + scope.Complete(); + } + catch (TransactionException) + { + //do nothing + } } + } - enum Bug2278EnumType + [Test] + public void Bug2274() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.Add(new NpgsqlParameter { - Left, - Right - } - + ParameterName = "p", + Direction = ParameterDirection.Output + }); + using var reader = cmd.ExecuteReader(CommandBehavior.SingleRow); + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/2178")] - public void Bug2178() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - builder.AutoPrepareMinUsages = 2; - builder.MaxAutoPrepare = 2; - using (var conn = new NpgsqlConnection(builder.ConnectionString)) - using (var cmd = new NpgsqlCommand()) - { - conn.Open(); - cmd.Connection = conn; + [Test] + public async Task Bug2278() + { + await using var adminConnection = await OpenConnectionAsync(); + var enumType = await GetTempTypeName(adminConnection); + var domainType = await GetTempTypeName(adminConnection); + var compositeType = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {enumType} AS ENUM ('left', 'right'); +CREATE DOMAIN {domainType} AS {enumType} NOT NULL; +CREATE TYPE {compositeType} AS (value {domainType})"); + var table = await CreateTempTable(adminConnection, $"value {compositeType}"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + dataSourceBuilder.MapEnum(enumType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await connection.ExecuteScalarAsync($"SELECT * FROM {table} AS d"); + } - cmd.CommandText = "SELECT 1"; - cmd.ExecuteScalar(); - cmd.ExecuteScalar(); - Assert.That(cmd.IsPrepared); + class Bug2278CompositeType + { + public Bug2278EnumType Value { get; set; } + } - // Now executing a faulty command multiple times - cmd.CommandText = "SELECT * FROM public.dummy_table_name"; - for (var i = 0; i < 3; ++i) - { - try - { - cmd.ExecuteScalar(); - } - catch { } - } + enum Bug2278EnumType + { + Left, + Right + } - cmd.CommandText = "SELECT 1"; - cmd.ExecuteScalar(); - Assert.That(cmd.IsPrepared); - } - } - [Test] - public void Bug2296() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/2178")] + public async Task Bug2178() + { + await using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - { - try - { - conn.ExecuteNonQuery("CREATE DOMAIN pg_temp.\"boolean\" AS bool"); - conn.ExecuteNonQuery("CREATE TEMP TABLE data (mybool \"boolean\")"); - conn.ExecuteNonQuery("INSERT INTO data (mybool) VALUES (TRUE)"); + csb.AutoPrepareMinUsages = 2; + csb.MaxAutoPrepare = 2; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(); + cmd.Connection = conn; + + cmd.CommandText = "SELECT 1"; + await cmd.ExecuteScalarAsync(); + await cmd.ExecuteScalarAsync(); + Assert.That(cmd.IsPrepared); + + // Now executing a faulty command multiple times + cmd.CommandText = "SELECT * FROM public.dummy_table_name"; + for (var i = 0; i < 3; ++i) + { + Assert.ThrowsAsync(async () => await cmd.ExecuteScalarAsync()); + } - conn.ReloadTypes(); + cmd.CommandText = "SELECT 1"; + await cmd.ExecuteScalarAsync(); + Assert.That(cmd.IsPrepared); + } - conn.ExecuteScalar("SELECT mybool FROM data"); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data; DROP TYPE IF EXISTS \"boolean\""); - conn.ReloadTypes(); - } - } - } + [Test] + public async Task Bug2296() + { + await using var conn = await OpenConnectionAsync(); + // Note that the type has to be named boolean + await conn.ExecuteNonQueryAsync("DROP TYPE IF EXISTS \"boolean\" CASCADE"); + await conn.ExecuteNonQueryAsync("CREATE DOMAIN pg_temp.\"boolean\" AS bool"); + conn.ReloadTypes(); + var tableName = await CreateTempTable(conn, $"mybool \"boolean\""); + await conn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (mybool) VALUES (TRUE)"); + + await conn.ExecuteScalarAsync($"SELECT mybool FROM {tableName}"); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2660")] - public void StandardConformingStrings() - { - using var conn = OpenConnection(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2660")] + public void Standard_conforming_strings() + { + using var conn = OpenConnection(); - var sql = @" + var sql = @" SELECT table_name FROM information_schema.views WHERE table_name LIKE @p0 escape '\' AND (is_updatable = 'NO') = @p1"; - using var cmd = new NpgsqlCommand(sql, conn); - cmd.Parameters.AddWithValue("@p0", "%trig%"); - cmd.Parameters.AddWithValue("@p1", true); - using var reader = cmd.ExecuteReader(); - reader.Read(); - } + using var cmd = new NpgsqlCommand(sql, conn); + cmd.Parameters.AddWithValue("@p0", "%trig%"); + cmd.Parameters.AddWithValue("@p1", true); + using var reader = cmd.ExecuteReader(); + reader.Read(); + } - #region Bug1285 + #region Bug1285 - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1285")] - public void Bug1285() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand { Connection = conn }) - { - cmd.CommandText = Bug1285CreateStatement; - cmd.ExecuteNonQuery(); - - cmd.CommandText = Bug1285SelectStatement; - cmd.Parameters.Add(new NpgsqlParameter("@1", Guid.NewGuid())); - cmd.ExecuteNonQuery(); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1285")] + public void Bug1285() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand { Connection = conn }; + cmd.CommandText = Bug1285CreateStatement; + cmd.ExecuteNonQuery(); + + cmd.CommandText = Bug1285SelectStatement; + cmd.Parameters.Add(new NpgsqlParameter("@1", Guid.NewGuid())); + cmd.ExecuteNonQuery(); + } - const string Bug1285SelectStatement = - "select " + - " \"Afbeelding\"" + - ", \"Afbeelding#Id\"" + - ", \"Omschrijving\"" + - ", \"Website\"" + - ", \"EMailadres\"" + - ", \"SocialMediaVastleggen\"" + - ", \"Facebook\"" + - ", \"Linkedin\"" + - ", \"Twitter\"" + - ", \"Youtube\"" + - ", \"Branche\"" + - ", \"Branche#Id\"" + - ", \"Branche#ComponentId\"" + - ", \"Telefoonnummer\"" + - ", \"Overheidsidentificatienummer\"" + - ", \"Adres\"" + - ", \"Adres#Id\"" + - ", \"Adres#ComponentId\"" + - ", \"2gIKz62kaTdGxw82_OrzTc4ANSM_\"" + - ", \"OnderdeelVanOrganisatie\"" + - ", \"OnderdeelVanOrganisatie#Id\"" + - ", \"OnderdeelVanOrganisatie#ComponentId\"" + - ", \"Profit\"" + - ", \"Profit#Id\"" + - ", \"Profit#ComponentId\"" + - ", \"Taal\"" + - ", \"Taal#Id\"" + - ", \"Taal#ComponentId\"" + - ", \"KlantSinds\"" + - ", \"BarcodeKlant\"" + - ", \"BarcodeKlant#BarcodeType\"" + - ", \"BarcodeKlant#Value\"" + - ", \"Postadres\"" + - ", \"Postadres#Id\"" + - ", \"Postadres#ComponentId\"" + - ", \"LeverancierSinds\"" + - ", \"BarcodeLeverancier\"" + - ", \"BarcodeLeverancier#BarcodeType\"" + - ", \"BarcodeLeverancier#Value\"" + - ", \"Zoeknaam\"" + - ", \"TitelEnAanhef\"" + - ", \"TitelEnAanhef#Id\"" + - ", \"TitelEnAanhef#ComponentId\"" + - ", \"Rechtsvorm\"" + - ", \"Rechtsvorm#Id\"" + - ", \"Rechtsvorm#ComponentId\"" + - ", \"LandVanVestiging\"" + - ", \"LandVanVestiging#Id\"" + - ", \"LandVanVestiging#ComponentId\"" + - ", \"AantalMedewerkers\"" + - ", \"NaamStatutair\"" + - ", \"VestigingStatutair\"" + - ", \"Correspondentie\"" + - ", \"Medium\"" + - ", \"FiscaalNummer\"" + - ", \"AangebrachtDoor\"" + - ", \"AangebrachtDoor#Id\"" + - ", \"AangebrachtDoor#ComponentId\"" + - ", \"Status\"" + - ", \"NummerKamerVanKoophandel\"" + - ", \"zII9SOHbwUPS_jKSlcRrQzuEr6A_\"" + - ", \"Code\"" + - ", \"OrganisatorischeEenheid\"" + - ", \"OrganisatorischeEenheid#Id\"" + - ", \"OrganisatorischeEenheid#ComponentId\"" + - ", \"PJlr3asVdeHVoqmAyHIF2fF1gVM_\"" + - ", \"IsUwv\"" + - ", \"Uwv\"" + - ", \"Uwv#Id\"" + - ", \"Uwv#ComponentId\"" + - ", \"IsVerzekeraarVoorWerkgever\"" + - ", \"VerzekeraarVoorWerkgever\"" + - ", \"VerzekeraarVoorWerkgever#Id\"" + - ", \"VerzekeraarVoorWerkgever#ComponentId\"" + - ", \"IsAbonnementsadministratie\"" + - ", \"Abonnementsadministratie\"" + - ", \"Abonnementsadministratie#Id\"" + - ", \"Abonnementsadministratie#ComponentId\"" + - ", \"IsPensioenfonds\"" + - ", \"Pensioenfonds\"" + - ", \"Pensioenfonds#Id\"" + - ", \"Pensioenfonds#ComponentId\"" + - ", \"IsProfit\"" + - ", \"Profit1\"" + - ", \"Profit1#Id\"" + - ", \"Profit1#ComponentId\"" + - ", \"Verantwoordelijke\"" + - ", \"Verantwoordelijke#Id\"" + - ", \"Verantwoordelijke#ComponentId\"" + - ", \"IsKlant\"" + - ", \"Klant\"" + - ", \"Klant#Id\"" + - ", \"Klant#ComponentId\"" + - ", \"IsFactureringsadministratie\"" + - ", \"Factureringsadministratie\"" + - ", \"Factureringsadministratie#Id\"" + - ", \"Factureringsadministratie#ComponentId\"" + - ", \"IsLeninggever\"" + - ", \"Leninggever\"" + - ", \"Leninggever#Id\"" + - ", \"Leninggever#ComponentId\"" + - ", \"IsProjectadministratie\"" + - ", \"Projectadministratie\"" + - ", \"Projectadministratie#Id\"" + - ", \"Projectadministratie#ComponentId\"" + - ", \"IsVervangingsfonds\"" + - ", \"Vervangingsfonds\"" + - ", \"Vervangingsfonds#Id\"" + - ", \"Vervangingsfonds#ComponentId\"" + - ", \"IsPensioen\"" + - ", \"Pensioen\"" + - ", \"Pensioen#Id\"" + - ", \"Pensioen#ComponentId\"" + - ", \"IsVasteActivaAdministratie\"" + - ", \"VasteActivaAdministratie\"" + - ", \"VasteActivaAdministratie#Id\"" + - ", \"VasteActivaAdministratie#ComponentId\"" + - ", \"IsBelastingdienst\"" + - ", \"Belastingdienst\"" + - ", \"Belastingdienst#Id\"" + - ", \"Belastingdienst#ComponentId\"" + - ", \"IsCursusadministratie\"" + - ", \"Cursusadministratie\"" + - ", \"Cursusadministratie#Id\"" + - ", \"Cursusadministratie#ComponentId\"" + - ", \"IsUitvoerderVervangingsfonds\"" + - ", \"UitvoerderVervangingsfonds\"" + - ", \"UitvoerderVervangingsfonds#Id\"" + - ", \"UitvoerderVervangingsfonds#ComponentId\"" + - ", \"IsAssemblageadministratie\"" + - ", \"Assemblageadministratie\"" + - ", \"Assemblageadministratie#Id\"" + - ", \"Assemblageadministratie#ComponentId\"" + - ", \"IsLeasemaatschappij\"" + - ", \"Leasemaatschappij\"" + - ", \"Leasemaatschappij#Id\"" + - ", \"Leasemaatschappij#ComponentId\"" + - ", \"IsBeslaglegger\"" + - ", \"Beslaglegger\"" + - ", \"Beslaglegger#Id\"" + - ", \"Beslaglegger#ComponentId\"" + - ", \"IsVerzekeraarVoorMedewerker\"" + - ", \"VerzekeraarVoorMedewerker\"" + - ", \"VerzekeraarVoorMedewerker#Id\"" + - ", \"VerzekeraarVoorMedewerker#ComponentId\"" + - ", \"IsWoningverhuur\"" + - ", \"Woningverhuur\"" + - ", \"Woningverhuur#Id\"" + - ", \"Woningverhuur#ComponentId\"" + - ", \"9cSTu7fkjOZm08FFQUHvclQHSWY_\"" + - ", \"Goederenstroomadministratie\"" + - ", \"Goederenstroomadministratie#Id\"" + - ", \"Goederenstroomadministratie#ComponentId\"" + - ", \"IsConcurrent\"" + - ", \"Concurrent\"" + - ", \"Concurrent#Id\"" + - ", \"Concurrent#ComponentId\"" + - ", \"IsArbodienst\"" + - ", \"Arbodienst\"" + - ", \"Arbodienst#Id\"" + - ", \"Arbodienst#ComponentId\"" + - ", \"IsUitvoerderSociaalFonds\"" + - ", \"UitvoerderSociaalFonds\"" + - ", \"UitvoerderSociaalFonds#Id\"" + - ", \"UitvoerderSociaalFonds#ComponentId\"" + - ", \"IsPensioenuitvoerder\"" + - ", \"Pensioenuitvoerder\"" + - ", \"Pensioenuitvoerder#Id\"" + - ", \"Pensioenuitvoerder#ComponentId\"" + - ", \"IsProspect\"" + - ", \"Prospect\"" + - ", \"Prospect#Id\"" + - ", \"Prospect#ComponentId\"" + - ", \"IsProspectadministratie\"" + - ", \"Prospectadministratie\"" + - ", \"Prospectadministratie#Id\"" + - ", \"Prospectadministratie#ComponentId\"" + - ", \"IsArtikelbeheeradministratie\"" + - ", \"Artikelbeheeradministratie\"" + - ", \"Artikelbeheeradministratie#Id\"" + - ", \"Artikelbeheeradministratie#ComponentId\"" + - ", \"v1J4Rq2eNZy9GBvGhuCBKqga0Rg_\"" + - ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_\"" + - ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_#Id\"" + - ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_#ComponentId\"" + - ", \"IsSociaalFonds\"" + - ", \"SociaalFonds\"" + - ", \"SociaalFonds#Id\"" + - ", \"SociaalFonds#ComponentId\"" + - ", \"IsWagenparkadministratie\"" + - ", \"Wagenparkadministratie\"" + - ", \"Wagenparkadministratie#Id\"" + - ", \"Wagenparkadministratie#ComponentId\"" + - ", \"IsLeverancier\"" + - ", \"Leverancier\"" + - ", \"Leverancier#Id\"" + - ", \"Leverancier#ComponentId\"" + - ", \"IsWagenparkAfnemer\"" + - ", \"WagenparkAfnemer\"" + - ", \"WagenparkAfnemer#Id\"" + - ", \"WagenparkAfnemer#ComponentId\"" + - ", \"IsEvenementadministratie\"" + - ", \"Evenementadministratie\"" + - ", \"Evenementadministratie#Id\"" + - ", \"Evenementadministratie#ComponentId\"" + - ", \"IsCrediteur\"" + - ", \"Crediteur\"" + - ", \"Crediteur#Id\"" + - ", \"Crediteur#ComponentId\"" + - ", \"IsFinancieleAdministratie\"" + - ", \"FinancieleAdministratie\"" + - ", \"FinancieleAdministratie#Id\"" + - ", \"FinancieleAdministratie#ComponentId\"" + - ", \"IsSociaalSecretariaat\"" + - ", \"SociaalSecretariaat\"" + - ", \"SociaalSecretariaat#Id\"" + - ", \"SociaalSecretariaat#ComponentId\"" + - ", \"IsBank\"" + - ", \"Bank\"" + - ", \"Bank#Id\"" + - ", \"Bank#ComponentId\"" + - ", \"IsWerkgever\"" + - ", \"Werkgever\"" + - ", \"Werkgever#Id\"" + - ", \"Werkgever#ComponentId\"" + - ", \"IsVerkoopadministratie\"" + - ", \"Verkoopadministratie\"" + - ", \"Verkoopadministratie#Id\"" + - ", \"Verkoopadministratie#ComponentId\"" + - ", \"IsInkoopadministratie\"" + - ", \"Inkoopadministratie\"" + - ", \"Inkoopadministratie#Id\"" + - ", \"Inkoopadministratie#ComponentId\"" + - ", \"IsSubsidient\"" + - ", \"Subsidient\"" + - ", \"Subsidient#Id\"" + - ", \"Subsidient#ComponentId\"" + - ", \"IsEnqueteadministratie\"" + - ", \"Enqueteadministratie\"" + - ", \"Enqueteadministratie#Id\"" + - ", \"Enqueteadministratie#ComponentId\"" + - ", \"XWiQaVjEbD041r7QN0kj2aKeCys_\"" + - ", \"X_TVE5FRBaQ97JJQbT7LmX4HBVY_\"" + - ", \"BrancheDescription\"" + - ", \"AdresDescription\"" + - ", \"PicMWY7oCqeiZMTdDqwFqi1U508_\"" + - ", \"ProfitDescription\"" + - ", \"TaalDescription\"" + - ", \"PostadresDescription\"" + - ", \"TitelEnAanhefDescription\"" + - ", \"RechtsvormDescription\"" + - ", \"LandVanVestigingDescription\"" + - ", \"AangebrachtDoorDescription\"" + - ", \"KFEDo6ZYK_ffNCeHuujBCshlPVs_\"" + - ", \"VerantwoordelijkeDescription\"" + - ", \"FunctionalState\"" + - ", \"KlantFinancieleAdministratie\"" + - ", \"KlantFinancieleAdministratie#Id\"" + - ", \"KlantFinancieleAdministratie#ComponentId\"" + - ", \"KlantVerkoopadministratie\"" + - ", \"KlantVerkoopadministratie#Id\"" + - ", \"KlantVerkoopadministratie#ComponentId\"" + - ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_\"" + - ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_#Id\"" + - ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_#ComponentId\"" + - ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_\"" + - ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_#Id\"" + - ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_#ComponentId\"" + - ", \"KlantEvenementadministratie\"" + - ", \"KlantEvenementadministratie#Id\"" + - ", \"KlantEvenementadministratie#ComponentId\"" + - ", \"KlantCursusadministratie\"" + - ", \"KlantCursusadministratie#Id\"" + - ", \"KlantCursusadministratie#ComponentId\"" + - ", \"KlantProspectadministratie\"" + - ", \"KlantProspectadministratie#Id\"" + - ", \"KlantProspectadministratie#ComponentId\"" + - ", \"KlantInkoopadministratie\"" + - ", \"KlantInkoopadministratie#Id\"" + - ", \"KlantInkoopadministratie#ComponentId\"" + - ", \"KlantProjectadministratie\"" + - ", \"KlantProjectadministratie#Id\"" + - ", \"KlantProjectadministratie#ComponentId\"" + - ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_\"" + - ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_#Id\"" + - ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_#ComponentId\"" + - ", \"KlantMain\"" + - ", \"KlantMain#Id\"" + - ", \"KlantMain#ComponentId\"" + - ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_\"" + - ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_#Id\"" + - ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_#ComponentId\"" + - ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_\"" + - ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_#Id\"" + - ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_#ComponentId\"" + - ", \"LeverancierMain\"" + - ", \"LeverancierMain#Id\"" + - ", \"LeverancierMain#ComponentId\"" + - ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_\"" + - ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_#Id\"" + - ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_#ComponentId\"" + - ", \"kleo39GG1utTUtP0F15mWkXZBFQ_\"" + - ", \"kleo39GG1utTUtP0F15mWkXZBFQ_#Id\"" + - ", \"kleo39GG1utTUtP0F15mWkXZBFQ_#ComponentId\"" + - ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_\"" + - ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_#Id\"" + - ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_#ComponentId\"" + - ", \"CrediteurMain\"" + - ", \"CrediteurMain#Id\"" + - ", \"CrediteurMain#ComponentId\"" + - ", \"TTStart\"" + - ", \"TTEnd\"" + - ", \"InstanceId\"" + - ", \"StartDate\"" + - ", \"EndDate\"" + - ", \"UserId\"" + - ", \"Id\"" + - " from \"OrganisatieQmo_Organisatie_QueryModelObjects_Imp\" WHERE (\"InstanceId\" = @1) AND ((\"StartDate\" IS NULL) AND (\"TTEnd\" IS NULL)) ORDER BY \"Id\" ASC NULLS FIRST OFFSET 0 ROWS FETCH NEXT 2 ROWS ONLY;"; - - const string Bug1285CreateStatement = @" + const string Bug1285SelectStatement = + "select " + + " \"Afbeelding\"" + + ", \"Afbeelding#Id\"" + + ", \"Omschrijving\"" + + ", \"Website\"" + + ", \"EMailadres\"" + + ", \"SocialMediaVastleggen\"" + + ", \"Facebook\"" + + ", \"Linkedin\"" + + ", \"Twitter\"" + + ", \"Youtube\"" + + ", \"Branche\"" + + ", \"Branche#Id\"" + + ", \"Branche#ComponentId\"" + + ", \"Telefoonnummer\"" + + ", \"Overheidsidentificatienummer\"" + + ", \"Adres\"" + + ", \"Adres#Id\"" + + ", \"Adres#ComponentId\"" + + ", \"2gIKz62kaTdGxw82_OrzTc4ANSM_\"" + + ", \"OnderdeelVanOrganisatie\"" + + ", \"OnderdeelVanOrganisatie#Id\"" + + ", \"OnderdeelVanOrganisatie#ComponentId\"" + + ", \"Profit\"" + + ", \"Profit#Id\"" + + ", \"Profit#ComponentId\"" + + ", \"Taal\"" + + ", \"Taal#Id\"" + + ", \"Taal#ComponentId\"" + + ", \"KlantSinds\"" + + ", \"BarcodeKlant\"" + + ", \"BarcodeKlant#BarcodeType\"" + + ", \"BarcodeKlant#Value\"" + + ", \"Postadres\"" + + ", \"Postadres#Id\"" + + ", \"Postadres#ComponentId\"" + + ", \"LeverancierSinds\"" + + ", \"BarcodeLeverancier\"" + + ", \"BarcodeLeverancier#BarcodeType\"" + + ", \"BarcodeLeverancier#Value\"" + + ", \"Zoeknaam\"" + + ", \"TitelEnAanhef\"" + + ", \"TitelEnAanhef#Id\"" + + ", \"TitelEnAanhef#ComponentId\"" + + ", \"Rechtsvorm\"" + + ", \"Rechtsvorm#Id\"" + + ", \"Rechtsvorm#ComponentId\"" + + ", \"LandVanVestiging\"" + + ", \"LandVanVestiging#Id\"" + + ", \"LandVanVestiging#ComponentId\"" + + ", \"AantalMedewerkers\"" + + ", \"NaamStatutair\"" + + ", \"VestigingStatutair\"" + + ", \"Correspondentie\"" + + ", \"Medium\"" + + ", \"FiscaalNummer\"" + + ", \"AangebrachtDoor\"" + + ", \"AangebrachtDoor#Id\"" + + ", \"AangebrachtDoor#ComponentId\"" + + ", \"Status\"" + + ", \"NummerKamerVanKoophandel\"" + + ", \"zII9SOHbwUPS_jKSlcRrQzuEr6A_\"" + + ", \"Code\"" + + ", \"OrganisatorischeEenheid\"" + + ", \"OrganisatorischeEenheid#Id\"" + + ", \"OrganisatorischeEenheid#ComponentId\"" + + ", \"PJlr3asVdeHVoqmAyHIF2fF1gVM_\"" + + ", \"IsUwv\"" + + ", \"Uwv\"" + + ", \"Uwv#Id\"" + + ", \"Uwv#ComponentId\"" + + ", \"IsVerzekeraarVoorWerkgever\"" + + ", \"VerzekeraarVoorWerkgever\"" + + ", \"VerzekeraarVoorWerkgever#Id\"" + + ", \"VerzekeraarVoorWerkgever#ComponentId\"" + + ", \"IsAbonnementsadministratie\"" + + ", \"Abonnementsadministratie\"" + + ", \"Abonnementsadministratie#Id\"" + + ", \"Abonnementsadministratie#ComponentId\"" + + ", \"IsPensioenfonds\"" + + ", \"Pensioenfonds\"" + + ", \"Pensioenfonds#Id\"" + + ", \"Pensioenfonds#ComponentId\"" + + ", \"IsProfit\"" + + ", \"Profit1\"" + + ", \"Profit1#Id\"" + + ", \"Profit1#ComponentId\"" + + ", \"Verantwoordelijke\"" + + ", \"Verantwoordelijke#Id\"" + + ", \"Verantwoordelijke#ComponentId\"" + + ", \"IsKlant\"" + + ", \"Klant\"" + + ", \"Klant#Id\"" + + ", \"Klant#ComponentId\"" + + ", \"IsFactureringsadministratie\"" + + ", \"Factureringsadministratie\"" + + ", \"Factureringsadministratie#Id\"" + + ", \"Factureringsadministratie#ComponentId\"" + + ", \"IsLeninggever\"" + + ", \"Leninggever\"" + + ", \"Leninggever#Id\"" + + ", \"Leninggever#ComponentId\"" + + ", \"IsProjectadministratie\"" + + ", \"Projectadministratie\"" + + ", \"Projectadministratie#Id\"" + + ", \"Projectadministratie#ComponentId\"" + + ", \"IsVervangingsfonds\"" + + ", \"Vervangingsfonds\"" + + ", \"Vervangingsfonds#Id\"" + + ", \"Vervangingsfonds#ComponentId\"" + + ", \"IsPensioen\"" + + ", \"Pensioen\"" + + ", \"Pensioen#Id\"" + + ", \"Pensioen#ComponentId\"" + + ", \"IsVasteActivaAdministratie\"" + + ", \"VasteActivaAdministratie\"" + + ", \"VasteActivaAdministratie#Id\"" + + ", \"VasteActivaAdministratie#ComponentId\"" + + ", \"IsBelastingdienst\"" + + ", \"Belastingdienst\"" + + ", \"Belastingdienst#Id\"" + + ", \"Belastingdienst#ComponentId\"" + + ", \"IsCursusadministratie\"" + + ", \"Cursusadministratie\"" + + ", \"Cursusadministratie#Id\"" + + ", \"Cursusadministratie#ComponentId\"" + + ", \"IsUitvoerderVervangingsfonds\"" + + ", \"UitvoerderVervangingsfonds\"" + + ", \"UitvoerderVervangingsfonds#Id\"" + + ", \"UitvoerderVervangingsfonds#ComponentId\"" + + ", \"IsAssemblageadministratie\"" + + ", \"Assemblageadministratie\"" + + ", \"Assemblageadministratie#Id\"" + + ", \"Assemblageadministratie#ComponentId\"" + + ", \"IsLeasemaatschappij\"" + + ", \"Leasemaatschappij\"" + + ", \"Leasemaatschappij#Id\"" + + ", \"Leasemaatschappij#ComponentId\"" + + ", \"IsBeslaglegger\"" + + ", \"Beslaglegger\"" + + ", \"Beslaglegger#Id\"" + + ", \"Beslaglegger#ComponentId\"" + + ", \"IsVerzekeraarVoorMedewerker\"" + + ", \"VerzekeraarVoorMedewerker\"" + + ", \"VerzekeraarVoorMedewerker#Id\"" + + ", \"VerzekeraarVoorMedewerker#ComponentId\"" + + ", \"IsWoningverhuur\"" + + ", \"Woningverhuur\"" + + ", \"Woningverhuur#Id\"" + + ", \"Woningverhuur#ComponentId\"" + + ", \"9cSTu7fkjOZm08FFQUHvclQHSWY_\"" + + ", \"Goederenstroomadministratie\"" + + ", \"Goederenstroomadministratie#Id\"" + + ", \"Goederenstroomadministratie#ComponentId\"" + + ", \"IsConcurrent\"" + + ", \"Concurrent\"" + + ", \"Concurrent#Id\"" + + ", \"Concurrent#ComponentId\"" + + ", \"IsArbodienst\"" + + ", \"Arbodienst\"" + + ", \"Arbodienst#Id\"" + + ", \"Arbodienst#ComponentId\"" + + ", \"IsUitvoerderSociaalFonds\"" + + ", \"UitvoerderSociaalFonds\"" + + ", \"UitvoerderSociaalFonds#Id\"" + + ", \"UitvoerderSociaalFonds#ComponentId\"" + + ", \"IsPensioenuitvoerder\"" + + ", \"Pensioenuitvoerder\"" + + ", \"Pensioenuitvoerder#Id\"" + + ", \"Pensioenuitvoerder#ComponentId\"" + + ", \"IsProspect\"" + + ", \"Prospect\"" + + ", \"Prospect#Id\"" + + ", \"Prospect#ComponentId\"" + + ", \"IsProspectadministratie\"" + + ", \"Prospectadministratie\"" + + ", \"Prospectadministratie#Id\"" + + ", \"Prospectadministratie#ComponentId\"" + + ", \"IsArtikelbeheeradministratie\"" + + ", \"Artikelbeheeradministratie\"" + + ", \"Artikelbeheeradministratie#Id\"" + + ", \"Artikelbeheeradministratie#ComponentId\"" + + ", \"v1J4Rq2eNZy9GBvGhuCBKqga0Rg_\"" + + ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_\"" + + ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_#Id\"" + + ", \"g4i3gYZGL0yu0T6UwmiZTaUDI8Y_#ComponentId\"" + + ", \"IsSociaalFonds\"" + + ", \"SociaalFonds\"" + + ", \"SociaalFonds#Id\"" + + ", \"SociaalFonds#ComponentId\"" + + ", \"IsWagenparkadministratie\"" + + ", \"Wagenparkadministratie\"" + + ", \"Wagenparkadministratie#Id\"" + + ", \"Wagenparkadministratie#ComponentId\"" + + ", \"IsLeverancier\"" + + ", \"Leverancier\"" + + ", \"Leverancier#Id\"" + + ", \"Leverancier#ComponentId\"" + + ", \"IsWagenparkAfnemer\"" + + ", \"WagenparkAfnemer\"" + + ", \"WagenparkAfnemer#Id\"" + + ", \"WagenparkAfnemer#ComponentId\"" + + ", \"IsEvenementadministratie\"" + + ", \"Evenementadministratie\"" + + ", \"Evenementadministratie#Id\"" + + ", \"Evenementadministratie#ComponentId\"" + + ", \"IsCrediteur\"" + + ", \"Crediteur\"" + + ", \"Crediteur#Id\"" + + ", \"Crediteur#ComponentId\"" + + ", \"IsFinancieleAdministratie\"" + + ", \"FinancieleAdministratie\"" + + ", \"FinancieleAdministratie#Id\"" + + ", \"FinancieleAdministratie#ComponentId\"" + + ", \"IsSociaalSecretariaat\"" + + ", \"SociaalSecretariaat\"" + + ", \"SociaalSecretariaat#Id\"" + + ", \"SociaalSecretariaat#ComponentId\"" + + ", \"IsBank\"" + + ", \"Bank\"" + + ", \"Bank#Id\"" + + ", \"Bank#ComponentId\"" + + ", \"IsWerkgever\"" + + ", \"Werkgever\"" + + ", \"Werkgever#Id\"" + + ", \"Werkgever#ComponentId\"" + + ", \"IsVerkoopadministratie\"" + + ", \"Verkoopadministratie\"" + + ", \"Verkoopadministratie#Id\"" + + ", \"Verkoopadministratie#ComponentId\"" + + ", \"IsInkoopadministratie\"" + + ", \"Inkoopadministratie\"" + + ", \"Inkoopadministratie#Id\"" + + ", \"Inkoopadministratie#ComponentId\"" + + ", \"IsSubsidient\"" + + ", \"Subsidient\"" + + ", \"Subsidient#Id\"" + + ", \"Subsidient#ComponentId\"" + + ", \"IsEnqueteadministratie\"" + + ", \"Enqueteadministratie\"" + + ", \"Enqueteadministratie#Id\"" + + ", \"Enqueteadministratie#ComponentId\"" + + ", \"XWiQaVjEbD041r7QN0kj2aKeCys_\"" + + ", \"X_TVE5FRBaQ97JJQbT7LmX4HBVY_\"" + + ", \"BrancheDescription\"" + + ", \"AdresDescription\"" + + ", \"PicMWY7oCqeiZMTdDqwFqi1U508_\"" + + ", \"ProfitDescription\"" + + ", \"TaalDescription\"" + + ", \"PostadresDescription\"" + + ", \"TitelEnAanhefDescription\"" + + ", \"RechtsvormDescription\"" + + ", \"LandVanVestigingDescription\"" + + ", \"AangebrachtDoorDescription\"" + + ", \"KFEDo6ZYK_ffNCeHuujBCshlPVs_\"" + + ", \"VerantwoordelijkeDescription\"" + + ", \"FunctionalState\"" + + ", \"KlantFinancieleAdministratie\"" + + ", \"KlantFinancieleAdministratie#Id\"" + + ", \"KlantFinancieleAdministratie#ComponentId\"" + + ", \"KlantVerkoopadministratie\"" + + ", \"KlantVerkoopadministratie#Id\"" + + ", \"KlantVerkoopadministratie#ComponentId\"" + + ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_\"" + + ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_#Id\"" + + ", \"NQwaM_lfzxIm_JPrPhwruL0YOXY_#ComponentId\"" + + ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_\"" + + ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_#Id\"" + + ", \"uSf1yseW4YQ1K6EoHNlzCxPofm0_#ComponentId\"" + + ", \"KlantEvenementadministratie\"" + + ", \"KlantEvenementadministratie#Id\"" + + ", \"KlantEvenementadministratie#ComponentId\"" + + ", \"KlantCursusadministratie\"" + + ", \"KlantCursusadministratie#Id\"" + + ", \"KlantCursusadministratie#ComponentId\"" + + ", \"KlantProspectadministratie\"" + + ", \"KlantProspectadministratie#Id\"" + + ", \"KlantProspectadministratie#ComponentId\"" + + ", \"KlantInkoopadministratie\"" + + ", \"KlantInkoopadministratie#Id\"" + + ", \"KlantInkoopadministratie#ComponentId\"" + + ", \"KlantProjectadministratie\"" + + ", \"KlantProjectadministratie#Id\"" + + ", \"KlantProjectadministratie#ComponentId\"" + + ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_\"" + + ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_#Id\"" + + ", \"P0gBTGcrSbl2kK8BM4_24fIeMvk_#ComponentId\"" + + ", \"KlantMain\"" + + ", \"KlantMain#Id\"" + + ", \"KlantMain#ComponentId\"" + + ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_\"" + + ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_#Id\"" + + ", \"8awXARDcCdVtrvN6IRlwAk5UcrI_#ComponentId\"" + + ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_\"" + + ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_#Id\"" + + ", \"MjgFAhRfH64O3M9Ts_5b0ENQDBE_#ComponentId\"" + + ", \"LeverancierMain\"" + + ", \"LeverancierMain#Id\"" + + ", \"LeverancierMain#ComponentId\"" + + ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_\"" + + ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_#Id\"" + + ", \"2CaNQvGgtPNHP2XCsoKBQEpwmYA_#ComponentId\"" + + ", \"kleo39GG1utTUtP0F15mWkXZBFQ_\"" + + ", \"kleo39GG1utTUtP0F15mWkXZBFQ_#Id\"" + + ", \"kleo39GG1utTUtP0F15mWkXZBFQ_#ComponentId\"" + + ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_\"" + + ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_#Id\"" + + ", \"d9jEYARbghWBrZU6jKbtZPyZUAk_#ComponentId\"" + + ", \"CrediteurMain\"" + + ", \"CrediteurMain#Id\"" + + ", \"CrediteurMain#ComponentId\"" + + ", \"TTStart\"" + + ", \"TTEnd\"" + + ", \"InstanceId\"" + + ", \"StartDate\"" + + ", \"EndDate\"" + + ", \"UserId\"" + + ", \"Id\"" + + " from \"OrganisatieQmo_Organisatie_QueryModelObjects_Imp\" WHERE (\"InstanceId\" = @1) AND ((\"StartDate\" IS NULL) AND (\"TTEnd\" IS NULL)) ORDER BY \"Id\" ASC NULLS FIRST OFFSET 0 ROWS FETCH NEXT 2 ROWS ONLY;"; + + const string Bug1285CreateStatement = @" CREATE TEMP TABLE ""OrganisatieQmo_Organisatie_QueryModelObjects_Imp"" ( ""Id"" uuid NOT NULL, @@ -1143,179 +1084,313 @@ CREATE TEMP TABLE ""OrganisatieQmo_Organisatie_QueryModelObjects_Imp"" ""CrediteurMain"" boolean, CONSTRAINT ""pk_OrganisatieQmo_Organisatie_QueryModelObjects_Imp"" PRIMARY KEY (""Id"") )"; - #endregion Bug1285 + #endregion Bug1285 - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] - public async Task ChunkedStringWriteBufferEncodingSpace() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - // write buffer size must be 8192 for this test to work - // so guard against changes to the default / a change in the test harness - builder.WriteBufferSize = 8192; - using var conn = OpenConnection(builder.ConnectionString); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] + public async Task Chunked_string_write_buffer_encoding_space() + { + // write buffer size must be 8192 for this test to work so guard against changes to the default / a change in the test harness + await using var dataSource = CreateDataSource(csb => csb.WriteBufferSize = 8192); + await using var conn = await dataSource.OpenConnectionAsync(); - try - { - conn.ExecuteNonQuery("CREATE TABLE bug_2849 (col1 text, col2 text)"); + var tableName = await CreateTempTable(conn, "col1 text, col2 text"); - using (var binaryImporter = conn.BeginBinaryImport("COPY bug_2849 FROM STDIN (FORMAT BINARY);")) - { - // 8163 writespace left - await binaryImporter.StartRowAsync(); + await using var binaryImporter = await conn.BeginBinaryImportAsync($"COPY {tableName} FROM STDIN (FORMAT BINARY);"); + // 8163 writespace left + await binaryImporter.StartRowAsync(); - // we need to almost fill the write buffer - we need one byte left in the buffer before we chunk the string for the column after this one! - var almostBufferFillingString = new string('a', 8152); - await binaryImporter.WriteAsync(almostBufferFillingString, NpgsqlTypes.NpgsqlDbType.Text); + // we need to almost fill the write buffer - we need one byte left in the buffer before we chunk the string for the column after this one! + var almostBufferFillingString = new string('a', 8152); + await binaryImporter.WriteAsync(almostBufferFillingString, NpgsqlDbType.Text); - var unicodeCharacterThatEncodesToThreeBytesInUtf8 = '\uD55C'; - // This string needs to be long enough to be eligible for chunking, and start with a unicode character that will - // get encoded to multiple bytes - var longStringStartingWithAforementionedUnicodeCharacter = unicodeCharacterThatEncodesToThreeBytesInUtf8 + new string('a', 10000); - await binaryImporter.WriteAsync(longStringStartingWithAforementionedUnicodeCharacter, NpgsqlDbType.Text); + var unicodeCharacterThatEncodesToThreeBytesInUtf8 = '\uD55C'; + // This string needs to be long enough to be eligible for chunking, and start with a unicode character that will + // get encoded to multiple bytes + var longStringStartingWithAforementionedUnicodeCharacter = unicodeCharacterThatEncodesToThreeBytesInUtf8 + new string('a', 10000); + await binaryImporter.WriteAsync(longStringStartingWithAforementionedUnicodeCharacter, NpgsqlDbType.Text); - await binaryImporter.CompleteAsync(); - } - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS bug_2849"); - } - } + await binaryImporter.CompleteAsync(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] - public async Task ChunkedCharArrayWriteBufferEncodingSpace() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - // write buffer size must be 8192 for this test to work - // so guard against changes to the default / a change in the test harness - builder.WriteBufferSize = 8192; - using var conn = OpenConnection(builder.ConnectionString); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] + public async Task Chunked_char_array_write_buffer_encoding_space() + { + // write buffer size must be 8192 for this test to work so guard against changes to the default / a change in the test harness + await using var dataSource = CreateDataSource(csb => csb.WriteBufferSize = 8192); + await using var conn = await dataSource.OpenConnectionAsync(); - try - { - conn.ExecuteNonQuery("CREATE TABLE bug_2849 (col1 text, col2 text)"); + var tableName = await CreateTempTable(conn, "col1 text, col2 text"); - using (var binaryImporter = conn.BeginBinaryImport("COPY bug_2849 FROM STDIN (FORMAT BINARY);")) - { - // 8163 writespace left - await binaryImporter.StartRowAsync(); + await using var binaryImporter = await conn.BeginBinaryImportAsync($"COPY {tableName} FROM STDIN (FORMAT BINARY);"); + // 8163 writespace left + await binaryImporter.StartRowAsync(); - // we need to almost fill the write buffer - we need one byte left in the buffer before we chunk the string for the column after this one! - var almostBufferFillingString = new string('a', 8152); - await binaryImporter.WriteAsync(almostBufferFillingString, NpgsqlTypes.NpgsqlDbType.Text); + // we need to almost fill the write buffer - we need one byte left in the buffer before we chunk the string for the column after this one! + var almostBufferFillingString = new string('a', 8152); + await binaryImporter.WriteAsync(almostBufferFillingString, NpgsqlDbType.Text); - var unicodeCharacterThatEncodesToThreeBytesInUtf8 = '\uD55C'; - // This string needs to be long enough to be eligible for chunking, and start with a unicode character that will - // get encoded to multiple bytes - var longStringStartingWithAforementionedUnicodeCharacter = unicodeCharacterThatEncodesToThreeBytesInUtf8 + new string('a', 10000); - await binaryImporter.WriteAsync(longStringStartingWithAforementionedUnicodeCharacter.ToCharArray(), NpgsqlDbType.Text); + var unicodeCharacterThatEncodesToThreeBytesInUtf8 = '\uD55C'; + // This string needs to be long enough to be eligible for chunking, and start with a unicode character that will + // get encoded to multiple bytes + var longStringStartingWithAforementionedUnicodeCharacter = unicodeCharacterThatEncodesToThreeBytesInUtf8 + new string('a', 10000); + await binaryImporter.WriteAsync(longStringStartingWithAforementionedUnicodeCharacter.ToCharArray(), NpgsqlDbType.Text); - await binaryImporter.CompleteAsync(); - } - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS bug_2849"); - } - } + await binaryImporter.CompleteAsync(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2371")] + public async Task NRE_in_BeginTextExport() + { + await using var conn = await OpenConnectionAsync(); + var funcName = await GetTempFunctionName(conn); + await using var transaction = await conn.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync($"CREATE OR REPLACE FUNCTION {funcName}() RETURNS TABLE (i INT) AS $$ BEGIN RETURN QUERY SELECT s.a FROM pg_stat_activity p; end; $$ LANGUAGE plpgsql;"); + using var reader = await conn.BeginTextExportAsync($"copy (select * FROM {funcName}()) TO STDOUT WITH (format csv)"); + Assert.That(() => reader.ReadLine(), Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedTable) + ); + } + + public enum TestEnum + { + One, + Two + } + + class SomeComposite + { + /// + /// An enum without proper handler + /// + public TestEnum Test { get; set; } + public int X { get; set; } + public string SomeText { get; set; } = ""; + } + + [Test] + public async Task CompositePostgresType() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var func = await GetTempFunctionName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} as (x int, some_text text, test int)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await connection.ExecuteNonQueryAsync(@$" +CREATE OR REPLACE FUNCTION {func}(id int, out comp1 {type}, OUT comp2 {type}[]) +LANGUAGE plpgsql AS +$$ +BEGIN + comp1 = ROW(9, 'bar', 1)::{type}; + comp2 = ARRAY[ROW(9, 'bar', 1)::{type}]; +END; +$$;"); + + Assert.ThrowsAsync(async () => await connection.ExecuteScalarAsync($"SELECT {func}(0)")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2371")] - public void NullReferenceExceptionInBeginTextExport() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3117")] + public void Bug3117() + { + const string OkCommand = "SELECT 1"; + const string ErrorCommand = "SELECT * FROM public.imnotexist"; + using var dataSource = CreateDataSource(); + using (var conn = dataSource.OpenConnection()) { - using var conn = OpenConnection(); - try - { - using var transaction = conn.BeginTransaction(); - var command = conn.CreateCommand(); - command.CommandText = "CREATE OR REPLACE FUNCTION f_test() RETURNS TABLE (i INT) AS $$ BEGIN RETURN QUERY SELECT s.a FROM pg_stat_activity p; end; $$ LANGUAGE plpgsql;"; - command.ExecuteNonQuery(); - using var reader = conn.BeginTextExport("copy (select * FROM f_test()) TO STDOUT WITH (format csv)"); - Assert.That(() => reader.ReadLine(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42P01") - ); - } - finally - { - conn.ExecuteNonQuery("DROP FUNCTION IF EXISTS f_test()"); - } + var okCommand = new NpgsqlCommand(OkCommand, conn); + okCommand.Prepare(); + using (okCommand.ExecuteReader()) { } + + var errorCommand = new NpgsqlCommand(ErrorCommand, conn); + Assert.That(() => errorCommand.Prepare(), Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedTable)); } - public enum TestEnum + using (var conn = dataSource.OpenConnection()) { - One, - Two + var okCommand = new NpgsqlCommand(OkCommand, conn); + okCommand.Prepare(); + using (okCommand.ExecuteReader()) { } } + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3209")] + public async Task Bug3209() + { + await using var conn = CreateConnection(); + await conn.CloseAsync(); + await conn.OpenAsync(); + await conn.CloseAsync(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3373")] + public async Task Bug3373() + { + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT repeat('1', 10000); SELECT * from pg_sleep(3)"; + cmd.CommandTimeout = 0; - class SomeComposite + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.DoesNotThrowAsync(async () => await reader.NextResultAsync()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3649")] + public async Task Bug3649() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "value integer"); + + using (var importer = await conn.BeginBinaryImportAsync($"COPY {table} (value) FROM STDIN (FORMAT binary)")) { - /// - /// An enum without proper handler - /// - public TestEnum Test { get; set; } - public int X { get; set; } - public string SomeText { get; set; } = ""; + await importer.StartRowAsync(); + await importer.WriteAsync(DBNull.Value, NpgsqlDbType.Integer); + await importer.StartRowAsync(); + await importer.WriteAsync(1, NpgsqlDbType.Integer); + await importer.StartRowAsync(); + await importer.WriteAsync(2, NpgsqlDbType.Integer); + await importer.CompleteAsync(); } - [Test] - public void CompositePostgresType() + using (var exporter = await conn.BeginBinaryExportAsync($"COPY {table} (value) TO STDIN (FORMAT binary)")) { - using var conn = OpenConnection(); - conn.ExecuteNonQuery("CREATE TYPE pg_temp.comp1 as (x int, some_text text, test int)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("comp1"); - - conn.ExecuteNonQuery(@" -CREATE FUNCTION pg_temp.func(id int, out comp1 comp1, OUT comp2 COMP1[]) -LANGUAGE plpgsql AS -$$ -BEGIN - comp1 = ROW(9, 'bar', 1)::comp1; - comp2 = ARRAY[ROW(9, 'bar', 1)::comp1]; -END; -$$;"); + await exporter.StartRowAsync(); + Assert.IsTrue(exporter.IsNull); + await exporter.SkipAsync(); + await exporter.StartRowAsync(); + Assert.AreEqual(1, await exporter.ReadAsync()); + await exporter.StartRowAsync(); + Assert.AreEqual(2, await exporter.ReadAsync()); + } + } - using var cmd = new NpgsqlCommand("SELECT pg_temp.func(0)", conn); - Assert.That(() => cmd.ExecuteScalar(), Throws.TypeOf()); + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3839")] + public async Task UIThreadSynchronizationContext_deadlock() + { + var syncContext = new SingleThreadSynchronizationContext(nameof(UIThreadSynchronizationContext_deadlock)); + using (var _ = syncContext.Enter()) + { + // We have to Yield, so the current thread is changed to the one used by SingleThreadSynchronizationContext + await Task.Yield(); + using var connection = OpenConnection(); + + var data = new string('x', 5_000_000); + using var cmd = new NpgsqlCommand("SELECT generate_series(1, 500000); SELECT @p", connection); + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, data); + cmd.ExecuteNonQuery(); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3117")] - public void Bug3117() + // We have to make another Yield to change the current thread from the one used by SingleThreadSynchronizationContext + await Task.Yield(); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3924")] + public async Task Bug3924() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - const string OkCommand = "SELECT 1"; - const string ErrorCommand = "SELECT * FROM public.imnotexist"; - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - var okCommand = new NpgsqlCommand(OkCommand, conn); - okCommand.Prepare(); - using (okCommand.ExecuteReader()) { } - - var errorCommand = new NpgsqlCommand(ErrorCommand, conn); - Assert.That(() => errorCommand.Prepare(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedTable)); - } + CommandTimeout = 10, + KeepAlive = 5, + }; - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - var okCommand = new NpgsqlCommand(OkCommand, conn); - okCommand.Prepare(); - using (okCommand.ExecuteReader()) { } - conn.UnprepareAll(); - } + await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(postmaster.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var serverMock = await postmaster.WaitForServerConnection(); + + using (var cmd = conn.CreateCommand()) + { + cmd.CommandTimeout = 1; + cmd.CommandText = "SELECT 1"; + var queryTask = cmd.ExecuteNonQueryAsync(); + await serverMock.ExpectExtendedQuery(); + _ = serverMock.WriteScalarResponseAndFlush(1); + await queryTask; } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3209")] - public async Task Bug3209() + // Giving some time for keepalive to send a query and wait for a response for a little bit + await serverMock.ExpectMessage(FrontendMessageCode.Sync); + await Task.Delay(1000); + _ = serverMock.WriteReadyForQuery().FlushAsync(); + + using (var cmd = conn.CreateCommand()) { - await using var conn = CreateConnection(); - await conn.CloseAsync(); - await conn.OpenAsync(); - await conn.CloseAsync(); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + cmd.CommandTimeout = 1; + cmd.CommandText = "SELECT 1"; + var queryTask = cmd.ExecuteNonQueryAsync(); + await serverMock.ExpectExtendedQuery(); + _ = serverMock.WriteScalarResponseAndFlush(1); + Assert.DoesNotThrowAsync(async () => await queryTask); } } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4099")] + public async Task Bug4099() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Multiplexing = true, + MaxPoolSize = 1 + }; + await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(postmaster.ConnectionString); + await using var firstConn = await dataSource.OpenConnectionAsync(); + await using var secondConn = await dataSource.OpenConnectionAsync(); + + var firstQuery = firstConn.ExecuteScalarAsync("SELECT data"); + + var server = await postmaster.WaitForServerConnection(); + await server.ExpectExtendedQuery(); + + var secondQuery = secondConn.ExecuteScalarAsync("SELECT other_data"); + await server.ExpectExtendedQuery(); + + var data = new byte[10000]; + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid)) + .WriteDataRowWithFlush(data); + + var otherData = new byte[10]; + await server + .WriteCommandComplete() + .WriteReadyForQuery() + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid)) + .WriteDataRow(otherData) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + Assert.That(data, Is.EquivalentTo((byte[])(await firstQuery)!)); + Assert.That(otherData, Is.EquivalentTo((byte[])(await secondQuery)!)); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4123")] + public async Task Bug4123() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + await using var rdr = await cmd.ExecuteReaderAsync(); + + await rdr.ReadAsync(); + await using var stream = await rdr.GetStreamAsync(0); + + Assert.DoesNotThrowAsync(stream.FlushAsync); + Assert.DoesNotThrow(stream.Flush); + } } diff --git a/test/Npgsql.Tests/CommandBuilderTests.cs b/test/Npgsql.Tests/CommandBuilderTests.cs index 98bfefead4..e917b7f6b3 100644 --- a/test/Npgsql.Tests/CommandBuilderTests.cs +++ b/test/Npgsql.Tests/CommandBuilderTests.cs @@ -5,929 +5,386 @@ using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -using static Npgsql.Util.Statics; -namespace Npgsql.Tests -{ - class CommandBuilderTests : TestBase - { - // TODO: REMOVE ME - bool IsMultiplexing = false; - - [Test, Description("Tests function parameter derivation with IN, OUT and INOUT parameters")] - public async Task DeriveFunctionParameters_Various() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" - CREATE OR REPLACE FUNCTION {function}(IN param1 INT, OUT param2 text, INOUT param3 INT) RETURNS record AS - ' - BEGIN - param2 = ''sometext''; - param3 = param1 + param3; - END; - ' LANGUAGE 'plpgsql'; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(cmd.Parameters[0].PostgresType, Is.TypeOf()); - Assert.That(cmd.Parameters[0].DataTypeName, Is.EqualTo("integer")); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("param1")); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); - Assert.That(cmd.Parameters[1].PostgresType, Is.TypeOf()); - Assert.That(cmd.Parameters[1].DataTypeName, Is.EqualTo("text")); - Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("param2")); - Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.InputOutput)); - Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(cmd.Parameters[2].PostgresType, Is.TypeOf()); - Assert.That(cmd.Parameters[2].DataTypeName, Is.EqualTo("integer")); - Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("param3")); - cmd.Parameters[0].Value = 5; - cmd.Parameters[2].Value = 4; - cmd.ExecuteNonQuery(); - Assert.That(cmd.Parameters[0].Value, Is.EqualTo(5)); - Assert.That(cmd.Parameters[1].Value, Is.EqualTo("sometext")); - Assert.That(cmd.Parameters[2].Value, Is.EqualTo(9)); - } - } - - [Test, Description("Tests function parameter derivation with IN-only parameters")] - public async Task DeriveFunctionParameters_InOnly() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" - CREATE OR REPLACE FUNCTION {function}(IN param1 INT, IN param2 INT) RETURNS int AS - ' - BEGIN - RETURN param1 + param2; - END; - ' LANGUAGE 'plpgsql'; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); - cmd.Parameters[0].Value = 5; - cmd.Parameters[1].Value = 4; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); - } - } - - [Test, Description("Tests function parameter derivation with no parameters")] - public async Task DeriveFunctionParameters_NoParams() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" - CREATE OR REPLACE FUNCTION {function}() RETURNS int AS - ' - BEGIN - RETURN 4; - END; - ' LANGUAGE 'plpgsql'; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Is.Empty); - } - } - - [Test] - public async Task DeriveFunctionParameters_CaseSensitiveName() - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync( - @"CREATE OR REPLACE FUNCTION ""FunctionCaseSensitive""(int4, text) returns int4 as - $BODY$ - begin - return 0; - end - $BODY$ - language 'plpgsql';"); - await using var _ = DeferAsync(() => conn.ExecuteNonQueryAsync(@"DROP FUNCTION ""FunctionCaseSensitive""")); - - var command = new NpgsqlCommand(@"""FunctionCaseSensitive""", conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); - } - } - - [Test] - public async Task DeriveFunctionParameters_ParameterNameFromFunction() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); - - await conn.ExecuteNonQueryAsync($@"CREATE OR REPLACE FUNCTION {function}(x int, y int, out sum int, out product int) as 'select $1 + $2, $1 * $2' language 'sql';"); - var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual("x", command.Parameters[0].ParameterName); - Assert.AreEqual("y", command.Parameters[1].ParameterName); - } - } - - [Test] - public async Task DeriveFunctionParameters_NonExistingFunction() - { - using (var conn = await OpenConnectionAsync()) - { - var invalidCommandName = new NpgsqlCommand("invalidfunctionname", conn) { CommandType = CommandType.StoredProcedure }; - Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(invalidCommandName), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42883")); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1212")] - public async Task DeriveFunctionParameters_TableParameters() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "9.2.0"); - await using var _ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" - CREATE FUNCTION {function}(IN in1 INT) RETURNS TABLE(t1 INT, t2 INT) AS - 'SELECT in1,in1+1' LANGUAGE 'sql'; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); - cmd.Parameters[0].Value = 5; - cmd.ExecuteNonQuery(); - Assert.That(cmd.Parameters[1].Value, Is.EqualTo(5)); - Assert.That(cmd.Parameters[2].Value, Is.EqualTo(6)); - } - } - - [Test, Description("Tests function parameter derivation for quoted functions with double quotes in the name works")] - public async Task DeriveFunctionParameters_QuoteCharactersInFunctionName() - { - using (var conn = await OpenConnectionAsync()) - { - var function = @"""""""FunctionQuote""""CharactersInName"""""""; - await conn.ExecuteNonQueryAsync( - $@"CREATE OR REPLACE FUNCTION {function}(int4, text) returns int4 as - $BODY$ - begin - return 0; - end - $BODY$ - language 'plpgsql';"); - await using var _ = DeferAsync(() => conn.ExecuteNonQueryAsync("DROP FUNCTION " + function)); - - var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); - } - } - - [Test, Description("Tests function parameter derivation for quoted functions with dots in the name works")] - public async Task DeriveFunctionParameters_DotCharacterInFunctionName() - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync( - @"CREATE OR REPLACE FUNCTION ""My.Dotted.Function""(int4, text) returns int4 as - $BODY$ - begin - return 0; - end - $BODY$ - language 'plpgsql';"); - await using var _ = DeferAsync(() => conn.ExecuteNonQueryAsync(@"DROP FUNCTION ""My.Dotted.Function""")); - - var command = new NpgsqlCommand(@"""My.Dotted.Function""", conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); - } - } - - [Test, Description("Tests if the right function according to search_path is used in function parameter derivation")] - public async Task DeriveFunctionParameters_CorrectSchemaResolution() - { - if (IsMultiplexing) - return; // Uses search_path - - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempSchema(conn, out var schema1); - await using var __ = await CreateTempSchema(conn, out var schema2); - - await conn.ExecuteNonQueryAsync( - $@" -CREATE FUNCTION {schema1}.redundantfunc() RETURNS int AS -$BODY$ -BEGIN - RETURN 1; -END; -$BODY$ -LANGUAGE 'plpgsql'; - -CREATE FUNCTION {schema2}.redundantfunc(IN param1 INT, IN param2 INT) RETURNS int AS -$BODY$ -BEGIN -RETURN param1 + param2; -END; -$BODY$ -LANGUAGE 'plpgsql'; - -SET search_path TO {schema2}; -"); - var command = new NpgsqlCommand("redundantfunc", conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(command); - Assert.That(command.Parameters, Has.Count.EqualTo(2)); - Assert.That(command.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(command.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); - command.Parameters[0].Value = 5; - command.Parameters[1].Value = 4; - Assert.That(command.ExecuteScalar(), Is.EqualTo(9)); - } - } - - [Test, Description("Tests if function parameter derivation throws an exception if the specified function is not in the search_path")] - public async Task DeriveFunctionParameters_ThrowsForExistingFunctionThatIsNotInSearchPath() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempSchema(conn, out var schema); - - await conn.ExecuteNonQueryAsync($@" -CREATE OR REPLACE FUNCTION {schema}.schema1func() RETURNS int AS -$BODY$ -BEGIN - RETURN 1; -END; -$BODY$ -LANGUAGE 'plpgsql'; - -RESET search_path; -"); - var command = new NpgsqlCommand("schema1func", conn) { CommandType = CommandType.StoredProcedure }; - Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42883")); - } - } - - [Test, Description("Tests if an exception is thrown if multiple functions with the specified name are in the search_path")] - public async Task DeriveFunctionParameters_ThrowsForMultipleFunctionNameHitsInSearchPath() - { - if (IsMultiplexing) - return; // Uses search_path - - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempSchema(conn, out var schema1); - await using var __ = await CreateTempSchema(conn, out var schema2); - - await conn.ExecuteNonQueryAsync( - $@" -CREATE FUNCTION {schema1}.redundantfunc() RETURNS int AS -$BODY$ -BEGIN - RETURN 1; -END; -$BODY$ -LANGUAGE 'plpgsql'; - -CREATE OR REPLACE FUNCTION {schema1}.redundantfunc(IN param1 INT, IN param2 INT) RETURNS int AS -$BODY$ -BEGIN -RETURN param1 + param2; -END; -$BODY$ -LANGUAGE 'plpgsql'; - -SET search_path TO {schema1}, {schema2}; -"); - var command = new NpgsqlCommand("redundantfunc", conn) { CommandType = CommandType.StoredProcedure }; - Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42725")); - } - } +namespace Npgsql.Tests; - #region Set returning functions - - [Test, Description("Tests parameter derivation for a function that returns SETOF sometype")] - public async Task DeriveFunctionParameters_FunctionReturningSetofType() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "9.2.0"); - - await using var _ = await GetTempTableName(conn, out var table); - await using var __ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (fooid int, foosubid int, fooname text); - -INSERT INTO {table} VALUES -(1, 1, 'Joe'), -(1, 2, 'Ed'), -(2, 1, 'Mary'); - -CREATE FUNCTION {function}(int) RETURNS SETOF {table} AS $$ - SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; -$$ LANGUAGE SQL; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); - cmd.Parameters[0].Value = 1; - cmd.ExecuteNonQuery(); - Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); - } - } - - [Test, Description("Tests parameter derivation for a function that returns TABLE")] - public async Task DeriveFunctionParameters_FunctionReturningTable() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "9.2.0"); - - await using var _ = await GetTempTableName(conn, out var table); - await using var __ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (fooid int, foosubid int, fooname text); - -INSERT INTO {table} VALUES -(1, 1, 'Joe'), -(1, 2, 'Ed'), -(2, 1, 'Mary'); - -CREATE OR REPLACE FUNCTION {function}(int) RETURNS TABLE(fooid int, foosubid int, fooname text) AS $$ - SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; -$$ LANGUAGE SQL; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); - cmd.Parameters[0].Value = 1; - cmd.ExecuteNonQuery(); - Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); - } - } - - [Test, Description("Tests parameter derivation for a function that returns SETOF record")] - public async Task DeriveFunctionParameters_FunctionReturningSetofRecord() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "9.2.0"); - - await using var _ = await GetTempTableName(conn, out var table); - await using var __ = GetTempFunctionName(conn, out var function); - - // This function returns record because of the two Out (InOut & Out) parameters - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (fooid int, foosubid int, fooname text); - -INSERT INTO {table} VALUES -(1, 1, 'Joe'), -(1, 2, 'Ed'), -(2, 1, 'Mary'); - -CREATE FUNCTION {function}(int, OUT fooid int, OUT foosubid int, OUT fooname text) RETURNS SETOF record AS $$ - SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; -$$ LANGUAGE SQL; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); - cmd.Parameters[0].Value = 1; - cmd.ExecuteNonQuery(); - Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2022")] - public async Task DeriveFunctionParameters_FunctionReturningSetofTypeWithDroppedColumn() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "9.2.0"); - - await using var _ = await GetTempTableName(conn, out var table); - await using var __ = GetTempFunctionName(conn, out var function); - - await conn.ExecuteNonQueryAsync($@" - CREATE TABLE {table} (id serial PRIMARY KEY, t1 text, t2 text); - CREATE OR REPLACE FUNCTION {function}() RETURNS SETOF {table} AS $$ - SELECT * FROM {table} - $$LANGUAGE SQL; - ALTER TABLE {table} DROP t2; - "); - - var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); - } - } - - #endregion - - #region CommandType.Text +class CommandBuilderTests : TestBase +{ + // See function parameter derivation tests in FunctionTests, and stored procedure derivation tests in StoredProcedureTests - [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text)")] - public async Task DeriveTextCommandParameters_OneParameterWithSameType() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "id int, val text", out var table); + [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text)")] + public async Task DeriveParameters_text_one_parameter_with_same_type() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id int, val text"); - var cmd = new NpgsqlCommand( - $@"INSERT INTO {table} VALUES(:x, 'some value'); + var cmd = new NpgsqlCommand( + $@"INSERT INTO {table} VALUES(:x, 'some value'); UPDATE {table} SET val = 'changed value' WHERE id = :x; SELECT val FROM {table} WHERE id = :x;", - conn); - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); - Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - cmd.Parameters[0].Value = 42; - var retVal = await cmd.ExecuteScalarAsync(); - Assert.That(retVal, Is.EqualTo("changed value")); - } - } + conn); + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + cmd.Parameters[0].Value = 42; + var retVal = await cmd.ExecuteScalarAsync(); + Assert.That(retVal, Is.EqualTo("changed value")); + } - [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text) where different types would be inferred for placeholders with the same name.")] - public async Task DeriveTextCommandParameters_OneParameterWithDifferentTypes() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "id int, val text", out var table); + [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text) where different types would be inferred for placeholders with the same name.")] + public async Task DeriveParameters_text_one_parameter_with_different_types() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id int, val text"); - var cmd = new NpgsqlCommand( - $@"INSERT INTO {table} VALUES(:x, 'some value'); + var cmd = new NpgsqlCommand( + $@"INSERT INTO {table} VALUES(:x, 'some value'); UPDATE {table} SET val = 'changed value' WHERE id = :x::double precision; SELECT val FROM {table} WHERE id = :x::numeric;", - conn); - var ex = Assert.Throws(() => NpgsqlCommandBuilder.DeriveParameters(cmd)); - Assert.That(ex.Message, Is.EqualTo("The backend parser inferred different types for parameters with the same name. Please try explicit casting within your SQL statement or batch or use different placeholder names.")); - } - } + conn); + var ex = Assert.Throws(() => NpgsqlCommandBuilder.DeriveParameters(cmd))!; + Assert.That(ex.Message, Is.EqualTo("The backend parser inferred different types for parameters with the same name. Please try explicit casting within your SQL statement or batch or use different placeholder names.")); + } - [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text) with multiple parameters")] - public async Task DeriveTextCommandParameters_MultipleParameters() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "id int, val text", out var table); + [Test, Description("Tests parameter derivation for parameterized queries (CommandType.Text) with multiple parameters")] + public async Task DeriveParameters_multiple_parameters() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id int, val text"); - var cmd = new NpgsqlCommand( - $@"INSERT INTO {table} VALUES(:x, 'some value'); + var cmd = new NpgsqlCommand( + $@"INSERT INTO {table} VALUES(:x, 'some value'); UPDATE {table} SET val = 'changed value' WHERE id = @y::double precision; SELECT val FROM {table} WHERE id = :z::numeric;", - conn); - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); - Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("y")); - Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("z")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); - Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Numeric)); - - cmd.Parameters[0].Value = 42; - cmd.Parameters[1].Value = 42d; - cmd.Parameters[2].Value = 42; - var retVal = await cmd.ExecuteScalarAsync(); - Assert.That(retVal, Is.EqualTo("changed value")); - } - } + conn); + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("y")); + Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("z")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); + Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Numeric)); + + cmd.Parameters[0].Value = 42; + cmd.Parameters[1].Value = 42d; + cmd.Parameters[2].Value = 42; + var retVal = await cmd.ExecuteScalarAsync(); + Assert.That(retVal, Is.EqualTo("changed value")); + } - [Test, Description("Tests parameter derivation a parameterized query (CommandType.Text) that is already prepared.")] - public async Task DeriveTextCommandParameters_PreparedStatement() + [Test, Description("Tests parameter derivation a parameterized query (CommandType.Text) that is already prepared.")] + public async Task DeriveParameters_text_prepared_statement() + { + const string query = "SELECT @p::integer"; + const int answer = 42; + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(query, conn); + cmd.Parameters.AddWithValue("@p", NpgsqlDbType.Integer, answer); + cmd.Prepare(); + Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); + + var ex = Assert.Throws(() => { - const string query = "SELECT @p::integer"; - const int answer = 42; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(query, conn)) - { - cmd.Parameters.AddWithValue("@p", NpgsqlDbType.Integer, answer); - cmd.Prepare(); - Assert.That(conn.Connector!.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); - - var ex = Assert.Throws(() => - { - // Derive parameters for the already prepared statement - NpgsqlCommandBuilder.DeriveParameters(cmd); - - }); - - Assert.That(ex.Message, Is.EqualTo("Deriving parameters isn't supported for commands that are already prepared.")); - - // We leave the command intact when throwing so it should still be useable - Assert.That(cmd.Parameters.Count, Is.EqualTo(1)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("@p")); - Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); - cmd.Parameters["@p"].Value = answer; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(answer)); - - conn.UnprepareAll(); - } - } + // Derive parameters for the already prepared statement + NpgsqlCommandBuilder.DeriveParameters(cmd); - [Test, Description("Tests parameter derivation for array parameters in parameterized queries (CommandType.Text)")] - public async Task DeriveTextCommandParameters_Array() - { - using (var conn = await OpenConnectionAsync()) - { - var cmd = new NpgsqlCommand("SELECT :a::integer[]", conn); - var val = new[] { 7, 42 }; - - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("a")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); - cmd.Parameters[0].Value = val; - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(val)); - } - } - } + })!; - [Test, Description("Tests parameter derivation for domain parameters in parameterized queries (CommandType.Text)")] - public async Task DeriveTextCommandParameters_Domain() - { - using (var conn = await OpenConnectionAsync()) - { - MinimumPgVersion(conn, "11.0", "Arrays of domains and domains over arrays were introduced in PostgreSQL 11"); - await conn.ExecuteNonQueryAsync("CREATE DOMAIN posint AS integer CHECK (VALUE > 0);" + - "CREATE DOMAIN int_array AS int[] CHECK(array_length(VALUE, 1) = 2);"); - conn.ReloadTypes(); - await using var _ = DeferAsync(async () => - { - await conn.ExecuteNonQueryAsync("DROP DOMAIN int_array; DROP DOMAIN posint"); - conn.ReloadTypes(); - }); - - var cmd = new NpgsqlCommand("SELECT :a::posint, :b::posint[], :c::int_array", conn); - var val = 23; - var arrayVal = new[] { 7, 42 }; - - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("a")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith("posint")); - Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("b")); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); - Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith("posint[]")); - Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("c")); - Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); - Assert.That(cmd.Parameters[2].DataTypeName, Does.EndWith("int_array")); - cmd.Parameters[0].Value = val; - cmd.Parameters[1].Value = arrayVal; - cmd.Parameters[2].Value = arrayVal; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(val)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(arrayVal)); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(arrayVal)); - } - } - } + Assert.That(ex.Message, Is.EqualTo("Deriving parameters isn't supported for commands that are already prepared.")); - [Test, Description("Tests parameter derivation for unmapped enum parameters in parameterized queries (CommandType.Text)")] - public async Task DeriveTextCommandParameters_UnmappedEnum() - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync("CREATE TYPE fruit AS ENUM ('Apple', 'Cherry', 'Plum')"); - conn.ReloadTypes(); - await using var _ = DeferAsync(async () => - { - await conn.ExecuteNonQueryAsync("DROP TYPE fruit"); - conn.ReloadTypes(); - }); - - var cmd = new NpgsqlCommand("SELECT :x::fruit", conn); - const string val1 = "Apple"; - var val2 = new string[] { "Cherry", "Plum" }; - - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); - Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); - Assert.That(cmd.Parameters[0].PostgresType!.Name, Is.EqualTo("fruit")); - Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith("fruit")); - cmd.Parameters[0].Value = val1; - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetString(0), Is.EqualTo(val1)); - } - } - } + // We leave the command intact when throwing so it should still be useable + Assert.That(cmd.Parameters.Count, Is.EqualTo(1)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("@p")); + Assert.That(conn.Connector.PreparedStatementManager.NumPrepared, Is.EqualTo(1)); + cmd.Parameters["@p"].Value = answer; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(answer)); + } - enum Fruit { Apple, Cherry, Plum } + [Test, Description("Tests parameter derivation for array parameters in parameterized queries (CommandType.Text)")] + public async Task DeriveParameters_text_array() + { + using var conn = await OpenConnectionAsync(); + var cmd = new NpgsqlCommand("SELECT :a::integer[]", conn); + var val = new[] { 7, 42 }; + + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("a")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); + cmd.Parameters[0].Value = val; + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(val)); + } - [Test, Description("Tests parameter derivation for mapped enum parameters in parameterized queries (CommandType.Text)")] - public async Task DeriveTextCommandParameters_MappedEnum() - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync("CREATE TYPE fruit AS ENUM ('apple', 'cherry', 'plum')"); - conn.ReloadTypes(); - await using var _ = DeferAsync(async () => - { - await conn.ExecuteNonQueryAsync("DROP TYPE fruit"); - conn.ReloadTypes(); - }); - - conn.TypeMapper.MapEnum("fruit"); - var cmd = new NpgsqlCommand("SELECT :x::fruit, :y::fruit[]", conn); - const Fruit val1 = Fruit.Apple; - var val2 = new Fruit[] { Fruit.Cherry, Fruit.Plum }; - - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); - Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); - Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith("fruit")); - Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("y")); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); - Assert.That(cmd.Parameters[1].PostgresType, Is.InstanceOf()); - Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith("fruit[]")); - cmd.Parameters[0].Value = val1; - cmd.Parameters[1].Value = val2; - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(val1)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(val2)); - } - } - } + [Test, Description("Tests parameter derivation for domain parameters in parameterized queries (CommandType.Text)")] + public async Task DeriveParameters_text_domain() + { + using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Arrays of domains and domains over arrays were introduced in PostgreSQL 11"); + var domainType = await GetTempTypeName(conn); + var domainArrayType = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($@" +CREATE DOMAIN {domainType} AS integer CHECK (VALUE > 0); +CREATE DOMAIN {domainArrayType} AS int[] CHECK(array_length(VALUE, 1) = 2);"); + conn.ReloadTypes(); + + var cmd = new NpgsqlCommand($"SELECT :a::{domainType}, :b::{domainType}[], :c::{domainArrayType}", conn); + var val = 23; + var arrayVal = new[] { 7, 42 }; + + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("a")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith(domainType)); + Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("b")); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); + Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith(domainType + "[]")); + Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("c")); + Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer | NpgsqlDbType.Array)); + Assert.That(cmd.Parameters[2].DataTypeName, Does.EndWith(domainArrayType)); + cmd.Parameters[0].Value = val; + cmd.Parameters[1].Value = arrayVal; + cmd.Parameters[2].Value = arrayVal; + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(val)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(arrayVal)); + Assert.That(reader.GetFieldValue(2), Is.EqualTo(arrayVal)); + } - class SomeComposite - { - public int X { get; set; } + [Test, Description("Tests parameter derivation for unmapped enum parameters in parameterized queries (CommandType.Text)")] + public async Task DeriveParameters_text_unmapped_enum() + { + using var conn = await OpenConnectionAsync(); + var type = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($@"CREATE TYPE {type} AS ENUM ('Apple', 'Cherry', 'Plum')"); + conn.ReloadTypes(); + + var cmd = new NpgsqlCommand($"SELECT :x::{type}", conn); + const string val1 = "Apple"; + var val2 = new string[] { "Cherry", "Plum" }; + + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(1)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); + Assert.That(cmd.Parameters[0].PostgresType!.Name, Is.EqualTo(type)); + Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith(type)); + cmd.Parameters[0].Value = val1; + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetString(0), Is.EqualTo(val1)); + } - [PgName("some_text")] - public string SomeText { get; set; } = ""; - } + enum Fruit { Apple, Cherry, Plum } - [Test] - public async Task DeriveTextCommandParameters_MappedComposite() - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync(@" -DROP TYPE IF EXISTS deriveparameterscomposite1; -CREATE TYPE deriveparameterscomposite1 AS (x int, some_text text)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("deriveparameterscomposite1"); - await using var _ = DeferAsync(async () => - { - await conn.ExecuteNonQueryAsync("DROP TYPE deriveparameterscomposite1"); - conn.ReloadTypes(); - }); - - var expected1 = new SomeComposite { X = 8, SomeText = "foo" }; - var expected2 = new[] { - expected1, - new SomeComposite {X = 9, SomeText = "bar"} - }; - - using (var cmd = new NpgsqlCommand("SELECT @p1::deriveparameterscomposite1, @p2::deriveparameterscomposite1[]", conn)) - { - NpgsqlCommandBuilder.DeriveParameters(cmd); - Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); - Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("p1")); - Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); - Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); - Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith("deriveparameterscomposite1")); - var p1Fields = ((PostgresCompositeType)cmd.Parameters[0].PostgresType!).Fields; - Assert.That(p1Fields[0].Name, Is.EqualTo("x")); - Assert.That(p1Fields[1].Name, Is.EqualTo("some_text")); - - Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("p2")); - Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); - Assert.That(cmd.Parameters[1].PostgresType, Is.InstanceOf()); - Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith("deriveparameterscomposite1[]")); - var p2Element = ((PostgresArrayType)cmd.Parameters[1].PostgresType!).Element; - Assert.That(p2Element, Is.InstanceOf()); - Assert.That(p2Element.Name, Is.EqualTo("deriveparameterscomposite1")); - var p2Fields = ((PostgresCompositeType)p2Element).Fields; - Assert.That(p2Fields[0].Name, Is.EqualTo("x")); - Assert.That(p2Fields[1].Name, Is.EqualTo("some_text")); - - cmd.Parameters[0].Value = expected1; - cmd.Parameters[1].Value = expected2; - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetFieldValue(0).SomeText, Is.EqualTo(expected1.SomeText)); - Assert.That(reader.GetFieldValue(0).X, Is.EqualTo(expected1.X)); - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetFieldValue(1)[i].SomeText, Is.EqualTo(expected2[i].SomeText)); - Assert.That(reader.GetFieldValue(1)[i].X, Is.EqualTo(expected2[i].X)); - } - } - } - } - } + [Test, Description("Tests parameter derivation for mapped enum parameters in parameterized queries (CommandType.Text)")] + public async Task DeriveParameters_text_mapped_enum() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@"CREATE TYPE {type} AS ENUM ('apple', 'cherry', 'plum')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var cmd = new NpgsqlCommand($"SELECT :x::{type}, :y::{type}[]", connection); + const Fruit val1 = Fruit.Apple; + var val2 = new[] { Fruit.Cherry, Fruit.Plum }; + + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); + Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith(type)); + Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("y")); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(cmd.Parameters[1].PostgresType, Is.InstanceOf()); + Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith(type + "[]")); + cmd.Parameters[0].Value = val1; + cmd.Parameters[1].Value = val2; + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(val1)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(val2)); + } - #endregion + class SomeComposite + { + public int X { get; set; } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1591")] - public async Task GetUpdateCommandInfersParametersWithNpgsqDbType() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@" - CREATE TABLE {table} ( - Cod varchar(5) NOT NULL, - Descr varchar(40), - Data date, - DataOra timestamp, - Intero smallInt NOT NULL, - Decimale money, - Singolo float, - Booleano bit, - Nota varchar(255), - BigIntArr bigint[], - VarCharArr character varying(20)[], - PRIMARY KEY (Cod) - ); - INSERT INTO {table} VALUES('key1', 'description', '2018-07-03', '2018-07-03 07:02:00', 123, 123.4, 1234.5, B'1', 'note'); - "); - - var daDataAdapter = - new NpgsqlDataAdapter( - $"SELECT Cod, Descr, Data, DataOra, Intero, Decimale, Singolo, Booleano, Nota, BigIntArr, VarCharArr FROM {table}", conn); - - var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); - var dtTable = new DataTable(); - - daDataAdapter.InsertCommand = cbCommandBuilder.GetInsertCommand(); - daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); - daDataAdapter.DeleteCommand = cbCommandBuilder.GetDeleteCommand(); - - Assert.That(daDataAdapter.UpdateCommand.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Date)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[3].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[4].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[5].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Money)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[6].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[7].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bit)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[8].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[9].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Bigint)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[10].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Varchar)); - - Assert.That(daDataAdapter.UpdateCommand.Parameters[11].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[13].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[15].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Date)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[17].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[18].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[20].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Money)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[22].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[24].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bit)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[26].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[28].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Bigint)); - Assert.That(daDataAdapter.UpdateCommand.Parameters[30].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Varchar)); - - daDataAdapter.Fill(dtTable); - - var row = dtTable.Rows[0]; - - Assert.That(row[0], Is.EqualTo("key1")); - Assert.That(row[1], Is.EqualTo("description")); - Assert.That(row[2], Is.EqualTo(new DateTime(2018, 7, 3))); - Assert.That(row[3], Is.EqualTo(new DateTime(2018, 7, 3, 7, 2, 0))); - Assert.That(row[4], Is.EqualTo(123)); - Assert.That(row[5], Is.EqualTo(123.4)); - Assert.That(row[6], Is.EqualTo(1234.5)); - Assert.That(row[7], Is.EqualTo(true)); - Assert.That(row[8], Is.EqualTo("note")); - - dtTable.Rows[0]["Singolo"] = 1.1D; - - Assert.That(daDataAdapter.Update(dtTable), Is.EqualTo(1)); - } - } + [PgName("some_text")] + public string SomeText { get; set; } = ""; + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2560")] - public void GetUpdateCommandWithColumnAliases() + [Test] + public async Task DeriveParameters_text_mapped_composite() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var expected1 = new SomeComposite { X = 8, SomeText = "foo" }; + var expected2 = new[] { expected1, new SomeComposite {X = 9, SomeText = "bar"} }; + + await using var cmd = new NpgsqlCommand($"SELECT @p1::{type}, @p2::{type}[]", connection); + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("p1")); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(cmd.Parameters[0].PostgresType, Is.InstanceOf()); + Assert.That(cmd.Parameters[0].DataTypeName, Does.EndWith(type)); + var p1Fields = ((PostgresCompositeType)cmd.Parameters[0].PostgresType!).Fields; + Assert.That(p1Fields[0].Name, Is.EqualTo("x")); + Assert.That(p1Fields[1].Name, Is.EqualTo("some_text")); + + Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("p2")); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(cmd.Parameters[1].PostgresType, Is.InstanceOf()); + Assert.That(cmd.Parameters[1].DataTypeName, Does.EndWith(type + "[]")); + var p2Element = ((PostgresArrayType)cmd.Parameters[1].PostgresType!).Element; + Assert.That(p2Element, Is.InstanceOf()); + Assert.That(p2Element.Name, Is.EqualTo(type)); + var p2Fields = ((PostgresCompositeType)p2Element).Fields; + Assert.That(p2Fields[0].Name, Is.EqualTo("x")); + Assert.That(p2Fields[1].Name, Is.EqualTo("some_text")); + + cmd.Parameters[0].Value = expected1; + cmd.Parameters[1].Value = expected2; + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult | CommandBehavior.SingleRow); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetFieldValue(0).SomeText, Is.EqualTo(expected1.SomeText)); + Assert.That(reader.GetFieldValue(0).X, Is.EqualTo(expected1.X)); + for (var i = 0; i < 2; i++) { - using var conn = OpenConnection(); - - conn.ExecuteNonQuery(@" - CREATE TEMP TABLE data ( - Cod varchar(5) NOT NULL, - Descr varchar(40), - Data date, - CONSTRAINT PK_test_Cod PRIMARY KEY (Cod) - ); - "); - - using var cmd = new NpgsqlCommand("SELECT Cod as CodAlias, Descr as DescrAlias, Data as DataAlias FROM data", conn); - using var daDataAdapter = new NpgsqlDataAdapter(cmd); - using var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); - - daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); - Assert.True(daDataAdapter.UpdateCommand.CommandText.Contains("SET \"cod\" = @p1, \"descr\" = @p2, \"data\" = @p3 WHERE ((\"cod\" = @p4) AND ((@p5 = 1 AND \"descr\" IS NULL) OR (\"descr\" = @p6)) AND ((@p7 = 1 AND \"data\" IS NULL) OR (\"data\" = @p8)))")); + Assert.That(reader.GetFieldValue(1)[i].SomeText, Is.EqualTo(expected2[i].SomeText)); + Assert.That(reader.GetFieldValue(1)[i].X, Is.EqualTo(expected2[i].X)); } + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1591")] + public async Task Get_update_command_infers_parameters_with_NpgsqDbType() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} ( + Cod varchar(5) NOT NULL, + Descr varchar(40), + Data date, + DataOra timestamp, + Intero smallInt NOT NULL, + Decimale money, + Singolo float, + Booleano bit, + Nota varchar(255), + BigIntArr bigint[], + VarCharArr character varying(20)[], + PRIMARY KEY (Cod) +); +INSERT INTO {table} VALUES('key1', 'description', '2018-07-03', '2018-07-03 07:02:00', 123, 123.4, 1234.5, B'1', 'note')"); + + var daDataAdapter = + new NpgsqlDataAdapter( + $"SELECT Cod, Descr, Data, DataOra, Intero, Decimale, Singolo, Booleano, Nota, BigIntArr, VarCharArr FROM {table}", conn); + + var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); + var dtTable = new DataTable(); + + daDataAdapter.InsertCommand = cbCommandBuilder.GetInsertCommand(); + daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); + daDataAdapter.DeleteCommand = cbCommandBuilder.GetDeleteCommand(); + + Assert.That(daDataAdapter.UpdateCommand.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Date)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[3].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[4].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[5].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Money)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[6].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[7].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bit)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[8].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[9].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Bigint)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[10].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Varchar)); + + Assert.That(daDataAdapter.UpdateCommand.Parameters[11].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[13].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[15].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Date)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[17].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[18].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[20].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Money)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[22].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[24].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bit)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[26].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[28].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Bigint)); + Assert.That(daDataAdapter.UpdateCommand.Parameters[30].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Varchar)); + + daDataAdapter.Fill(dtTable); + + var row = dtTable.Rows[0]; + + Assert.That(row[0], Is.EqualTo("key1")); + Assert.That(row[1], Is.EqualTo("description")); + Assert.That(row[2], Is.EqualTo(new DateTime(2018, 7, 3))); + Assert.That(row[3], Is.EqualTo(new DateTime(2018, 7, 3, 7, 2, 0))); + Assert.That(row[4], Is.EqualTo(123)); + Assert.That(row[5], Is.EqualTo(123.4)); + Assert.That(row[6], Is.EqualTo(1234.5)); + Assert.That(row[7], Is.EqualTo(true)); + Assert.That(row[8], Is.EqualTo("note")); + + dtTable.Rows[0]["Singolo"] = 1.1D; + + Assert.That(daDataAdapter.Update(dtTable), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2846")] - public void GetUpdateCommandWithArrayColumType() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS Test; -CREATE TABLE Test ( -Cod varchar(5) NOT NULL, -Vettore character varying(20)[], -CONSTRAINT PK_test_Cod PRIMARY KEY (Cod) -) -"); - using var daDataAdapter = new NpgsqlDataAdapter("SELECT cod, vettore FROM test ORDER By cod", conn); - using var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); - var dtTable = new DataTable(); - - cbCommandBuilder.SetAllValues = true; - - daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); - - daDataAdapter.Fill(dtTable); - dtTable.Rows.Add(); - dtTable.Rows[0]["cod"] = '0'; - dtTable.Rows[0]["vettore"] = new string[] { "aaa", "bbb" }; - - daDataAdapter.Update(dtTable); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS Test"); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2560")] + public async Task Get_update_command_with_column_aliases() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "Cod varchar(5) PRIMARY KEY, Descr varchar(40), Data date"); + using var cmd = new NpgsqlCommand($"SELECT Cod as CodAlias, Descr as DescrAlias, Data as DataAlias FROM {table}", conn); + using var daDataAdapter = new NpgsqlDataAdapter(cmd); + using var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); + + daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); + Assert.True(daDataAdapter.UpdateCommand.CommandText.Contains("SET \"cod\" = @p1, \"descr\" = @p2, \"data\" = @p3 WHERE ((\"cod\" = @p4) AND ((@p5 = 1 AND \"descr\" IS NULL) OR (\"descr\" = @p6)) AND ((@p7 = 1 AND \"data\" IS NULL) OR (\"data\" = @p8)))")); } -} + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2846")] + public async Task Get_update_command_with_array_column_type() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "Cod varchar(5) PRIMARY KEY, Vettore character varying(20)[]"); + using var daDataAdapter = new NpgsqlDataAdapter($"SELECT cod, vettore FROM {table} ORDER By cod", conn); + using var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); + var dtTable = new DataTable(); + + cbCommandBuilder.SetAllValues = true; + + daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); + + daDataAdapter.Fill(dtTable); + dtTable.Rows.Add(); + dtTable.Rows[0]["cod"] = '0'; + dtTable.Rows[0]["vettore"] = new[] { "aaa", "bbb" }; + + daDataAdapter.Update(dtTable); + } +} diff --git a/test/Npgsql.Tests/CommandParameterTests.cs b/test/Npgsql.Tests/CommandParameterTests.cs new file mode 100644 index 0000000000..1e4355df4b --- /dev/null +++ b/test/Npgsql.Tests/CommandParameterTests.cs @@ -0,0 +1,216 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using NpgsqlTypes; +using NUnit.Framework; + +namespace Npgsql.Tests; + +public class CommandParameterTests : MultiplexingTestBase +{ + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task Input_and_output_parameters(CommandBehavior behavior) + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @c-1 AS c, @a+2 AS b", conn); + cmd.Parameters.Add(new NpgsqlParameter("a", 3)); + var b = new NpgsqlParameter { ParameterName = "b", Direction = ParameterDirection.Output }; + cmd.Parameters.Add(b); + var c = new NpgsqlParameter { ParameterName = "c", Direction = ParameterDirection.InputOutput, Value = 4 }; + cmd.Parameters.Add(c); + using (await cmd.ExecuteReaderAsync(behavior)) + { + Assert.AreEqual(5, b.Value); + Assert.AreEqual(3, c.Value); + } + } + + [Test] + public async Task Send_NpgsqlDbType_Unknown([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p::TIMESTAMP", conn); + cmd.CommandText = "SELECT @p::TIMESTAMP"; + cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Unknown) { Value = "2008-1-1" }); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetValue(0), Is.EqualTo(new DateTime(2008, 1, 1))); + } + + [Test] + public async Task Positional_parameter() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } + + [Test] + public async Task Positional_parameters_are_not_supported_with_legacy_batching() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1; SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + } + + [Test] + public async Task Unreferenced_named_parameter_works() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.AddWithValue("not_used", 8); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public async Task Unreferenced_positional_parameter_works() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public async Task Mixing_positional_and_named_parameters_is_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8 }); + cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = 9 }); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] + public async Task Reuse_command_with_different_parameter_placeholder_types() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + cmd.CommandText = "SELECT @p1"; + cmd.Parameters.AddWithValue("@p1", 8); + _ = await cmd.ExecuteScalarAsync(); + + cmd.CommandText = "SELECT $1"; + cmd.Parameters[0].ParameterName = null; + _ = await cmd.ExecuteScalarAsync(); + } + + [Test] + public async Task Positional_output_parameters_are_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { Value = 8, Direction = ParameterDirection.InputOutput }); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public void Parameters_get_name() + { + var command = new NpgsqlCommand(); + + // Add parameters. + command.Parameters.Add(new NpgsqlParameter(":Parameter1", DbType.Boolean)); + command.Parameters.Add(new NpgsqlParameter(":Parameter2", DbType.Int32)); + command.Parameters.Add(new NpgsqlParameter(":Parameter3", DbType.DateTime)); + command.Parameters.Add(new NpgsqlParameter("Parameter4", DbType.DateTime)); + + var idbPrmtr = command.Parameters["Parameter1"]; + Assert.IsNotNull(idbPrmtr); + command.Parameters[0].Value = 1; + + // Get by indexers. + + Assert.AreEqual(":Parameter1", command.Parameters["Parameter1"].ParameterName); + Assert.AreEqual(":Parameter2", command.Parameters["Parameter2"].ParameterName); + Assert.AreEqual(":Parameter3", command.Parameters["Parameter3"].ParameterName); + Assert.AreEqual("Parameter4", command.Parameters["Parameter4"].ParameterName); //Should this work? + + Assert.AreEqual(":Parameter1", command.Parameters[0].ParameterName); + Assert.AreEqual(":Parameter2", command.Parameters[1].ParameterName); + Assert.AreEqual(":Parameter3", command.Parameters[2].ParameterName); + Assert.AreEqual("Parameter4", command.Parameters[3].ParameterName); + } + + [Test] + public async Task Same_param_multiple_times() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p1", conn); + cmd.Parameters.AddWithValue("@p1", 8); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(reader[1], Is.EqualTo(8)); + } + + [Test] + public async Task Generic_parameter() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn); + cmd.Parameters.Add(new NpgsqlParameter("p1", 8)); + cmd.Parameters.Add(new NpgsqlParameter("p2", 8) { NpgsqlDbType = NpgsqlDbType.Integer }); + cmd.Parameters.Add(new NpgsqlParameter("p3", "hello")); + cmd.Parameters.Add(new NpgsqlParameter("p4", new[] { 'f', 'o', 'o' })); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + Assert.That(reader.GetInt32(1), Is.EqualTo(8)); + Assert.That(reader.GetString(2), Is.EqualTo("hello")); + Assert.That(reader.GetString(3), Is.EqualTo("foo")); + } + + [Test] + [TestCase(false)] + [TestCase(true)] + public async Task Parameter_must_be_set(bool genericParam) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1::TEXT", conn); + cmd.Parameters.Add( + genericParam + ? new NpgsqlParameter("p1", null) + : new NpgsqlParameter("p1", null) + ); + + Assert.That(async () => await cmd.ExecuteReaderAsync(), + Throws.Exception + .TypeOf() + .With.Message.EqualTo("Parameter 'p1' must have either its NpgsqlDbType or its DataTypeName or its Value set.")); + } + + [Test] + public async Task Object_generic_param_does_runtime_lookup() + { + await AssertTypeWrite(1, "1", "integer", NpgsqlDbType.Integer, DbType.Int32, DbType.Int32, isDefault: false, + isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); + await AssertTypeWrite(new[] {1, 1}, "{1,1}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefault: false, + isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); + } + + [Test] + public async Task Object_generic_parameter_works() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } + + public CommandParameterTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) + { + } +} diff --git a/test/Npgsql.Tests/CommandTests.cs b/test/Npgsql.Tests/CommandTests.cs index afcaa373d3..5d3b35b01d 100644 --- a/test/Npgsql.Tests/CommandTests.cs +++ b/test/Npgsql.Tests/CommandTests.cs @@ -1,1082 +1,1771 @@ +using Npgsql.BackendMessages; +using Npgsql.Internal; +using Npgsql.Tests.Support; +using NpgsqlTypes; +using NUnit.Framework; using System; +using System.Buffers.Binary; +using System.Collections.Generic; using System.Data; -using System.IO; using System.Linq; -using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; -using Npgsql.Tests.Support; -using NpgsqlTypes; -using NUnit.Framework; +using Npgsql.Internal.Postgres; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class CommandTests : MultiplexingTestBase { - public class CommandTests : MultiplexingTestBase - { - #region Multiple Statements in a Command - - /// - /// Tests various configurations of queries and non-queries within a multiquery - /// - [Test] - [TestCase(new[] { true }, TestName = "SingleQuery")] - [TestCase(new[] { false }, TestName = "SingleNonQuery")] - [TestCase(new[] { true, true }, TestName = "TwoQueries")] - [TestCase(new[] { false, false }, TestName = "TwoNonQueries")] - [TestCase(new[] { false, true }, TestName = "NonQueryQuery")] - [TestCase(new[] { true, false }, TestName = "QueryNonQuery")] - public async Task MultipleStatements(bool[] queries) + static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; + static uint TextOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Text).Value; + + #region Legacy batching + + [Test] + [TestCase(new[] { true }, TestName = "SingleQuery")] + [TestCase(new[] { false }, TestName = "SingleNonQuery")] + [TestCase(new[] { true, true }, TestName = "TwoQueries")] + [TestCase(new[] { false, false }, TestName = "TwoNonQueries")] + [TestCase(new[] { false, true }, TestName = "NonQueryQuery")] + [TestCase(new[] { true, false }, TestName = "QueryNonQuery")] + public async Task Multiple_statements(bool[] queries) + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + var sb = new StringBuilder(); + foreach (var query in queries) + sb.Append(query ? "SELECT 1;" : $"UPDATE {table} SET name='yo' WHERE 1=0;"); + var sql = sb.ToString(); + foreach (var prepare in new[] { false, true }) { - using (var conn = await OpenConnectionAsync()) + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + if (prepare && !IsMultiplexing) + await cmd.PrepareAsync(); + await using var reader = await cmd.ExecuteReaderAsync(); + var numResultSets = queries.Count(q => q); + for (var i = 0; i < numResultSets; i++) { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var sb = new StringBuilder(); - foreach (var query in queries) - sb.Append(query ? "SELECT 1;" : $"UPDATE {table} SET name='yo' WHERE 1=0;"); - var sql = sb.ToString(); - foreach (var prepare in new[] { false, true }) - { - using (var cmd = new NpgsqlCommand(sql, conn)) - { - if (prepare && !IsMultiplexing) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync()) - { - var numResultSets = queries.Count(q => q); - for (var i = 0; i < numResultSets; i++) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader[0], Is.EqualTo(1)); - Assert.That(reader.NextResult(), Is.EqualTo(i != numResultSets - 1)); - } - } - } - } + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader[0], Is.EqualTo(1)); + Assert.That(await reader.NextResultAsync(), Is.EqualTo(i != numResultSets - 1)); } } + } - [Test] - public async Task MultipleStatementsWithParameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + [Test] + public async Task Multiple_statements_with_parameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT @p1; SELECT @p2"; + var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Integer); + var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Text); + cmd.Parameters.Add(p1); + cmd.Parameters.Add(p2); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + p1.Value = 8; + p2.Value = "foo"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); + Assert.That(await reader.NextResultAsync(), Is.False); + } - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Integer); - var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Text); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - p1.Value = 8; - p2.Value = "foo"; - using (var reader = await cmd.ExecuteReaderAsync()) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - Assert.That(reader.NextResult(), Is.True); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetString(0), Is.EqualTo("foo")); - Assert.That(reader.NextResult(), Is.False); - } - } - } - } + [Test] + public async Task SingleRow_legacy_batching([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.NextResult(), Is.False); + } - [Test] - public async Task MultipleStatementsSingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + [Test, Description("Makes sure a later command can depend on an earlier one")] + [IssueLink("https://github.com/npgsql/npgsql/issues/641")] + public async Task Multiple_statements_with_dependencies() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "a INT"); - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - { - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - Assert.That(reader.Read(), Is.False); - Assert.That(reader.NextResult(), Is.False); - } - } - } - } + await conn.ExecuteNonQueryAsync($"ALTER TABLE {table} ADD COLUMN b INT; INSERT INTO {table} (b) VALUES (8)"); + Assert.That(await conn.ExecuteScalarAsync($"SELECT b FROM {table}"), Is.EqualTo(8)); + } - [Test, Description("Makes sure a later command can depend on an earlier one")] - [IssueLink("https://github.com/npgsql/npgsql/issues/641")] - public async Task MultipleStatementsWithDependencies() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "a INT", out var table); + [Test, Description("Forces async write mode when the first statement in a multi-statement command is big")] + [IssueLink("https://github.com/npgsql/npgsql/issues/641")] + public async Task Multiple_statements_large_first_command() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT repeat('X', {conn.Settings.WriteBufferSize}); SELECT @p", conn); + var expected1 = new string('X', conn.Settings.WriteBufferSize); + var expected2 = new string('Y', conn.Settings.WriteBufferSize); + cmd.Parameters.AddWithValue("p", expected2); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(expected1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(expected2)); + } - await conn.ExecuteNonQueryAsync($"ALTER TABLE {table} ADD COLUMN b INT; INSERT INTO {table} (b) VALUES (8)"); - Assert.That(await conn.ExecuteScalarAsync($"SELECT b FROM {table}"), Is.EqualTo(8)); - } - } + [Test] + [NonParallelizable] // Disables sql rewriting + public async Task Legacy_batching_is_not_supported_when_EnableSqlParsing_is_disabled() + { + using var _ = DisableSqlRewriting(); - [Test, Description("Forces async write mode when the first statement in a multi-statement command is big")] - [IssueLink("https://github.com/npgsql/npgsql/issues/641")] - public async Task MultipleStatementsLargeFirstCommand() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT repeat('X', {conn.Settings.WriteBufferSize}); SELECT @p", conn)) - { - var expected1 = new string('X', conn.Settings.WriteBufferSize); - var expected2 = new string('Y', conn.Settings.WriteBufferSize); - cmd.Parameters.AddWithValue("p", expected2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(expected1)); - reader.NextResult(); - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(expected2)); - } - } - } + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.SyntaxError)); + } - #endregion + [Test] + [NonParallelizable] // Disables sql rewriting + public async Task Positional_parameters_are_supported_when_EnableSqlParsing_is_disabled() + { + using var _ = DisableSqlRewriting(); - #region Timeout + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT $1", conn); + cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } - [Test, Description("Checks that CommandTimeout gets enforced as a socket timeout")] - [IssueLink("https://github.com/npgsql/npgsql/issues/327")] - [Timeout(10000)] - public async Task Timeout() - { - if (IsMultiplexing) - return; // Multiplexing, Timeout - - // Mono throws a socket exception with WouldBlock instead of TimedOut (see #1330) - var isMono = Type.GetType("Mono.Runtime") != null; - using var conn = await OpenConnectionAsync(ConnectionString + ";CommandTimeout=1"); - using var cmd = CreateSleepCommand(conn, 10); - Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception + [Test] + [NonParallelizable] // Disables sql rewriting + public async Task Named_parameters_are_not_supported_when_EnableSqlParsing_is_disabled() + { + using var _ = DisableSqlRewriting(); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter("p", 8)); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } + + #endregion + + #region Timeout + + [Test, Description("Checks that CommandTimeout gets enforced as a socket timeout")] + [IssueLink("https://github.com/npgsql/npgsql/issues/327")] + public async Task Timeout() + { + if (IsMultiplexing) + return; // Multiplexing, Timeout + + await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = CreateSleepCommand(conn, 10); + Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf() + ); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + + [Test, Description("Times out an async operation, testing that cancellation occurs successfully")] + [IssueLink("https://github.com/npgsql/npgsql/issues/607")] + public async Task Timeout_async_soft() + { + if (IsMultiplexing) + return; // Multiplexing, Timeout + + await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = CreateSleepCommand(conn, 10); + Assert.That(async () => await cmd.ExecuteNonQueryAsync(), + Throws.Exception .TypeOf() - .With.InnerException.TypeOf() - ); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - } + .With.InnerException.TypeOf()); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } - [Test, Description("Times out an async operation, testing that cancellation occurs successfully")] - [IssueLink("https://github.com/npgsql/npgsql/issues/607")] - [Timeout(10000)] - public async Task TimeoutAsyncSoft() - { - if (IsMultiplexing) - return; // Multiplexing, Timeout - - using var conn = await OpenConnectionAsync(builder => builder.CommandTimeout = 1); - using var cmd = CreateSleepCommand(conn, 10); - Assert.That(async () => await cmd.ExecuteNonQueryAsync(), - Throws.Exception - .TypeOf() - .With.InnerException.TypeOf()); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - } + [Test, Description("Times out an async operation, with unsuccessful cancellation (socket break)")] + [IssueLink("https://github.com/npgsql/npgsql/issues/607")] + public async Task Timeout_async_hard() + { + if (IsMultiplexing) + return; // Multiplexing, Timeout - [Test, Description("Times out an async operation, with unsuccessful cancellation (socket break)")] - [IssueLink("https://github.com/npgsql/npgsql/issues/607")] - [Timeout(10000)] - public async Task TimeoutAsyncHard() - { - if (IsMultiplexing) - return; // Multiplexing, Timeout + var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 1 }; + await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + await postmasterMock.WaitForServerConnection(); - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 1 }; - await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); - await postmasterMock.WaitForServerConnection(); + var processId = conn.ProcessID; - var processId = conn.ProcessID; + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), + Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), - Throws.Exception - .TypeOf() - .With.InnerException.TypeOf()); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, + Is.EqualTo(processId)); + } - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, - Is.EqualTo(processId)); - } + [Test] + public async Task Timeout_from_connection_string() + { + Assert.That(NpgsqlConnector.MinimumInternalCommandTimeout, Is.Not.EqualTo(NpgsqlCommand.DefaultTimeout)); + var timeout = NpgsqlConnector.MinimumInternalCommandTimeout; + await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = timeout); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var command = new NpgsqlCommand("SELECT 1", conn); + Assert.That(command.CommandTimeout, Is.EqualTo(timeout)); + command.CommandTimeout = 10; + await command.ExecuteScalarAsync(); + Assert.That(command.CommandTimeout, Is.EqualTo(10)); + } - [Test] - public async Task TimeoutFromConnectionString() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] + public async Task Timeout_switch_connection() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); + if (csb.CommandTimeout >= 100 && csb.CommandTimeout < 105) + IgnoreExceptOnBuildServer("Bad default command timeout"); + + await using var dataSource1 = CreateDataSource(ConnectionString + ";CommandTimeout=100"); + await using var c1 = dataSource1.CreateConnection(); + await using var cmd = c1.CreateCommand(); + Assert.That(cmd.CommandTimeout, Is.EqualTo(100)); + await using var dataSource2 = CreateDataSource(ConnectionString + ";CommandTimeout=101"); + await using (var c2 = dataSource2.CreateConnection()) { - Assert.That(NpgsqlConnector.MinimumInternalCommandTimeout, Is.Not.EqualTo(NpgsqlCommand.DefaultTimeout)); - var timeout = NpgsqlConnector.MinimumInternalCommandTimeout; - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - CommandTimeout = timeout - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - var command = new NpgsqlCommand("SELECT 1", conn); - conn.Open(); - Assert.That(command.CommandTimeout, Is.EqualTo(timeout)); - command.CommandTimeout = 10; - await command.ExecuteScalarAsync(); - Assert.That(command.CommandTimeout, Is.EqualTo(10)); - } + cmd.Connection = c2; + Assert.That(cmd.CommandTimeout, Is.EqualTo(101)); } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] - public async Task TimeoutSwitchConnection() + cmd.CommandTimeout = 102; + await using (var c2 = dataSource2.CreateConnection()) { - using (var conn = new NpgsqlConnection(ConnectionString)) - { - if (conn.CommandTimeout >= 100 && conn.CommandTimeout < 105) - TestUtil.IgnoreExceptOnBuildServer("Bad default command timeout"); - } + cmd.Connection = c2; + Assert.That(cmd.CommandTimeout, Is.EqualTo(102)); + } + } + + [Test] + public async Task Prepare_timeout_hard([Values] SyncOrAsync async) + { + if (IsMultiplexing) + return; // Multiplexing, Timeout + + var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 1 }; + await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + await postmasterMock.WaitForServerConnection(); - using (var c1 = await OpenConnectionAsync(ConnectionString + ";CommandTimeout=100")) + var processId = conn.ProcessID; + + var cmd = new NpgsqlCommand("SELECT 1", conn); + Assert.That(async () => { - using (var cmd = c1.CreateCommand()) - { - Assert.That(cmd.CommandTimeout, Is.EqualTo(100)); - using (var c2 = new NpgsqlConnection(ConnectionString + ";CommandTimeout=101")) - { - cmd.Connection = c2; - Assert.That(cmd.CommandTimeout, Is.EqualTo(101)); - } - cmd.CommandTimeout = 102; - using (var c2 = new NpgsqlConnection(ConnectionString + ";CommandTimeout=101")) - { - cmd.Connection = c2; - Assert.That(cmd.CommandTimeout, Is.EqualTo(102)); - } - } - } - } + if (async == SyncOrAsync.Sync) + cmd.Prepare(); + else + await cmd.PrepareAsync(); + }, + Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); - #endregion + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, + Is.EqualTo(processId)); + } - #region Cancel + #endregion - [Test, Description("Basic cancellation scenario")] - [Timeout(6000)] - public async Task Cancel() - { - if (IsMultiplexing) - return; + #region Cancel - using var conn = await OpenConnectionAsync(); - using var cmd = CreateSleepCommand(conn, 5); + [Test, Description("Basic cancellation scenario")] + [Ignore("Flaky, see https://github.com/npgsql/npgsql/issues/5070")] + public async Task Cancel() + { + if (IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = CreateSleepCommand(conn, 5); + + var queryTask = Task.Run(() => cmd.ExecuteNonQuery()); + // We have to be sure the command's state is InProgress, otherwise the cancellation request will never be sent + cmd.WaitUntilCommandIsInProgress(); + cmd.Cancel(); + Assert.That(async () => await queryTask, Throws + .TypeOf() + .With.InnerException.TypeOf() + .With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled) + ); + } - var cancelTask = Task.Run(() => - { - Thread.Sleep(300); - cmd.Cancel(); - }); - Assert.That(() => cmd.ExecuteNonQuery(), Throws - .TypeOf() - .With.InnerException.TypeOf() - .With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled) - ); + [Test] + public async Task Cancel_async_immediately() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - await cancelTask; - } + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; - [Test, Description("Cancels an async query with the cancellation token, with successful PG cancellation")] - public async Task CancelAsyncSoft() - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + var t = cmd.ExecuteScalarAsync(new(canceled: true)); + Assert.That(t.IsCompleted, Is.True); // checks, if a query has completed synchronously + Assert.That(t.Status, Is.EqualTo(TaskStatus.Canceled)); + Assert.ThrowsAsync(async () => await t); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test, Description("Cancels an async query with the cancellation token, with successful PG cancellation")] + [Explicit("Flaky due to #5033")] + public async Task Cancel_async_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var conn = await OpenConnectionAsync(); + await using var cmd = CreateSleepCommand(conn); + using var cancellationSource = new CancellationTokenSource(); + var t = cmd.ExecuteNonQueryAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var exception = Assert.ThrowsAsync(async () => await t)!; + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test, Description("Cancels an async query with the cancellation token, with unsuccessful PG cancellation (socket break)")] + public async Task Cancel_async_hard() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - await using var conn = await OpenConnectionAsync(); - using var cmd = CreateSleepCommand(conn); - var cancellationSource = new CancellationTokenSource(); - var t = cmd.ExecuteNonQueryAsync(cancellationSource.Token); - cancellationSource.Cancel(); + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + await postmasterMock.WaitForServerConnection(); - var exception = Assert.ThrowsAsync(async () => await t); - Assert.That(exception.InnerException, - Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); - Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + var processId = conn.ProcessID; + using var cancellationSource = new CancellationTokenSource(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + var t = cmd.ExecuteScalarAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var exception = Assert.ThrowsAsync(async () => await t)!; + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, + Is.EqualTo(processId)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3466")] + [Ignore("https://github.com/npgsql/npgsql/issues/4668")] + public async Task Bug3466([Values(false, true)] bool isBroken) + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Pooling = false + }; + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString(), completeCancellationImmediately: false); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var serverMock = await postmasterMock.WaitForServerConnection(); + + var processId = conn.ProcessID; + + using var cancellationSource = new CancellationTokenSource(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn) + { + CommandTimeout = 4 + }; + var t = Task.Run(() => cmd.ExecuteScalar()); + // We have to be sure the command's state is InProgress, otherwise the cancellation request will never be sent + cmd.WaitUntilCommandIsInProgress(); + // Perform cancellation, which will block on the server side + var cancelTask = Task.Run(() => cmd.Cancel()); + // Note what we have to wait for the cancellation request, otherwise the connection might be closed concurrently + // and the cancellation request is never send + var cancellationRequest = await postmasterMock.WaitForCancellationRequest(); + + if (isBroken) + { + Assert.ThrowsAsync(async () => await t); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + else + { + await serverMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + Assert.DoesNotThrowAsync(async () => await t); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + await conn.CloseAsync(); } - [Test, Description("Cancels an async query with the cancellation token, with unsuccessful PG cancellation (socket break)")] - public async Task CancelAsyncHard() + // Release the cancellation at the server side, and make sure it completes without an exception + cancellationRequest.Complete(); + Assert.DoesNotThrowAsync(async () => await cancelTask); + } + + [Test, Description("Check that cancel only affects the command on which its was invoked")] + [Explicit("Timing-sensitive")] + public async Task Cancel_cross_command() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd1 = CreateSleepCommand(conn, 2); + await using var cmd2 = new NpgsqlCommand("SELECT 1", conn); + var cancelTask = Task.Factory.StartNew(() => { - if (IsMultiplexing) - return; // Multiplexing, cancellation + Thread.Sleep(300); + cmd2.Cancel(); + }); + Assert.That(() => cmd1.ExecuteNonQueryAsync(), Throws.Nothing); + cancelTask.Wait(); + } - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); - await postmasterMock.WaitForServerConnection(); + #endregion - var processId = conn.ProcessID; + #region Cursors - var cancellationSource = new CancellationTokenSource(); - using var cmd = new NpgsqlCommand("SELECT 1", conn); - var t = cmd.ExecuteScalarAsync(cancellationSource.Token); - cancellationSource.Cancel(); + [Test] + public async Task Cursor_statement() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + using var t = conn.BeginTransaction(); + + for (var x = 0; x < 5; x++) + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); + + var i = 0; + var command = new NpgsqlCommand($"DECLARE TE CURSOR FOR SELECT * FROM {table}", conn); + command.ExecuteNonQuery(); + command.CommandText = "FETCH FORWARD 3 IN TE"; + var dr = command.ExecuteReader(); + + while (dr.Read()) + i++; + Assert.AreEqual(3, i); + dr.Close(); + + i = 0; + command.CommandText = "FETCH BACKWARD 1 IN TE"; + var dr2 = command.ExecuteReader(); + while (dr2.Read()) + i++; + Assert.AreEqual(1, i); + dr2.Close(); + + command.CommandText = "close te;"; + command.ExecuteNonQuery(); + } - var exception = Assert.ThrowsAsync(async () => await t); - Assert.That(exception.InnerException, Is.TypeOf()); - Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + [Test] + public async Task Cursor_move_RecordsAffected() + { + using var connection = await OpenConnectionAsync(); + using var transaction = connection.BeginTransaction(); + var command = new NpgsqlCommand("DECLARE curs CURSOR FOR SELECT * FROM (VALUES (1), (2), (3)) as t", connection); + command.ExecuteNonQuery(); + command.CommandText = "MOVE FORWARD ALL IN curs"; + var count = command.ExecuteNonQuery(); + Assert.AreEqual(3, count); + } - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, - Is.EqualTo(processId)); - } + #endregion + + #region CommandBehavior.CloseConnection + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/693")] + public async Task CloseConnection() + { + using var conn = await OpenConnectionAsync(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection)) + while (reader.Read()) {} + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1194")] + public async Task CloseConnection_with_open_reader_with_CloseConnection() + { + using var conn = await OpenConnectionAsync(); + var cmd = new NpgsqlCommand("SELECT 1", conn); + var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection); + var wasClosed = false; + reader.ReaderClosed += (sender, args) => { wasClosed = true; }; + conn.Close(); + Assert.That(wasClosed, Is.True); + } + + [Test] + public async Task CloseConnection_with_exception() + { + using var conn = await OpenConnectionAsync(); + using (var cmd = new NpgsqlCommand("SE", conn)) + Assert.That(() => cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection), Throws.Exception.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + #endregion - [Test, Description("Check that cancel only affects the command on which its was invoked")] - [Explicit("Timing-sensitive")] - [Timeout(3000)] - public async Task CancelCrossCommand() + [Test] + public async Task SingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1, 2 UNION SELECT 3, 4", conn); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow); + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + } + + [Test] + public async Task CommandText_not_set() + { + await using var conn = await OpenConnectionAsync(); + await using (var cmd = new NpgsqlCommand()) { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd1 = CreateSleepCommand(conn, 2)) - using (var cmd2 = new NpgsqlCommand("SELECT 1", conn)) - { - var cancelTask = Task.Factory.StartNew(() => - { - Thread.Sleep(300); - cmd2.Cancel(); - }); - Assert.That(() => cmd1.ExecuteNonQueryAsync(), Throws.Nothing); - cancelTask.Wait(); - } - } + cmd.Connection = conn; + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + cmd.CommandText = null; + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + cmd.CommandText = ""; } - #endregion + await using (var cmd = conn.CreateCommand()) + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + } + + [Test] + public async Task ExecuteScalar() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + await using var command = new NpgsqlCommand($"SELECT name FROM {table}", conn); + Assert.That(command.ExecuteScalarAsync, Is.Null); + + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES (NULL)"); + Assert.That(command.ExecuteScalarAsync, Is.EqualTo(DBNull.Value)); + + await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + for (var i = 0; i < 2; i++) + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); + Assert.That(command.ExecuteScalarAsync, Is.EqualTo("X")); + } + + [Test] + public async Task ExecuteNonQuery() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand { Connection = conn }; + var table = await CreateTempTable(conn, "name TEXT"); + + cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('John')"; + Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); + + cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('John'); INSERT INTO {table} (name) VALUES ('John')"; + Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(2)); + + cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('{new string('x', conn.Settings.WriteBufferSize)}')"; + Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); + } + + [Test, Description("Makes sure a command is unusable after it is disposed")] + public async Task Dispose() + { + await using var conn = await OpenConnectionAsync(); + var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Dispose(); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test, Description("Disposing a command with an open reader does not close the reader. This is the SqlClient behavior.")] + public async Task Command_Dispose_does_not_close_reader() + { + await using var conn = await OpenConnectionAsync(); + var cmd = new NpgsqlCommand("SELECT 1, 2", conn); + await cmd.ExecuteReaderAsync(); + cmd.Dispose(); + cmd = new NpgsqlCommand("SELECT 3", conn); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } - #region Cursors + [Test] + public async Task Non_standards_conforming_strings() + { + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); - [Test] - public async Task CursorStatement() + if (IsMultiplexing) { - using (var conn = await OpenConnectionAsync()) - { - using (var t = conn.BeginTransaction()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - - for (var x = 0; x < 5; x++) - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); - - var i = 0; - var command = new NpgsqlCommand($"DECLARE TE CURSOR FOR SELECT * FROM {table}", conn); - command.ExecuteNonQuery(); - command.CommandText = "FETCH FORWARD 3 IN TE"; - var dr = command.ExecuteReader(); - - while (dr.Read()) - i++; - Assert.AreEqual(3, i); - dr.Close(); - - i = 0; - command.CommandText = "FETCH BACKWARD 1 IN TE"; - var dr2 = command.ExecuteReader(); - while (dr2.Read()) - i++; - Assert.AreEqual(1, i); - dr2.Close(); - - command.CommandText = "close te;"; - command.ExecuteNonQuery(); - } - } + Assert.That(async () => await conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"), + Throws.Exception.TypeOf()); } - - [Test] - public async Task CursorMoveRecordsAffected() + else { - using (var connection = await OpenConnectionAsync()) - using (var transaction = connection.BeginTransaction()) - { - var command = new NpgsqlCommand("DECLARE curs CURSOR FOR SELECT * FROM (VALUES (1), (2), (3)) as t", connection); - command.ExecuteNonQuery(); - command.CommandText = "MOVE FORWARD ALL IN curs"; - var count = command.ExecuteNonQuery(); - Assert.AreEqual(3, count); - } + await conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + await conn.ExecuteNonQueryAsync("set standard_conforming_strings=on"); } + } - #endregion + [Test] + public async Task Parameter_and_operator_unclear() + { + await using var conn = await OpenConnectionAsync(); + //Without parenthesis the meaning of [, . and potentially other characters is + //a syntax error. See comment in NpgsqlCommand.GetClearCommandText() on "usually-redundant parenthesis". + await using var command = new NpgsqlCommand("select :arr[2]", conn); + command.Parameters.AddWithValue(":arr", new int[] {5, 4, 3, 2, 1}); + await using var rdr = await command.ExecuteReaderAsync(); + rdr.Read(); + Assert.AreEqual(rdr.GetInt32(0), 4); + } - #region CommandBehavior.CloseConnection + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4171")] + public async Task Cached_command_clears_parameters_placeholder_type() + { + await using var conn = await OpenConnectionAsync(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/693")] - public async Task CloseConnection() + await using (var cmd1 = conn.CreateCommand()) { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection)) - while (reader.Read()) {} - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } + cmd1.CommandText = "SELECT @p1"; + cmd1.Parameters.AddWithValue("@p1", 8); + await using var reader1 = await cmd1.ExecuteReaderAsync(); + reader1.Read(); + Assert.That(reader1[0], Is.EqualTo(8)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1194")] - public async Task CloseConnectionWithOpenReaderWithCloseConnection() + await using (var cmd2 = conn.CreateCommand()) { - using (var conn = await OpenConnectionAsync()) - { - var cmd = new NpgsqlCommand("SELECT 1", conn); - var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection); - var wasClosed = false; - reader.ReaderClosed += (sender, args) => { wasClosed = true; }; - conn.Close(); - Assert.That(wasClosed, Is.True); - } + cmd2.CommandText = "SELECT $1"; + cmd2.Parameters.AddWithValue(8); + await using var reader2 = await cmd2.ExecuteReaderAsync(); + reader2.Read(); + Assert.That(reader2[0], Is.EqualTo(8)); } + } - [Test] - public async Task CloseConnectionWithException() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SE", conn)) - Assert.That(() => cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection), Throws.Exception.TypeOf()); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } - } + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task Statement_mapped_output_parameters(CommandBehavior behavior) + { + await using var conn = await OpenConnectionAsync(); + var command = new NpgsqlCommand("select 3, 4 as param1, 5 as param2, 6;", conn); - #endregion + var p = new NpgsqlParameter("param2", NpgsqlDbType.Integer); + p.Direction = ParameterDirection.Output; + p.Value = -1; + command.Parameters.Add(p); - [Test] - public async Task SingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + p = new NpgsqlParameter("param1", NpgsqlDbType.Integer); + p.Direction = ParameterDirection.Output; + p.Value = -1; + command.Parameters.Add(p); - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT 1, 2 UNION SELECT 3, 4", conn)) - { - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow)) - { - Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - Assert.That(reader.Read(), Is.False); - } - } - } - } + p = new NpgsqlParameter("p", NpgsqlDbType.Integer); + p.Direction = ParameterDirection.Output; + p.Value = -1; + command.Parameters.Add(p); - [Test, Description("Makes sure writing an unset parameter isn't allowed")] - public async Task ParameterUnset() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("@p", NpgsqlDbType.Integer)); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); - } - } - } + await using var reader = await command.ExecuteReaderAsync(behavior); - [Test] - public void ParametersGetName() - { - var command = new NpgsqlCommand(); + Assert.AreEqual(4, command.Parameters["param1"].Value); + Assert.AreEqual(5, command.Parameters["param2"].Value); - // Add parameters. - command.Parameters.Add(new NpgsqlParameter(":Parameter1", DbType.Boolean)); - command.Parameters.Add(new NpgsqlParameter(":Parameter2", DbType.Int32)); - command.Parameters.Add(new NpgsqlParameter(":Parameter3", DbType.DateTime)); - command.Parameters.Add(new NpgsqlParameter("Parameter4", DbType.DateTime)); + reader.Read(); - var idbPrmtr = command.Parameters["Parameter1"]; - Assert.IsNotNull(idbPrmtr); - command.Parameters[0].Value = 1; + Assert.AreEqual(3, reader.GetInt32(0)); + Assert.AreEqual(4, reader.GetInt32(1)); + Assert.AreEqual(5, reader.GetInt32(2)); + Assert.AreEqual(6, reader.GetInt32(3)); + } - // Get by indexers. + [Test] + public async Task Bug1006158_output_parameters() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + var sproc = await GetTempProcedureName(conn); - Assert.AreEqual(":Parameter1", command.Parameters["Parameter1"].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters["Parameter2"].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters["Parameter3"].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters["Parameter4"].ParameterName); //Should this work? + var createFunction = $@" +CREATE PROCEDURE {sproc}(OUT a integer, OUT b boolean) AS $$ +BEGIN + a := 3; + b := true; +END +$$ LANGUAGE plpgsql;"; - Assert.AreEqual(":Parameter1", command.Parameters[0].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters[1].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters[2].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters[3].ParameterName); - } + var command = new NpgsqlCommand(createFunction, conn); + await command.ExecuteNonQueryAsync(); - [Test] - public async Task SameParamMultipleTimes() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p1", conn)) - { - cmd.Parameters.AddWithValue("@p1", 8); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(8)); - Assert.That(reader[1], Is.EqualTo(8)); - } - } - } + command = new NpgsqlCommand(sproc, conn); + command.CommandType = CommandType.StoredProcedure; - [Test] - public async Task GenericParameter() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", 8)); - cmd.Parameters.Add(new NpgsqlParameter("p2", 8) { NpgsqlDbType = NpgsqlDbType.Integer }); - cmd.Parameters.Add(new NpgsqlParameter("p3", "hello")); - cmd.Parameters.Add(new NpgsqlParameter("p4", new[] { 'f', 'o', 'o' })); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - Assert.That(reader.GetInt32(1), Is.EqualTo(8)); - Assert.That(reader.GetString(2), Is.EqualTo("hello")); - Assert.That(reader.GetString(3), Is.EqualTo("foo")); - } - } - } + command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32)); + command.Parameters[0].Direction = ParameterDirection.Output; + command.Parameters.Add(new NpgsqlParameter("b", DbType.Boolean)); + command.Parameters[1].Direction = ParameterDirection.Output; - [Test] - public async Task CommandTextNotSet() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; - Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); - cmd.CommandText = null; - Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); - cmd.CommandText = ""; - } + _ = await command.ExecuteScalarAsync(); - using (var cmd = conn.CreateCommand()) - Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); - } - } + Assert.AreEqual(3, command.Parameters[0].Value); + Assert.AreEqual(true, command.Parameters[1].Value); + } - [Test] - public async Task ExecuteScalar() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - using (var command = new NpgsqlCommand($"SELECT name FROM {table}", conn)) - { - Assert.That(command.ExecuteScalarAsync, Is.Null); + [Test] + public async Task Bug1010788_UpdateRowSource() + { + if (IsMultiplexing) + return; - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES (NULL)"); - Assert.That(command.ExecuteScalarAsync, Is.EqualTo(DBNull.Value)); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id SERIAL PRIMARY KEY, name TEXT"); - await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); - for (var i = 0; i < 2; i++) - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); - Assert.That(command.ExecuteScalarAsync, Is.EqualTo("X")); - } - } - } + var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); + Assert.AreEqual(UpdateRowSource.Both, command.UpdatedRowSource); - [Test] - public async Task ExecuteNonQuery() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand { Connection = conn }) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + var cmdBuilder = new NpgsqlCommandBuilder(); + var da = new NpgsqlDataAdapter(command); + cmdBuilder.DataAdapter = da; + Assert.IsNotNull(da.SelectCommand); + Assert.IsNotNull(cmdBuilder.DataAdapter); + + var updateCommand = cmdBuilder.GetUpdateCommand(); + Assert.AreEqual(UpdateRowSource.None, updateCommand.UpdatedRowSource); + } + + [Test] + public async Task TableDirect() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('foo')"); + using var cmd = new NpgsqlCommand(table, conn) { CommandType = CommandType.TableDirect }; + using var rdr = await cmd.ExecuteReaderAsync(); + Assert.That(rdr.Read(), Is.True); + Assert.That(rdr["name"], Is.EqualTo("foo")); + } - cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('John')"; - Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); - cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('John'); INSERT INTO {table} (name) VALUES ('John')"; - Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(2)); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/503")] + public async Task Invalid_UTF8() + { + const string badString = "SELECT 'abc\uD801\uD802d'"; + await using var dataSource = CreateDataSource(); + using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(() => conn.ExecuteScalarAsync(badString), Throws.Exception.TypeOf()); + } - cmd.CommandText = $"INSERT INTO {table} (name) VALUES ('{new string('x', conn.Settings.WriteBufferSize)}')"; - Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] + public async Task Use_across_connection_change([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn1 = await OpenConnectionAsync(); + using var conn2 = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1", conn1); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + cmd.Connection = conn2; + Assert.That(cmd.IsPrepared, Is.False); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } - cmd.Parameters.AddWithValue("not_used", DBNull.Value); - Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); - } - } - } + // The asserts we're testing are debug only. + [Test] + public async Task Use_after_reload_types_invalidates_cached_infos() + { + if (IsMultiplexing) + return; - [Test, Description("Makes sure a command is unusable after it is disposed")] - public async Task Dispose() + using var conn1 = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1", conn1); + cmd.Prepare(); + using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = await OpenConnectionAsync()) - { - var cmd = new NpgsqlCommand("SELECT 1", conn); - cmd.Dispose(); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); - Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); - Assert.That(() => cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); - Assert.That(() => cmd.PrepareAsync(), Throws.Exception.TypeOf()); - } + await reader.ReadAsync(); + Assert.DoesNotThrow(() => reader.GetInt32(0)); } - [Test, Description("Disposing a command with an open reader does not close the reader. This is the SqlClient behavior.")] - public async Task DisposeCommandDoesNotCloseReader() + await conn1.ReloadTypesAsync(); + + using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = await OpenConnectionAsync()) - { - var cmd = new NpgsqlCommand("SELECT 1, 2", conn); - await cmd.ExecuteReaderAsync(); - cmd.Dispose(); - cmd = new NpgsqlCommand("SELECT 3", conn); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); - } + await reader.ReadAsync(); + Assert.DoesNotThrow(() => reader.GetInt32(0)); } + } - [Test] - public async Task NonStandardsConformingStrings_NotSupported() - { - using var conn = await OpenConnectionAsync(); + [Test] + public async Task Parameter_overflow_message_length_throws() + { + await using var conn = CreateConnection(); + await conn.OpenAsync(); + await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn); + + var largeParam = new string('A', 1 << 29); + cmd.Parameters.AddWithValue("a", largeParam); + cmd.Parameters.AddWithValue("b", largeParam); + cmd.Parameters.AddWithValue("c", largeParam); + cmd.Parameters.AddWithValue("d", largeParam); + cmd.Parameters.AddWithValue("e", largeParam); + cmd.Parameters.AddWithValue("f", largeParam); + cmd.Parameters.AddWithValue("g", largeParam); + cmd.Parameters.AddWithValue("h", largeParam); + + Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync()); + } - Assert.That(() => conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"), - Throws.Exception.TypeOf()); - } + [Test] + public async Task Composite_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); - [Test] - public async Task ParameterAndOperatorUnclear() - { - using (var conn = await OpenConnectionAsync()) - { - //Without parenthesis the meaning of [, . and potentially other characters is - //a syntax error. See comment in NpgsqlCommand.GetClearCommandText() on "usually-redundant parenthesis". - using (var command = new NpgsqlCommand("select :arr[2]", conn)) - { - command.Parameters.AddWithValue(":arr", new int[] {5, 4, 3, 2, 1}); - using (var rdr = await command.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.AreEqual(rdr.GetInt32(0), 4); - } - } - } - } + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)"); - [Test] - [TestCase(CommandBehavior.Default)] - [TestCase(CommandBehavior.SequentialAccess)] - public async Task StatementMappedOutputParameters(CommandBehavior behavior) + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var largeString = new string('A', 1 << 29); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + cmd.Parameters.AddWithValue("a", new BigComposite { - using var conn = await OpenConnectionAsync(); - var command = new NpgsqlCommand("select 3, 4 as param1, 5 as param2, 6;", conn); + A = largeString, + B = largeString, + C = largeString, + D = largeString, + E = largeString, + F = largeString, + G = largeString, + H = largeString + }); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } - var p = new NpgsqlParameter("param2", NpgsqlDbType.Integer); - p.Direction = ParameterDirection.Output; - p.Value = -1; - command.Parameters.Add(p); + record BigComposite + { + public string A { get; set; } = null!; + public string B { get; set; } = null!; + public string C { get; set; } = null!; + public string D { get; set; } = null!; + public string E { get; set; } = null!; + public string F { get; set; } = null!; + public string G { get; set; } = null!; + public string H { get; set; } = null!; + } - p = new NpgsqlParameter("param1", NpgsqlDbType.Integer); - p.Direction = ParameterDirection.Output; - p.Value = -1; - command.Parameters.Add(p); + [Test] + public async Task Array_overflow_message_length_throws() + { + await using var connection = await OpenConnectionAsync(); - p = new NpgsqlParameter("p", NpgsqlDbType.Integer); - p.Direction = ParameterDirection.Output; - p.Value = -1; - command.Parameters.Add(p); + var largeString = new string('A', 1 << 29); - using var reader = await command.ExecuteReaderAsync(behavior); + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var array = new[] + { + largeString, + largeString, + largeString, + largeString, + largeString, + largeString, + largeString, + largeString + }; + cmd.Parameters.AddWithValue("a", array); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } - Assert.AreEqual(4, command.Parameters["param1"].Value); - Assert.AreEqual(5, command.Parameters["param2"].Value); + [Test] + public async Task Range_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); - reader.Read(); + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})"); - Assert.AreEqual(3, reader.GetInt32(0)); - Assert.AreEqual(4, reader.GetInt32(1)); - Assert.AreEqual(5, reader.GetInt32(2)); - Assert.AreEqual(6, reader.GetInt32(3)); - } + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - [Test] - public async Task CaseInsensitiveParameterNames() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select :p1", conn)) - { - command.Parameters.Add(new NpgsqlParameter("P1", NpgsqlDbType.Integer)).Value = 5; - var result = await command.ExecuteScalarAsync(); - Assert.AreEqual(5, result); - } - } + var largeString = new string('A', (1 << 28) + 2000000); - [Test] - public async Task TestBug1006158OutputParameters() + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var composite = new BigComposite { - using (var conn = await OpenConnectionAsync()) - await using (GetTempFunctionName(conn, out var function)) - { - var createFunction = $@" -CREATE OR REPLACE FUNCTION {function}(OUT a integer, OUT b boolean) AS -$BODY$DECLARE -BEGIN - a := 3; - b := true; -END;$BODY$ -LANGUAGE 'plpgsql' VOLATILE;"; + A = largeString, + B = largeString, + C = largeString, + D = largeString + }; + var range = new NpgsqlRange(composite, composite); + cmd.Parameters.Add(new NpgsqlParameter + { + Value = range, + ParameterName = "a", + DataTypeName = rangeType + }); - var command = new NpgsqlCommand(createFunction, conn); - await command.ExecuteNonQueryAsync(); + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } - command = new NpgsqlCommand(function, conn); - command.CommandType = CommandType.StoredProcedure; + [Test] + public async Task Multirange_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + MinimumPgVersion(adminConnection, "14.0", "Multirange types were introduced in PostgreSQL 14"); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); - command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32)); - command.Parameters[0].Direction = ParameterDirection.Output; - command.Parameters.Add(new NpgsqlParameter("b", DbType.Boolean)); - command.Parameters[1].Direction = ParameterDirection.Output; + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})"); - var result = await command.ExecuteScalarAsync(); + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - Assert.AreEqual(3, command.Parameters[0].Value); - Assert.AreEqual(true, command.Parameters[1].Value); - } - } + var largeString = new string('A', (1 << 28) + 2000000); - [Test] - public async Task Bug1010788UpdateRowSource() + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var composite = new BigComposite + { + A = largeString + }; + var range = new NpgsqlRange(composite, composite); + var multirange = new[] + { + range, + range, + range, + range + }; + cmd.Parameters.Add(new NpgsqlParameter { - if (IsMultiplexing) - return; + Value = multirange, + ParameterName = "a", + DataTypeName = rangeType + "_multirange" + }); - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "id SERIAL PRIMARY KEY, name TEXT", out var table); + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } - var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); - Assert.AreEqual(UpdateRowSource.Both, command.UpdatedRowSource); + [Test, Description("CreateCommand before connection open")] + [IssueLink("https://github.com/npgsql/npgsql/issues/565")] + public async Task Create_command_before_connection_open() + { + using var conn = new NpgsqlConnection(ConnectionString); + var cmd = new NpgsqlCommand("SELECT 1", conn); + conn.Open(); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } - var cmdBuilder = new NpgsqlCommandBuilder(); - var da = new NpgsqlDataAdapter(command); - cmdBuilder.DataAdapter = da; - Assert.IsNotNull(da.SelectCommand); - Assert.IsNotNull(cmdBuilder.DataAdapter); + [Test] + public void Connection_not_set_throws() + { + var cmd = new NpgsqlCommand("SELECT 1"); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } - var updateCommand = cmdBuilder.GetUpdateCommand(); - Assert.AreEqual(UpdateRowSource.None, updateCommand.UpdatedRowSource); - } - } + [Test] + public void Connection_not_open_throws() + { + using var conn = CreateConnection(); + var cmd = new NpgsqlCommand("SELECT 1", conn); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } - [Test] - public async Task TableDirect() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + [Test] + public async Task ExecuteNonQuery_Throws_PostgresException([Values] bool async) + { + if (!async && IsMultiplexing) + return; - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('foo')"); - using (var cmd = new NpgsqlCommand(table, conn) { CommandType = CommandType.TableDirect }) - using (var rdr = await cmd.ExecuteReaderAsync()) - { - Assert.That(rdr.Read(), Is.True); - Assert.That(rdr["name"], Is.EqualTo("foo")); - } - } - } + await using var conn = await OpenConnectionAsync(); - [Test] - [TestCase(CommandBehavior.Default)] - [TestCase(CommandBehavior.SequentialAccess)] - public async Task InputAndOutputParameters(CommandBehavior behavior) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @c-1 AS c, @a+2 AS b", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("a", 3)); - var b = new NpgsqlParameter { ParameterName = "b", Direction = ParameterDirection.Output }; - cmd.Parameters.Add(b); - var c = new NpgsqlParameter { ParameterName = "c", Direction = ParameterDirection.InputOutput, Value = 4 }; - cmd.Parameters.Add(c); - using (await cmd.ExecuteReaderAsync(behavior)) - { - Assert.AreEqual(5, b.Value); - Assert.AreEqual(3, c.Value); - } - } - } + var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); + var table2 = await CreateTempTable(conn, $"id SERIAL primary key, {table1}_id integer references {table1}(id) INITIALLY DEFERRED"); - [Test] - public async Task SendUnknown([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + var sql = $"insert into {table2} ({table1}_id) values (1) returning id"; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::TIMESTAMP", conn)) - { - cmd.CommandText = "SELECT @p::TIMESTAMP"; - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Unknown) { Value = "2008-1-1" }); - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(new DateTime(2008, 1, 1))); - } - } - } + var ex = async + ? Assert.ThrowsAsync(async () => await conn.ExecuteNonQueryAsync(sql)) + : Assert.Throws(() => conn.ExecuteNonQuery(sql)); + Assert.That(ex!.SqlState, Is.EqualTo(PostgresErrorCodes.ForeignKeyViolation)); + } + + [Test] + public async Task ExecuteScalar_Throws_PostgresException([Values] bool async) + { + if (!async && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + + var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); + var table2 = await CreateTempTable(conn, $"id SERIAL primary key, {table1}_id integer references {table1}(id) INITIALLY DEFERRED"); + + var sql = $"insert into {table2} ({table1}_id) values (1) returning id"; + + var ex = async + ? Assert.ThrowsAsync(async () => await conn.ExecuteScalarAsync(sql)) + : Assert.Throws(() => conn.ExecuteScalar(sql)); + Assert.That(ex!.SqlState, Is.EqualTo(PostgresErrorCodes.ForeignKeyViolation)); + } + + [Test] + public async Task ExecuteReader_Throws_PostgresException([Values] bool async) + { + if (!async && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + + var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); + var table2 = await CreateTempTable(conn, $"id SERIAL primary key, {table1}_id integer references {table1}(id) INITIALLY DEFERRED"); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"insert into {table2} ({table1}_id) values (1) returning id"; + + await using var reader = async + ? await cmd.ExecuteReaderAsync() + : cmd.ExecuteReader(); + + Assert.IsTrue(async ? await reader.ReadAsync() : reader.Read()); + var value = reader.GetInt32(0); + Assert.That(value, Is.EqualTo(1)); + Assert.IsFalse(async ? await reader.ReadAsync() : reader.Read()); + var ex = async + ? Assert.ThrowsAsync(async () => await reader.NextResultAsync()) + : Assert.Throws(() => reader.NextResult()); + Assert.That(ex!.SqlState, Is.EqualTo(PostgresErrorCodes.ForeignKeyViolation)); + } + + [Test] + public void Command_is_recycled() + { + using var conn = OpenConnection(); + var cmd1 = conn.CreateCommand(); + cmd1.CommandText = "SELECT @p1"; + var tx = conn.BeginTransaction(); + cmd1.Transaction = tx; + cmd1.Parameters.AddWithValue("p1", 8); + _ = cmd1.ExecuteScalar(); + cmd1.Dispose(); + + var cmd2 = conn.CreateCommand(); + Assert.That(cmd2, Is.SameAs(cmd1)); + Assert.That(cmd2.CommandText, Is.Empty); + Assert.That(cmd2.CommandType, Is.EqualTo(CommandType.Text)); + Assert.That(cmd2.Transaction, Is.Null); + Assert.That(cmd2.Parameters, Is.Empty); + // TODO: Leaving this for now, since it'll be replaced by the new batching API + // Assert.That(cmd2.Statements, Is.Empty); + } + + [Test] + public void Command_recycled_resets_CommandType() + { + using var conn = CreateConnection(); + var cmd1 = conn.CreateCommand(); + cmd1.CommandType = CommandType.StoredProcedure; + cmd1.Dispose(); + + var cmd2 = conn.CreateCommand(); + Assert.That(cmd2.CommandType, Is.EqualTo(CommandType.Text)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/503")] - public async Task InvalidUTF8() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/831")] + [IssueLink("https://github.com/npgsql/npgsql/issues/2795")] + public async Task Many_parameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "some_column INT"); + using var cmd = new NpgsqlCommand { Connection = conn }; + var sb = new StringBuilder($"INSERT INTO {table} (some_column) VALUES "); + for (var i = 0; i < ushort.MaxValue; i++) { - const string badString = "SELECT 'abc\uD801\uD802d'"; - using (var conn = await OpenConnectionAsync()) - { - Assert.That(() => conn.ExecuteScalarAsync(badString), Throws.Exception.TypeOf()); - } + var paramName = "p" + i; + cmd.Parameters.Add(new NpgsqlParameter(paramName, 8)); + if (i > 0) + sb.Append(", "); + sb.Append($"(@{paramName})"); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] - public async Task UseAcrossConnectionChange([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + cmd.CommandText = sb.ToString(); - using (var conn1 = await OpenConnectionAsync()) - using (var conn2 = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn1)) - { - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - cmd.Connection = conn2; - Assert.That(cmd.IsPrepared, Is.False); - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); - } - } + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + + await cmd.ExecuteNonQueryAsync(); + } - [Test, Description("CreateCommand before connection open")] - [IssueLink("https://github.com/npgsql/npgsql/issues/565")] - public async Task CreateCommandBeforeConnectionOpen() + [Test, Description("Bypasses PostgreSQL's uint16 limitation on the number of parameters")] + [IssueLink("https://github.com/npgsql/npgsql/issues/831")] + [IssueLink("https://github.com/npgsql/npgsql/issues/858")] + [IssueLink("https://github.com/npgsql/npgsql/issues/2703")] + public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand { Connection = conn }; + var sb = new StringBuilder("SOME RANDOM SQL "); + for (var i = 0; i < ushort.MaxValue + 1; i++) { - using (var conn = new NpgsqlConnection(ConnectionString)) { - var cmd = new NpgsqlCommand("SELECT 1", conn); - conn.Open(); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); - } + var paramName = "p" + i; + cmd.Parameters.Add(new NpgsqlParameter(paramName, 8)); + if (i > 0) + sb.Append(", "); + sb.Append('@'); + sb.Append(paramName); } - [Test] - public void ConnectionNotSet() + cmd.CommandText = sb.ToString(); + + if (prepare == PrepareOrNot.Prepared) { - var cmd = new NpgsqlCommand("SELECT 1"); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.Prepare(), Throws.Exception + .InstanceOf() + .With.Message.EqualTo("A statement cannot have more than 65535 parameters")); } - - [Test] - public void ConnectionNotOpen() + else { - using var conn = new NpgsqlConnection(ConnectionString); - var cmd = new NpgsqlCommand("SELECT 1", conn); - Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception + .InstanceOf() + .With.Message.EqualTo("A statement cannot have more than 65535 parameters")); } + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/831")] - [IssueLink("https://github.com/npgsql/npgsql/issues/2795")] - [Timeout(10000)] - public async Task ManyParameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + [Test, Description("An individual statement cannot have more than 65535 parameters, but a command can (across multiple statements).")] + [IssueLink("https://github.com/npgsql/npgsql/issues/1199")] + public async Task Many_parameters_across_statements() + { + // Create a command with 1000 statements which have 70 params each + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand { Connection = conn }; + var paramIndex = 0; + var sb = new StringBuilder(); + for (var statementIndex = 0; statementIndex < 1000; statementIndex++) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "some_column INT", out var table); - using var cmd = new NpgsqlCommand { Connection = conn }; - var sb = new StringBuilder($"INSERT INTO {table} (some_column) VALUES "); - for (var i = 0; i < ushort.MaxValue; i++) + if (statementIndex > 0) + sb.Append("; "); + sb.Append("SELECT "); + var startIndex = paramIndex; + var endIndex = paramIndex + 70; + for (; paramIndex < endIndex; paramIndex++) { - var paramName = "p" + i; + var paramName = "p" + paramIndex; cmd.Parameters.Add(new NpgsqlParameter(paramName, 8)); - if (i > 0) + if (paramIndex > startIndex) sb.Append(", "); - sb.Append($"(@{paramName})"); + sb.Append('@'); + sb.Append(paramName); } + } - cmd.CommandText = sb.ToString(); - - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); + cmd.CommandText = sb.ToString(); + await cmd.ExecuteNonQueryAsync(); + } - await cmd.ExecuteNonQueryAsync(); + [Test, Description("Makes sure that Npgsql doesn't attempt to send all data before the user can start reading. That would cause a deadlock.")] + public async Task Batched_big_statements_do_not_deadlock() + { + // We're going to send a large multistatement query that would exhaust both the client's and server's + // send and receive buffers (assume 64k per buffer). + var data = new string('x', 1024); + using var conn = await OpenConnectionAsync(); + var sb = new StringBuilder(); + for (var i = 0; i < 500; i++) + sb.Append("SELECT @p;"); + using var cmd = new NpgsqlCommand(sb.ToString(), conn); + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, data); + using var reader = await cmd.ExecuteReaderAsync(); + for (var i = 0; i < 500; i++) + { + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(data)); + reader.NextResult(); } + } + + [Test] + public void Batched_small_then_big_statements_do_not_deadlock_in_sync_io() + { + if (IsMultiplexing) + return; // Multiplexing, sync I/O + + // This makes sure we switch to async writing for batches, starting from the 2nd statement at the latest. + // Otherwise, a small first first statement followed by a huge big one could cause us to deadlock, as we're stuck + // synchronously sending the 2nd statement while PG is stuck sending the results of the 1st. + using var conn = OpenConnection(); + var data = new string('x', 5_000_000); + using var cmd = new NpgsqlCommand("SELECT generate_series(1, 500000); SELECT @p", conn); + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, data); + cmd.ExecuteNonQuery(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1429")] + public async Task Same_command_different_param_values() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.AddWithValue("p", 8); + await cmd.ExecuteNonQueryAsync(); + + cmd.Parameters[0].Value = 9; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1429")] + public async Task Same_command_different_param_instances() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.AddWithValue("p", 8); + await cmd.ExecuteNonQueryAsync(); + + cmd.Parameters.RemoveAt(0); + cmd.Parameters.AddWithValue("p", 9); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); + } - [Test, Description("Bypasses PostgreSQL's uint16 limitation on the number of parameters")] - [IssueLink("https://github.com/npgsql/npgsql/issues/831")] - [IssueLink("https://github.com/npgsql/npgsql/issues/858")] - [IssueLink("https://github.com/npgsql/npgsql/issues/2703")] - public async Task TooManyParameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3509"), Ignore("Flaky")] + public async Task Bug3509() + { + if (IsMultiplexing) + return; + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + KeepAlive = 1, + }; + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var serverMock = await postmasterMock.WaitForServerConnection(); + // Wait for a keepalive to arrive at the server, reply with an error + await serverMock.WaitForData(); + var queryTask = Task.Run(async () => await conn.ExecuteNonQueryAsync("SELECT 1")); + // TODO: kind of flaky - think of the way to rewrite + // giving a queryTask some time to get stuck on a lock + await Task.Delay(300); + await serverMock + .WriteErrorResponse("42") + .WriteReadyForQuery() + .FlushAsync(); + + await serverMock + .WriteScalarResponseAndFlush(1); + + var ex = Assert.ThrowsAsync(async () => await queryTask)!; + Assert.That(ex.InnerException, Is.TypeOf() + .With.InnerException.TypeOf()); + } - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand { Connection = conn }; - var sb = new StringBuilder("SOME RANDOM SQL "); - for (var i = 0; i < ushort.MaxValue + 1; i++) - { - var paramName = "p" + i; - cmd.Parameters.Add(new NpgsqlParameter(paramName, 8)); - if (i > 0) - sb.Append(", "); - sb.Append('@'); - sb.Append(paramName); - } - cmd.CommandText = sb.ToString(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4134")] + public async Task Cached_command_double_dispose() + { + await using var conn = await OpenConnectionAsync(); - if (prepare == PrepareOrNot.Prepared) - { - Assert.That(() => cmd.Prepare(), Throws.Exception - .InstanceOf() - .With.Message.EqualTo("A statement cannot have more than 65535 parameters")); - } - else - { - Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception - .InstanceOf() - .With.Message.EqualTo("A statement cannot have more than 65535 parameters")); - } - } + var cmd1 = conn.CreateCommand(); + cmd1.Dispose(); + cmd1.Dispose(); + + var cmd2 = conn.CreateCommand(); + Assert.That(cmd2, Is.SameAs(cmd1)); + + cmd2.CommandText = "SELECT 1"; + Assert.That(await cmd2.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4330")] + public async Task Prepare_with_positional_placeholders_after_named() + { + if (IsMultiplexing) + return; // Explicit preparation + + await using var conn = await OpenConnectionAsync(); + + await using var command = new NpgsqlCommand("SELECT @p", conn); + command.Parameters.AddWithValue("p", 10); + await command.ExecuteNonQueryAsync(); + + command.Parameters.Clear(); + + command.CommandText = "SELECT $1"; + command.Parameters.Add(new() { NpgsqlDbType = NpgsqlDbType.Integer }); + Assert.DoesNotThrowAsync(() => command.PrepareAsync()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4621")] + [Description("Most of 08* errors are coming whenever there was an error while connecting to a remote server from a cluster, so the connection to the cluster is still OK")] + public async Task Postgres_connection_errors_not_break_connection() + { + if (IsMultiplexing) + return; + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; + var queryTask = cmd.ExecuteNonQueryAsync(); + + var server = await postmasterMock.WaitForServerConnection(); + await server + .WriteErrorResponse(PostgresErrorCodes.SqlClientUnableToEstablishSqlConnection) + .WriteReadyForQuery() + .FlushAsync(); + + var ex = Assert.ThrowsAsync(async () => await queryTask)!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.SqlClientUnableToEstablishSqlConnection)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4804")] + [Description("Concurrent write and read failure can lead to deadlocks while cleaning up the connector.")] + public async Task Concurrent_read_write_failure_deadlock() + { + if (IsMultiplexing) + return; + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + await using var cmd = conn.CreateCommand(); + // Attempt to send a big enough query to fill buffers + // That way the write side should be stuck, waiting for the server to empty buffers + cmd.CommandText = new string('a', 8_000_000); + var queryTask = cmd.ExecuteNonQueryAsync(); + + var server = await postmasterMock.WaitForServerConnection(); + server.Close(); + + Assert.ThrowsAsync(async () => await queryTask); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4906")] + [Description("Make sure we don't cancel a prepended query (and do not deadlock in case of a failure)")] + [Explicit("Flaky due to #5033")] + public async Task Not_cancel_prepended_query([Values] bool failPrependedQuery) + { + if (IsMultiplexing) + return; - [Test, Description("An individual statement cannot have more than 65535 parameters, but a command can (across multiple statements).")] - [IssueLink("https://github.com/npgsql/npgsql/issues/1199")] - public async Task ManyParametersAcrossStatements() + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + var csb = new NpgsqlConnectionStringBuilder(postmasterMock.ConnectionString) { - // Create a command with 1000 statements which have 70 params each - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand { Connection = conn }) - { - var paramIndex = 0; - var sb = new StringBuilder(); - for (var statementIndex = 0; statementIndex < 1000; statementIndex++) - { - if (statementIndex > 0) - sb.Append("; "); - sb.Append("SELECT "); - var startIndex = paramIndex; - var endIndex = paramIndex + 70; - for (; paramIndex < endIndex; paramIndex++) - { - var paramName = "p" + paramIndex; - cmd.Parameters.Add(new NpgsqlParameter(paramName, 8)); - if (paramIndex > startIndex) - sb.Append(", "); - sb.Append('@'); - sb.Append(paramName); - } - } + NoResetOnClose = false + }; + await using var dataSource = CreateDataSource(csb.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + // reopen connection to append prepended query + await conn.CloseAsync(); + await conn.OpenAsync(); + + using var cts = new CancellationTokenSource(); + var queryTask = conn.ExecuteNonQueryAsync("SELECT 1", cancellationToken: cts.Token); + + var server = await postmasterMock.WaitForServerConnection(); + await server.ExpectSimpleQuery("DISCARD ALL"); + await server.ExpectExtendedQuery(); + + var cancelTask = Task.Run(cts.Cancel); + var cancellationRequestTask = postmasterMock.WaitForCancellationRequest().AsTask(); + // Give 1 second to make sure we didn't send cancellation request + await Task.Delay(1000); + Assert.IsFalse(cancelTask.IsCompleted); + Assert.IsFalse(cancellationRequestTask.IsCompleted); + + if (failPrependedQuery) + { + await server + .WriteErrorResponse(PostgresErrorCodes.SyntaxError) + .WriteReadyForQuery() + .FlushAsync(); - cmd.CommandText = sb.ToString(); - await cmd.ExecuteNonQueryAsync(); - } + await cancelTask; + await cancellationRequestTask; + Assert.ThrowsAsync(async () => await queryTask); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + return; } - [Test, Description("Makes sure that Npgsql doesn't attempt to send all data before the user can start reading. That would cause a deadlock.")] - public async Task ReadWriteDeadlock() + await server + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await cancelTask; + await cancellationRequestTask; + + await server + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + Assert.ThrowsAsync(async () => await queryTask); + + queryTask = conn.ExecuteNonQueryAsync("SELECT 1"); + await server.ExpectExtendedQuery(); + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteNoData() + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + await queryTask; + } + + [Test] + public async Task Cancel_while_reading_from_long_running_query() + { + if (IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = """ +SELECT *, CASE WHEN "t"."i" = 50000 THEN pg_sleep(100) ELSE NULL END +FROM +( + SELECT generate_series(1, 1000000) AS "i" +) AS "t" +"""; + + using (var cts = new CancellationTokenSource()) + await using (var reader = await cmd.ExecuteReaderAsync(cts.Token)) { - // We're going to send a large multistatement query that would exhaust both the client's and server's - // send and receive buffers (assume 64k per buffer). - var data = new string('x', 1024); - using (var conn = await OpenConnectionAsync()) + Assert.ThrowsAsync(async () => { - var sb = new StringBuilder(); - for (var i = 0; i < 500; i++) - sb.Append("SELECT @p;"); - using (var cmd = new NpgsqlCommand(sb.ToString(), conn)) + var i = 0; + while (await reader.ReadAsync(cts.Token)) { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, data); - using (var reader = await cmd.ExecuteReaderAsync()) - { - for (var i = 0; i < 500; i++) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(data)); - reader.NextResult(); - } - } + i++; + if (i == 10) + cts.Cancel(); } - } + }); } - [Test] - public async Task StatementOID() + cmd.CommandText = "SELECT 42"; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(42)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5218")] + [Description("Make sure we do not lose unread messages after resetting oversize buffer")] + public async Task Oversize_buffer_lost_messages() + { + if (IsMultiplexing) + return; + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - using var conn = await OpenConnectionAsync(); + NoResetOnClose = true + }; + await using var mock = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(mock.ConnectionString); + await using var connection = await dataSource.OpenConnectionAsync(); + var connector = connection.Connector!; + + var server = await mock.WaitForServerConnection(); + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(TextOid)) + .WriteDataRowWithFlush(Encoding.ASCII.GetBytes(new string('a', connection.Settings.ReadBufferSize * 2))); + // Just to make sure we have enough space + await server.FlushAsync(); + await server + .WriteDataRow(Encoding.ASCII.GetBytes("abc")) + .WriteCommandComplete() + .WriteReadyForQuery() + .WriteParameterStatus("SomeKey", "SomeValue") + .FlushAsync(); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT 1"; + await using (await cmd.ExecuteReaderAsync()) { } + + await connection.CloseAsync(); + await connection.OpenAsync(); + + Assert.AreSame(connector, connection.Connector); + // We'll get new value after the next query reads ParameterStatus from the buffer + Assert.That(connection.PostgresParameters, Does.Not.ContainKey("SomeKey").WithValue("SomeValue")); + + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(TextOid)) + .WriteDataRow(Encoding.ASCII.GetBytes("abc")) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await cmd.ExecuteNonQueryAsync(); + + Assert.That(connection.PostgresParameters, Contains.Key("SomeKey").WithValue("SomeValue")); + } + + #region Logging + + [Test] + public async Task Log_ExecuteScalar_single_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); - MaximumPgVersionExclusive(conn, "12.0", - "Support for 'CREATE TABLE ... WITH OIDS' has been removed in 12.0. See https://www.postgresql.org/docs/12/release-12.html#id-1.11.6.5.4"); + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($"CREATE TABLE {table} (name TEXT) WITH OIDS"); + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - using var cmd = new NpgsqlCommand( - $"INSERT INTO {table} (name) VALUES (@p1);" + - $"UPDATE {table} SET name='b' WHERE name=@p2", - conn); + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } - cmd.Parameters.AddWithValue("p1", "foo"); - cmd.Parameters.AddWithValue("p2", "bar"); - await cmd.ExecuteNonQueryAsync(); + [Test] + public async Task Log_ExecuteScalar_single_statement_with_positional_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); + cmd.Parameters.Add(new() { Value = 8 }); + cmd.Parameters.Add(new() { NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); - Assert.That(cmd.Statements[0].OID, Is.Not.EqualTo(0)); - Assert.That(cmd.Statements[1].OID, Is.EqualTo(0)); + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1429")] - public async Task SameCommandDifferentParamValues() + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") + .And.Contains("SELECT $1, $2") + .And.Contains("Parameters: [8, NULL]")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Log_ExecuteScalar_single_statement_with_named_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); + + using (listLoggerProvider.Record()) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", 8); - await cmd.ExecuteNonQueryAsync(); + await cmd.ExecuteScalarAsync(); + } - cmd.Parameters[0].Value = 9; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); - } + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") + .And.Contains("SELECT $1, $2") + .And.Contains("Parameters: [8, NULL]")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Log_ExecuteScalar_single_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); + cmd.Parameters.Add(new() { Value = 8 }); + cmd.Parameters.Add(new() { Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1429")] - public async Task SameCommandDifferentParamInstances() + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains($"SELECT $1, $2")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + } + + [Test] + public async Task Log_ExecuteScalar_multiple_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + + using (listLoggerProvider.Record()) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", 8); - await cmd.ExecuteNonQueryAsync(); + await cmd.ExecuteScalarAsync(); + } - cmd.Parameters.RemoveAt(0); - cmd.Parameters.AddWithValue("p", 9); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); - } + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT 1, System.Object[]), (SELECT 2, System.Object[])]")); + var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT 1")); + Assert.That(batchCommands[0].Parameters, Is.Empty); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT 2")); + Assert.That(batchCommands[1].Parameters, Is.Empty); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Log_ExecuteScalar_multiple_statement_with_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); } - public CommandTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, System.Object[]), (SELECT $1, System.Object[])]")); + var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[0].Parameters[0], Is.EqualTo(8)); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1].Parameters[0], Is.EqualTo(9)); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); } + + [Test] + public async Task Log_ExecuteScalar_multiple_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1]")); + var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1], Is.EqualTo("SELECT $1")); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + + if (!IsMultiplexing) + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + #endregion Logging + + public CommandTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/ConnectionStringBuilderTests.cs b/test/Npgsql.Tests/ConnectionStringBuilderTests.cs index 50cc8d3b3e..6e2d2e3a04 100644 --- a/test/Npgsql.Tests/ConnectionStringBuilderTests.cs +++ b/test/Npgsql.Tests/ConnectionStringBuilderTests.cs @@ -1,112 +1,121 @@ using System; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +class ConnectionStringBuilderTests { - class ConnectionStringBuilderTests + [Test] + public void Basic() { - [Test] - public void Basic() - { - var builder = new NpgsqlConnectionStringBuilder(); - Assert.That(builder.Count, Is.EqualTo(0)); - Assert.That(builder.ContainsKey("server"), Is.True); - builder.Host = "myhost"; - Assert.That(builder["host"], Is.EqualTo("myhost")); - Assert.That(builder.Count, Is.EqualTo(1)); - Assert.That(builder.ConnectionString, Is.EqualTo("Host=myhost")); - builder.Remove("HOST"); - Assert.That(builder["host"], Is.EqualTo("")); - Assert.That(builder.Count, Is.EqualTo(0)); - } + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(builder.Count, Is.EqualTo(0)); + Assert.That(builder.ContainsKey("server"), Is.True); + builder.Host = "myhost"; + Assert.That(builder["host"], Is.EqualTo("myhost")); + Assert.That(builder.Count, Is.EqualTo(1)); + Assert.That(builder.ConnectionString, Is.EqualTo("Host=myhost")); + builder.Remove("HOST"); + Assert.That(builder["host"], Is.EqualTo("")); + Assert.That(builder.Count, Is.EqualTo(0)); + } - [Test] - public void FromString() - { - var builder = new NpgsqlConnectionStringBuilder(); - builder.ConnectionString = "Host=myhost;EF Template Database=foo"; - Assert.That(builder.Host, Is.EqualTo("myhost")); - Assert.That(builder.EntityTemplateDatabase, Is.EqualTo("foo")); - } + [Test] + public void TryGetValue() + { + var builder = new NpgsqlConnectionStringBuilder(); + builder.ConnectionString = "Host=myhost"; + Assert.That(builder.TryGetValue("Host", out var value), Is.True); + Assert.That(value, Is.EqualTo("myhost")); + Assert.That(builder.TryGetValue("SomethingUnknown", out value), Is.False); + } - [Test] - public void TryGetValue() - { - var builder = new NpgsqlConnectionStringBuilder(); - builder.ConnectionString = "Host=myhost"; - Assert.That(builder.TryGetValue("Host", out var value), Is.True); - Assert.That(value, Is.EqualTo("myhost")); - Assert.That(builder.TryGetValue("SomethingUnknown", out value), Is.False); - } + [Test] + public void Remove() + { + var builder = new NpgsqlConnectionStringBuilder(); + builder.SslMode = SslMode.Require; + Assert.That(builder["SSL Mode"], Is.EqualTo(SslMode.Require)); + builder.Remove("SSL Mode"); + Assert.That(builder.ConnectionString, Is.EqualTo("")); + builder.CommandTimeout = 120; + Assert.That(builder["Command Timeout"], Is.EqualTo(120)); + builder.Remove("Command Timeout"); + Assert.That(builder.ConnectionString, Is.EqualTo("")); + } - [Test] - public void Remove() - { - var builder = new NpgsqlConnectionStringBuilder(); - builder.SslMode = SslMode.Prefer; - Assert.That(builder["SSL Mode"], Is.EqualTo(SslMode.Prefer)); - builder.Remove("SSL Mode"); - Assert.That(builder.ConnectionString, Is.EqualTo("")); - builder.CommandTimeout = 120; - Assert.That(builder["Command Timeout"], Is.EqualTo(120)); - builder.Remove("Command Timeout"); - Assert.That(builder.ConnectionString, Is.EqualTo("")); - } + [Test] + public void Clear() + { + var builder = new NpgsqlConnectionStringBuilder { Host = "myhost" }; + builder.Clear(); + Assert.That(builder.Count, Is.EqualTo(0)); + Assert.That(builder["host"], Is.EqualTo("")); + Assert.That(builder.Host, Is.Null); + } + + [Test] + public void Removing_resets_to_default() + { + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); + builder.Port = 8; + builder.Remove("Port"); + Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); + } - [Test] - public void Clear() - { - var builder = new NpgsqlConnectionStringBuilder { Host = "myhost" }; - builder.Clear(); - Assert.That(builder.Count, Is.EqualTo(0)); - Assert.That(builder["host"], Is.EqualTo("")); - Assert.That(builder.Host, Is.Null); - } + [Test] + public void Setting_to_null_resets_to_default() + { + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); + builder.Port = 8; + builder["Port"] = null; + Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); + } - [Test] - public void Default() - { - var builder = new NpgsqlConnectionStringBuilder(); - Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); - builder.Port = 8; - builder.Remove("Port"); - Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); - } + [Test] + public void Enum() + { + var builder = new NpgsqlConnectionStringBuilder(); + builder.ConnectionString = "SslMode=Require"; + Assert.That(builder.SslMode, Is.EqualTo(SslMode.Require)); + Assert.That(builder.Count, Is.EqualTo(1)); + } - [Test] - public void Enum() - { - var builder = new NpgsqlConnectionStringBuilder(); - builder.ConnectionString = "SslMode=Prefer"; - Assert.That(builder.SslMode, Is.EqualTo(SslMode.Prefer)); - Assert.That(builder.Count, Is.EqualTo(1)); - } + [Test] + public void Enum_insensitive() + { + var builder = new NpgsqlConnectionStringBuilder(); + builder.ConnectionString = "SslMode=require"; + Assert.That(builder.SslMode, Is.EqualTo(SslMode.Require)); + Assert.That(builder.Count, Is.EqualTo(1)); + } - [Test] - public void Clone() - { - var builder = new NpgsqlConnectionStringBuilder(); - builder.Host = "myhost"; - var builder2 = builder.Clone(); - Assert.That(builder2.Host, Is.EqualTo("myhost")); - Assert.That(builder2["Host"], Is.EqualTo("myhost")); - Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); - } + [Test] + public void Clone() + { + var builder = new NpgsqlConnectionStringBuilder(); + builder.Host = "myhost"; + var builder2 = builder.Clone(); + Assert.That(builder2.Host, Is.EqualTo("myhost")); + Assert.That(builder2["Host"], Is.EqualTo("myhost")); + Assert.That(builder.Port, Is.EqualTo(NpgsqlConnection.DefaultPort)); + } - [Test] - public void ConversionError() - { - var builder = new NpgsqlConnectionStringBuilder(); - Assert.That(() => builder["Port"] = "hello", - Throws.Exception.TypeOf().With.Message.Contains("Port")); - } + [Test] + public void Conversion_error_throws() + { + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(() => builder["Port"] = "hello", + Throws.Exception.TypeOf().With.Message.Contains("Port")); + } - [Test] - public void InvalidConnectionString() - { - var builder = new NpgsqlConnectionStringBuilder(); - Assert.That(() => builder.ConnectionString = "Server=127.0.0.1;User Id=npgsql_tests;Pooling:false", - Throws.Exception.TypeOf()); - } + [Test] + public void Invalid_connection_string_throws() + { + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(() => builder.ConnectionString = "Server=127.0.0.1;User Id=npgsql_tests;Pooling:false", + Throws.Exception.TypeOf()); } } diff --git a/test/Npgsql.Tests/ConnectionTests.cs b/test/Npgsql.Tests/ConnectionTests.cs index 0a96e52897..497cb888a2 100644 --- a/test/Npgsql.Tests/ConnectionTests.cs +++ b/test/Npgsql.Tests/ConnectionTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Data; using System.Diagnostics; using System.IO; @@ -6,1469 +7,1768 @@ using System.Net; using System.Net.Security; using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using System.Threading.Tasks; -using Npgsql.Tests.Support; +using Npgsql.Internal; +using Npgsql.PostgresTypes; +using Npgsql.Util; +using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -using static Npgsql.Util.Statics; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class ConnectionTests : MultiplexingTestBase { - public class ConnectionTests : MultiplexingTestBase + [Test, Description("Makes sure the connection goes through the proper state lifecycle")] + public async Task Basic_lifecycle() { - [Test, Description("Makes sure the connection goes through the proper state lifecycle")] - //[Timeout(5000)] - public async Task BasicLifecycle() + await using var conn = CreateConnection(); + + var eventOpen = false; + var eventClosed = false; + + conn.StateChange += (s, e) => { - using var conn = new NpgsqlConnection(ConnectionString); + if (e.OriginalState == ConnectionState.Closed && + e.CurrentState == ConnectionState.Open) + eventOpen = true; - bool eventOpen = false, eventClosed = false; - conn.StateChange += (s, e) => - { - if (e.OriginalState == ConnectionState.Closed && e.CurrentState == ConnectionState.Open) - eventOpen = true; - if (e.OriginalState == ConnectionState.Open && e.CurrentState == ConnectionState.Closed) - eventClosed = true; - }; + if (e.OriginalState == ConnectionState.Open && + e.CurrentState == ConnectionState.Closed) + eventClosed = true; + }; - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); - conn.Open(); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(eventOpen, Is.True); + await conn.OpenAsync(); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); - } + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(eventOpen, Is.True); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + await using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) + { + await reader.ReadAsync(); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + } + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + + await conn.CloseAsync(); + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); + Assert.That(eventClosed, Is.True); + } + + [Test, Description("Makes sure the connection goes through the proper state lifecycle")] + public async Task Broken_lifecycle([Values] bool openFromClose) + { + if (IsMultiplexing) + return; + + await using var dataSource = CreateDataSource(); + await using var conn = dataSource.CreateConnection(); + + var eventOpen = false; + var eventClosed = false; + + conn.StateChange += (s, e) => + { + if (e.OriginalState == ConnectionState.Closed && + e.CurrentState == ConnectionState.Open) + eventOpen = true; + + if (e.OriginalState == ConnectionState.Open && + e.CurrentState == ConnectionState.Closed) + eventClosed = true; + }; + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); + + await conn.OpenAsync(); + await using var transaction = await conn.BeginTransactionAsync(); + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(eventOpen, Is.True); + + var sleep = conn.ExecuteNonQueryAsync("SELECT pg_sleep(5)"); + + // Wait for a query + await Task.Delay(1000); + await using (var killingConn = await OpenConnectionAsync()) + killingConn.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); + + Assert.ThrowsAsync(() => sleep); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(eventClosed, Is.True); + Assert.That(conn.Connector is null); + Assert.AreEqual(0, conn.NpgsqlDataSource.Statistics.Total); + + if (openFromClose) + { + await conn.CloseAsync(); - conn.Close(); Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); Assert.That(eventClosed, Is.True); } - [Test] - public async Task BreakWhileOpen() - { - if (IsMultiplexing) - return; + Assert.DoesNotThrowAsync(conn.OpenAsync); + Assert.AreEqual(1, await conn.ExecuteScalarAsync("SELECT 1")); + Assert.AreEqual(1, conn.NpgsqlDataSource.Statistics.Total); + Assert.DoesNotThrowAsync(conn.CloseAsync); + } - using var conn = new NpgsqlConnection(ConnectionString); + [Test] + [Platform(Exclude = "MacOsX", Reason = "Flaky on MacOS")] + public async Task Break_while_open() + { + if (IsMultiplexing) + return; - conn.Open(); + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); - using (var conn2 = await OpenConnectionAsync()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); + using (var conn2 = await OpenConnectionAsync()) + conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); - // Allow some time for the pg_terminate to kill our connection - using (var cmd = CreateSleepCommand(conn, 10)) - Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception - .AssignableTo() - ); + // Allow some time for the pg_terminate to kill our connection + using (var cmd = CreateSleepCommand(conn, 10)) + Assert.That(() => cmd.ExecuteNonQuery(), Throws.Exception + .AssignableTo()); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - #region Connection Errors + #region Connection Errors #if IGNORE - [Test] - [TestCase(true)] - [TestCase(false)] - public async Task ConnectionRefused(bool pooled) - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Port = 44444, Pooling = pooled }; - using (var conn = new NpgsqlConnection(csb)) { - Assert.That(() => conn.Open(), Throws.Exception - .TypeOf() - // CoreCLR currently has an issue which causes the wrong SocketErrorCode to be set on Linux: - // https://github.com/dotnet/corefx/issues/8464 - .With.Property(nameof(SocketException.SocketErrorCode)).EqualTo(SocketError.ConnectionRefused) - ); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); - } + [Test] + [TestCase(true)] + [TestCase(false)] + public async Task Connection_refused(bool pooled) + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Port = 44444, Pooling = pooled }; + using (var conn = new NpgsqlConnection(csb)) { + Assert.That(() => conn.Open(), Throws.Exception + .TypeOf() + // CoreCLR currently has an issue which causes the wrong SocketErrorCode to be set on Linux: + // https://github.com/dotnet/corefx/issues/8464 + .With.Property(nameof(SocketException.SocketErrorCode)).EqualTo(SocketError.ConnectionRefused) + ); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); } + } - [Test] - [TestCase(true)] - [TestCase(false)] - public async Task ConnectionRefusedAsync(bool pooled) + [Test] + [TestCase(true)] + [TestCase(false)] + public async Task Connection_refused_async(bool pooled) + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Port = 44444, Pooling = pooled }; + using (var conn = new NpgsqlConnection(csb)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Port = 44444, Pooling = pooled }; - using (var conn = new NpgsqlConnection(csb)) - { - Assert.That(async () => await conn.OpenAsync(), Throws.Exception - .TypeOf() - .With.Property(nameof(SocketException.SocketErrorCode)).EqualTo(SocketError.ConnectionRefused) - ); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); - } + Assert.That(async () => await conn.OpenAsync(), Throws.Exception + .TypeOf() + .With.Property(nameof(SocketException.SocketErrorCode)).EqualTo(SocketError.ConnectionRefused) + ); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); } + } #endif - [Test] - [Ignore("Fails in a non-determinstic manner and only on the build server... investigate...")] - public void InvalidUserId() + [Test] + [Ignore("Fails in a non-determinstic manner and only on the build server... investigate...")] + public void Invalid_Username() + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Username = "unknown", Pooling = false + }.ToString(); + using var conn = new NpgsqlConnection(connString); + Assert.That(conn.Open, Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.InvalidPassword) + ); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + public void Bad_database() + { + using var dataSource = CreateDataSource(csb => csb.Database = "does_not_exist"); + using var conn = dataSource.CreateConnection(); + + Assert.That(() => conn.Open(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.InvalidCatalogName) + ); + } + + [Test, Description("Tests that mandatory connection string parameters are indeed mandatory")] + public void Mandatory_connection_string_params() + => Assert.Throws(() => + new NpgsqlConnection("User ID=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests")); + + [Test, Description("Reuses the same connection instance for a failed connection, then a successful one")] + public async Task Fail_connect_then_succeed([Values] bool pooling) + { + if (IsMultiplexing && !pooling) // Multiplexing doesn't work without pooling + return; + + var dbName = GetUniqueIdentifier(nameof(Fail_connect_then_succeed)); + await using var conn1 = await OpenConnectionAsync(); + await conn1.ExecuteNonQueryAsync($"DROP DATABASE IF EXISTS \"{dbName}\""); + try { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + await using var dataSource = CreateDataSource(csb => { - Username = "unknown", Pooling = false - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - Assert.That(conn.Open, Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("28P01") - ); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); - } - } + csb.Database = dbName; + csb.Pooling = pooling; + }); + + await using var conn2 = dataSource.CreateConnection(); + var pgEx = Assert.ThrowsAsync(conn2.OpenAsync)!; + Assert.That(pgEx.SqlState, Is.EqualTo(PostgresErrorCodes.InvalidCatalogName)); // database doesn't exist + Assert.That(conn2.FullState, Is.EqualTo(ConnectionState.Closed)); - [Test, Description("Connects with a bad password to ensure the proper error is thrown")] - public void AuthenticationFailure() + await conn1.ExecuteNonQueryAsync($"CREATE DATABASE \"{dbName}\" TEMPLATE template0"); + + Assert.DoesNotThrowAsync(conn2.OpenAsync); + Assert.DoesNotThrowAsync(conn2.CloseAsync); + } + finally { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Password = "bad" - }; - using (CreateTempPool(builder, out var connectionString)) - using (var conn = new NpgsqlConnection(connectionString)) - { - Assert.That(() => conn.OpenAsync(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).StartsWith("28") - ); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Closed)); - } + await conn1.ExecuteNonQueryAsync($"DROP DATABASE IF EXISTS \"{dbName}\""); } + } - #region ProvidePasswordCallback Tests + [Test] + public void Open_timeout_unknown_ip([Values(true, false)] bool async) + { + const int timeoutSeconds = 2; - [Test, Description("ProvidePasswordCallback is used when password is not supplied in connection string")] - public async Task ProvidePasswordCallbackDelegateIsUsed() + var unknownIp = Environment.GetEnvironmentVariable("NPGSQL_UNKNOWN_IP"); + if (unknownIp is null) { - using var _ = CreateTempPool(ConnectionString, out var connString); - var builder = new NpgsqlConnectionStringBuilder(connString); - var goodPassword = builder.Password; - var getPasswordDelegateWasCalled = false; - builder.Password = null; + Assert.Ignore("NPGSQL_UNKNOWN_IP isn't defined and is required for connection timeout tests"); + return; + } - Assume.That(goodPassword, Is.Not.Null); + using var dataSource = CreateDataSource(csb => + { + csb.Host = unknownIp; + csb.Timeout = timeoutSeconds; + }); + using var conn = dataSource.CreateConnection(); - using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) - { - conn.Open(); - Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); - - // Do this again, since with multiplexing the very first connection attempt is done via - // the non-multiplexing path, to surface any exceptions. - NpgsqlConnection.ClearPool(conn); - conn.Close(); - getPasswordDelegateWasCalled = false; - conn.Open(); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); - } + var sw = Stopwatch.StartNew(); + if (async) + { + Assert.That(async () => await conn.OpenAsync(), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); + } + else + { + Assert.That(() => conn.Open(), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); + } - string ProvidePasswordCallback(string host, int port, string database, string username) - { - getPasswordDelegateWasCalled = true; - return goodPassword!; - } + Assert.That(sw.Elapsed.TotalMilliseconds, Is.GreaterThanOrEqualTo(timeoutSeconds * 1000 - 100), + $"Timeout was supposed to happen after {timeoutSeconds} seconds, but fired after {sw.Elapsed.TotalSeconds}"); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + public void Connect_timeout_cancel() + { + var unknownIp = Environment.GetEnvironmentVariable("NPGSQL_UNKNOWN_IP"); + if (unknownIp is null) + { + Assert.Ignore("NPGSQL_UNKNOWN_IP isn't defined and is required for connection cancellation tests"); + return; } - [Test, Description("ProvidePasswordCallback is not used when password is supplied in connection string")] - public void ProvidePasswordCallbackDelegateIsNotUsed() + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - using var _ = CreateTempPool(ConnectionString, out var connString); + Host = unknownIp, + Pooling = false, + Timeout = 30 + }.ToString(); + using var conn = new NpgsqlConnection(connString); + var cts = new CancellationTokenSource(1000); + Assert.That(async () => await conn.OpenAsync(cts.Token), Throws.Exception.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } - using (var conn = new NpgsqlConnection(connString) { ProvidePasswordCallback = ProvidePasswordCallback }) - { - conn.Open(); + #endregion - // Do this again, since with multiplexing the very first connection attempt is done via - // the non-multiplexing path, to surface any exceptions. - NpgsqlConnection.ClearPool(conn); - conn.Close(); - conn.Open(); - } + #region Client Encoding - string ProvidePasswordCallback(string host, int port, string database, string username) - { - throw new Exception("password should come from connection string, not delegate"); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] + public async Task Client_encoding_is_UTF8_by_default() + { + using var conn = await OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("UTF8")); + } - [Test, Description("Exceptions thrown from client application are wrapped when using ProvidePasswordCallback Delegate")] - public void ProvidePasswordCallbackDelegateExceptionsAreWrapped() - { - using var _ = CreateTempPool(ConnectionString, out var connString); - var builder = new NpgsqlConnectionStringBuilder(connString) - { - Password = null - }; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] + [NonParallelizable] // Sets environment variable + public async Task Client_encoding_env_var() + { + using (var testConn = await OpenConnectionAsync()) + Assert.That(await testConn.ExecuteScalarAsync("SHOW client_encoding"), Is.Not.EqualTo("SQL_ASCII")); + + // Note that the pool is unaware of the environment variable, so if a connection is + // returned from the pool it may contain the wrong client_encoding + using var _ = SetEnvironmentVariable("PGCLIENTENCODING", "SQL_ASCII"); + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("SQL_ASCII")); + } - using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) - { - Assert.That(() => conn.Open(), Throws.Exception - .TypeOf() - .With.InnerException.Message.EqualTo("inner exception from ProvidePasswordCallback") - ); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] + public async Task Client_encoding_connection_param() + { + using (var conn = await OpenConnectionAsync()) + Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.Not.EqualTo("SQL_ASCII")); + await using var dataSource = CreateDataSource(csb => csb.ClientEncoding = "SQL_ASCII"); + using (var conn = await dataSource.OpenConnectionAsync()) + Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("SQL_ASCII")); + } - string ProvidePasswordCallback(string host, int port, string database, string username) - { - throw new Exception("inner exception from ProvidePasswordCallback"); - } + #endregion Client Encoding + + #region Timezone + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1634")] + [NonParallelizable] // Sets environment variable + public async Task Timezone_env_var() + { + string newTimezone; + using (var conn1 = await OpenConnectionAsync()) + { + newTimezone = (string?)await conn1.ExecuteScalarAsync("SHOW TIMEZONE") == "Africa/Bamako" + ? "Africa/Lagos" + : "Africa/Bamako"; } - [Test, Description("Parameters passed to ProvidePasswordCallback delegate are correct")] - public void ProvidePasswordCallbackDelegateGetsCorrectArguments() + // Note that the pool is unaware of the environment variable, so if a connection is + // returned from the pool it may contain the wrong timezone + using var _ = SetEnvironmentVariable("PGTZ", newTimezone); + await using var dataSource = CreateDataSource(); + using var conn2 = await dataSource.OpenConnectionAsync(); + Assert.That(await conn2.ExecuteScalarAsync("SHOW TIMEZONE"), Is.EqualTo(newTimezone)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1634")] + public async Task Timezone_connection_param() + { + string newTimezone; + using (var conn = await OpenConnectionAsync()) { - using var _ = CreateTempPool(ConnectionString, out var connString); - var builder = new NpgsqlConnectionStringBuilder(connString); - var goodPassword = builder.Password; - builder.Password = null; + newTimezone = (string?)await conn.ExecuteScalarAsync("SHOW TIMEZONE") == "Africa/Bamako" + ? "Africa/Lagos" + : "Africa/Bamako"; + } - Assume.That(goodPassword, Is.Not.Null); + await using var dataSource = CreateDataSource(csb => csb.Timezone = newTimezone); + using (var conn = await dataSource.OpenConnectionAsync()) + Assert.That(await conn.ExecuteScalarAsync("SHOW TIMEZONE"), Is.EqualTo(newTimezone)); + } - string? receivedHost = null; - int? receivedPort = null; - string? receivedDatabase = null; - string? receivedUsername = null; + #endregion Timezone + + #region ConnectionString - Host + + [TestCase("127.0.0.1", ExpectedResult = new [] { "127.0.0.1:5432" })] + [TestCase("127.0.0.1:5432", ExpectedResult = new [] { "127.0.0.1:5432" })] + [TestCase("::1", ExpectedResult = new [] { "::1:5432" })] + [TestCase("[::1]", ExpectedResult = new [] { "[::1]:5432" })] + [TestCase("[::1]:5432", ExpectedResult = new [] { "[::1]:5432" })] + [TestCase("localhost", ExpectedResult = new [] { "localhost:5432" })] + [TestCase("localhost:5432", ExpectedResult = new [] { "localhost:5432" })] + [TestCase("127.0.0.1,127.0.0.1:5432,::1,[::1],[::1]:5432,localhost,localhost:5432", + ExpectedResult = new [] + { + "127.0.0.1:5432", + "127.0.0.1:5432", + "::1:5432", + "[::1]:5432", + "[::1]:5432", + "localhost:5432", + "localhost:5432" + })] + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3802")] + public string[] ConnectionString_Host(string host) + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = { Host = host } + }; + using var dataSource = dataSourceBuilder.BuildMultiHost(); + return dataSource.Pools.Select(ds => $"{ds.Settings.Host}:{ds.Settings.Port}").ToArray()!; + } - using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) - { - conn.Open(); - Assert.AreEqual(builder.Host, receivedHost); - Assert.AreEqual(builder.Port, receivedPort); - Assert.AreEqual(builder.Database, receivedDatabase); - Assert.AreEqual(builder.Username, receivedUsername); - } + #endregion ConnectionString - Host - string ProvidePasswordCallback(string host, int port, string database, string username) - { - receivedHost = host; - receivedPort = port; - receivedDatabase = database; - receivedUsername = username; + [Test] + public async Task Unix_domain_socket() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (Environment.OSVersion.Version.Major < 10 || Environment.OSVersion.Version.Build < 17093) + Assert.Ignore("Unix-domain sockets support was introduced in Windows build 17093"); - return goodPassword!; - } + // On Windows we first need a classic IP connection to make sure we're running against the + // right backend version + using var versionConnection = await OpenConnectionAsync(); + MinimumPgVersion(versionConnection, "13.0", "Unix-domain sockets support on Windows was introduced in PostgreSQL 13"); } - #endregion - [Test] - public void BadDatabase() + var port = new NpgsqlConnectionStringBuilder(ConnectionString).Port; + var candidateDirectories = new[] { "/var/run/postgresql", "/tmp", Environment.GetEnvironmentVariable("TMP") ?? "C:\\" }; + var dir = candidateDirectories.FirstOrDefault(d => File.Exists(Path.Combine(d, $".s.PGSQL.{port}"))); + if (dir == null) { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Database = "does_not_exist" - }; - using (CreateTempPool(builder, out var connectionString)) - using (var conn = new NpgsqlConnection(connectionString)) - Assert.That(() => conn.Open(), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("3D000") - ); + IgnoreExceptOnBuildServer("No PostgreSQL unix domain socket was found"); + return; } - [Test, Description("Tests that mandatory connection string parameters are indeed mandatory")] - public void MandatoryConnectionStringParams() + try { - Assert.That(() => new NpgsqlConnection("User ID=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests").Open(), Throws.Exception.TypeOf()); + await using var dataSource = CreateDataSource(csb => csb.Host = dir); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1", tx), Is.EqualTo(1)); + Assert.That(conn.DataSource, Is.EqualTo(Path.Combine(dir, $".s.PGSQL.{port}"))); } - - [Test, Description("Reuses the same connection instance for a failed connection, then a successful one")] - public async Task FailConnectThenSucceed() + catch (Exception ex) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); + IgnoreExceptOnBuildServer($"Connection via unix domain socket failed: {ex}"); + } + } - var dbName = GetUniqueIdentifier(nameof(FailConnectThenSucceed)); - using (var conn1 = await OpenConnectionAsync()) - { - conn1.ExecuteNonQuery($"DROP DATABASE IF EXISTS \"{dbName}\""); - try - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Database = dbName, - Pooling = false - }.ToString(); - - using (var conn2 = new NpgsqlConnection(connString)) - { - Assert.That(() => conn2.Open(), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("3D000") // database doesn't exist - ); - Assert.That(conn2.FullState, Is.EqualTo(ConnectionState.Closed)); - - conn1.ExecuteNonQuery($"CREATE DATABASE \"{dbName}\" TEMPLATE template0"); - - conn2.Open(); - conn2.Close(); - } - } - finally - { - //conn1.ExecuteNonQuery($"DROP DATABASE IF EXISTS \"{dbName}\""); - } - } + [Test] + [Platform(Exclude = "MacOsX", Reason = "Fails only on mac, needs to be investigated")] + public async Task Unix_abstract_domain_socket() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("Abstract unix-domain sockets are not supported on windows"); } - [Test] - [Timeout(10000)] - public void ConnectTimeout() + // We first need a classic IP connection to make sure we're running against the + // right backend version + using var versionConnection = await OpenConnectionAsync(); + MinimumPgVersion(versionConnection, "14.0", "Abstract unix-domain sockets support was introduced in PostgreSQL 14"); + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - var unknownIp = Environment.GetEnvironmentVariable("NPGSQL_UNKNOWN_IP"); - if (unknownIp == null) - return; // https://github.com/nunit/nunit/issues/3282 - //Assert.Ignore("NPGSQL_UNKNOWN_IP isn't defined and is required for connection timeout tests"); + Host = "@/npgsql_unix" + }; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - Host = unknownIp, - Pooling = false, - Timeout = 2 - }; - using (var conn = new NpgsqlConnection(csb.ToString())) - { - var sw = Stopwatch.StartNew(); - Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); - Assert.That(sw.Elapsed.TotalMilliseconds, Is.GreaterThanOrEqualTo((csb.Timeout * 1000) - 100), - $"Timeout was supposed to happen after {csb.Timeout} seconds, but fired after {sw.Elapsed.TotalSeconds}"); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } + try + { + await using var dataSource = CreateDataSource(csb.ToString()); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1", tx), Is.EqualTo(1)); + Assert.That(conn.DataSource, Is.EqualTo(Path.Combine(csb.Host, $".s.PGSQL.{csb.Port}"))); } - - [Test] - [Timeout(10000)] - public void ConnectTimeoutAsync() + catch (Exception ex) { - var unknownIp = Environment.GetEnvironmentVariable("NPGSQL_UNKNOWN_IP"); - if (unknownIp == null) - return; // https://github.com/nunit/nunit/issues/3282 - // Assert.Ignore("NPGSQL_UNKNOWN_IP isn't defined and is required for connection timeout tests"); + IgnoreExceptOnBuildServer($"Connection via abstract unix domain socket failed: {ex}"); + } + } - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Host = unknownIp, - Pooling = false, - Timeout = 2 - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - Assert.That(async () => await conn.OpenAsync(), Throws.Exception - .TypeOf() - .With.InnerException.TypeOf()); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/903")] + public void DataSource_property() + { + using var conn = new NpgsqlConnection(); + Assert.That(conn.DataSource, Is.EqualTo(string.Empty)); + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); + + conn.ConnectionString = csb.ConnectionString; + Assert.That(conn.DataSource, Is.EqualTo($"tcp://{csb.Host}:{csb.Port}")); + + // Multiplexing isn't supported with multiple hosts + if (IsMultiplexing) + return; + + csb.Host = "127.0.0.1, 127.0.0.2"; + conn.ConnectionString = csb.ConnectionString; + Assert.That(conn.DataSource, Is.EqualTo(string.Empty)); + } + + #region Server version + + [Test] + public async Task PostgreSqlVersion_ServerVersion() + { + await using var c = new NpgsqlConnection(ConnectionString); + + Assert.That(() => c.PostgreSqlVersion, Throws.InvalidOperationException + .With.Message.EqualTo("Connection is not open")); + + Assert.That(() => c.ServerVersion, Throws.InvalidOperationException + .With.Message.EqualTo("Connection is not open")); + + await c.OpenAsync(); + var backendVersionString = (string)(await c.ExecuteScalarAsync("SHOW server_version"))!; + + Assert.That(backendVersionString, Is.EqualTo(c.ServerVersion)); + + Assert.That(backendVersionString, Does.Contain( + new[] { "rc", "beta", "devel" }.Any(x => backendVersionString.Contains(x)) + ? c.PostgreSqlVersion.Major.ToString() + : c.PostgreSqlVersion.ToString())); + } + + [TestCase("X13.0")] + [TestCase("13.")] + [TestCase("13.1.")] + [TestCase("13.1.1.")] + [TestCase("13.1.1.1.")] + [TestCase("13.1.1.1.1")] + public void ParseVersion_fails(string versionString) + => Assert.That(() => TestDbInfo.ParseServerVersion(versionString), Throws.Exception); + + [TestCase("13.3", ExpectedResult = "13.3")] + [TestCase("13.3X", ExpectedResult = "13.3")] + [TestCase("9.6.4", ExpectedResult = "9.6.4")] + [TestCase("9.6.4X", ExpectedResult = "9.6.4")] + [TestCase("9.5alpha2", ExpectedResult = "9.5")] + [TestCase("9.5alpha2X", ExpectedResult = "9.5")] + [TestCase("9.5devel", ExpectedResult = "9.5")] + [TestCase("9.5develX", ExpectedResult = "9.5")] + [TestCase("9.5deveX", ExpectedResult = "9.5")] + [TestCase("9.4beta3", ExpectedResult = "9.4")] + [TestCase("9.4rc1", ExpectedResult = "9.4")] + [TestCase("9.4rc1X", ExpectedResult = "9.4")] + [TestCase("13devel", ExpectedResult = "13.0")] + [TestCase("13beta1", ExpectedResult = "13.0")] + // The following should not occur as PostgreSQL version string in the wild these days but we support it. + [TestCase("13", ExpectedResult = "13.0")] + [TestCase("13X", ExpectedResult = "13.0")] + [TestCase("13alpha1", ExpectedResult = "13.0")] + [TestCase("13alpha", ExpectedResult = "13.0")] + [TestCase("13alphX", ExpectedResult = "13.0")] + [TestCase("13beta", ExpectedResult = "13.0")] + [TestCase("13betX", ExpectedResult = "13.0")] + [TestCase("13rc1", ExpectedResult = "13.0")] + [TestCase("13rc", ExpectedResult = "13.0")] + [TestCase("13rX", ExpectedResult = "13.0")] + [TestCase("99999.99999.99999.99999", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999.99999X", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999.99999devel", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999.99999alpha99999", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999alpha99999", ExpectedResult = "99999.99999.99999")] + [TestCase("99999.99999.99999.99999beta99999", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999beta99999", ExpectedResult = "99999.99999.99999")] + [TestCase("99999.99999.99999.99999rc99999", ExpectedResult = "99999.99999.99999.99999")] + [TestCase("99999.99999.99999rc99999", ExpectedResult = "99999.99999.99999")] + public string ParseVersion_succeeds(string versionString) + => TestDbInfo.ParseServerVersion(versionString).ToString(); + + class TestDbInfo : NpgsqlDatabaseInfo + { + public TestDbInfo(string host, int port, string databaseName, Version version) : base(host, port, databaseName, version) + => throw new NotImplementedException(); + + protected override IEnumerable GetTypes() + => throw new NotImplementedException(); + + public new static Version ParseServerVersion(string versionString) + => NpgsqlDatabaseInfo.ParseServerVersion(versionString); + } + + #endregion Server version + + [Test] + public void Setting_connection_string_while_open_throws() + { + using var conn = new NpgsqlConnection(); + conn.ConnectionString = ConnectionString; + conn.Open(); + Assert.That(() => conn.ConnectionString = "", Throws.Exception.TypeOf()); + } + + [Test] + public void Empty_constructor() + { + var conn = new NpgsqlConnection(); + Assert.That(conn.ConnectionTimeout, Is.EqualTo(NpgsqlConnectionStringBuilder.DefaultTimeout)); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void Constructor_with_null_connection_string() + { + var conn = new NpgsqlConnection(null); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void Constructor_with_empty_connection_string() + { + var conn = new NpgsqlConnection(""); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void Set_connection_string_to_null() + { + var conn = new NpgsqlConnection(ConnectionString); + conn.ConnectionString = null; + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(conn.Settings.Host, Is.Null); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void Set_connection_string_to_empty() + { + var conn = new NpgsqlConnection(ConnectionString); + conn.ConnectionString = ""; + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(conn.Settings.Host, Is.Null); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/703")] + public async Task No_database_defaults_to_username() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Database = null }; + using var conn = new NpgsqlConnection(csb.ToString()); + Assert.That(conn.Database, Is.EqualTo(csb.Username)); + conn.Open(); + Assert.That(await conn.ExecuteScalarAsync("SELECT current_database()"), Is.EqualTo(csb.Username)); + Assert.That(conn.Database, Is.EqualTo(csb.Username)); + } + + [Test, Description("Breaks a connector while it's in the pool, with a keepalive and without")] + [Platform(Exclude = "MacOsX", Reason = "Fails only on mac, needs to be investigated")] + [TestCase(false, TestName = nameof(Break_connector_in_pool) + "_without_keep_alive")] + [TestCase(true, TestName = nameof(Break_connector_in_pool) + "_with_keep_alive")] + public async Task Break_connector_in_pool(bool keepAlive) + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing, hanging"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.MaxPoolSize = 1; + if (keepAlive) + dataSourceBuilder.ConnectionStringBuilder.KeepAlive = 1; + await using var dataSource = dataSourceBuilder.Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + var connector = conn.Connector; + Assert.That(connector, Is.Not.Null); + await conn.CloseAsync(); + + // Use another connection to kill the connector currently in the pool + await using (var conn2 = await OpenConnectionAsync()) + await conn2.ExecuteNonQueryAsync($"SELECT pg_terminate_backend({connector!.BackendProcessId})"); + + // Allow some time for the terminate to occur + await Task.Delay(3000); + + await conn.OpenAsync(); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + if (keepAlive) + { + Assert.That(conn.Connector, Is.Not.SameAs(connector)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } + else + { + Assert.That(conn.Connector, Is.SameAs(connector)); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception + .AssignableTo()); + } + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4603")] + public async Task Reload_types_keepalive_concurrent() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing doesn't support keepalive"); + + await using var dataSource = CreateDataSource(csb => csb.KeepAlive = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + + var startTimestamp = Stopwatch.GetTimestamp(); + // Give a few seconds for a KeepAlive to possibly perform + while (GetElapsedTime(startTimestamp).TotalSeconds < 2) + Assert.DoesNotThrow(conn.ReloadTypes); + + // dotnet 3.1 doesn't have Stopwatch.GetElapsedTime method. + static TimeSpan GetElapsedTime(long startingTimestamp) => + new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)10000000 / Stopwatch.Frequency))); + } - [Test] - [Timeout(10000)] - public void ConnectTimeoutCancel() + #region ChangeDatabase + + [Test] + public async Task ChangeDatabase() + { + using var conn = await OpenConnectionAsync(); + conn.ChangeDatabase("template1"); + using var cmd = new NpgsqlCommand("select current_database()", conn); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("template1")); + } + + [Test] + public async Task ChangeDatabase_does_not_affect_other_connections() + { + using var conn1 = new NpgsqlConnection(ConnectionString); + using var conn2 = new NpgsqlConnection(ConnectionString); + // Connection 1 changes database + conn1.Open(); + conn1.ChangeDatabase("template1"); + Assert.That(await conn1.ExecuteScalarAsync("SELECT current_database()"), Is.EqualTo("template1")); + + // Connection 2's database should not changed + conn2.Open(); + Assert.That(await conn2.ExecuteScalarAsync("SELECT current_database()"), Is.Not.EqualTo(conn1.Database)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1331")] + public void ChangeDatabase_connection_on_closed_connection_throws() + { + using var conn = new NpgsqlConnection(ConnectionString); + Assert.That(() => conn.ChangeDatabase("template1"), Throws.Exception + .TypeOf() + .With.Message.EqualTo("Connection is not open")); + } + + #endregion + + [Test, Description("Tests closing a connector while a reader is open")] + public async Task Close_during_read([Values(PooledOrNot.Pooled, PooledOrNot.Unpooled)] PooledOrNot pooled) + { + if (IsMultiplexing && pooled == PooledOrNot.Unpooled) + return; // Multiplexing requires pooling + + await using var dataSource = CreateDataSource(csb => csb.Pooling = pooled == PooledOrNot.Pooled); + await using var conn = await dataSource.OpenConnectionAsync(); + await using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) { - var unknownIp = Environment.GetEnvironmentVariable("NPGSQL_UNKNOWN_IP"); - if (unknownIp == null) - return; // https://github.com/nunit/nunit/issues/3282 - //Assert.Ignore("NPGSQL_UNKNOWN_IP isn't defined and is required for connection cancellation tests"); + reader.Read(); + conn.Close(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(reader.IsClosed); + } - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Host = unknownIp, - Pooling = false, - Timeout = 30 - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - var cts = new CancellationTokenSource(1000); - Assert.That(async () => await conn.OpenAsync(cts.Token), Throws.Exception.TypeOf()); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } + conn.Open(); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Search_path() + { + await using var dataSource = CreateDataSource(csb => csb.SearchPath = "foo"); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW search_path"), Contains.Substring("foo")); + } + + [Test] + public async Task Set_options() + { + await using var dataSource = CreateDataSource(csb => + csb.Options = + "-c default_transaction_isolation=serializable -c default_transaction_deferrable=on -c foo.bar=My\\ Famous\\\\Thing"); + await using var conn = await dataSource.OpenConnectionAsync(); + + Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_isolation"), Is.EqualTo("serializable")); + Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_deferrable"), Is.EqualTo("on")); + Assert.That(await conn.ExecuteScalarAsync("SHOW foo.bar"), Is.EqualTo("My Famous\\Thing")); + } + + [Test] + public async Task Connector_not_initialized_exception() + { + var command = new NpgsqlCommand(); + command.CommandText = @"SELECT 123"; + + for (var i = 0; i < 2; i++) + { + await using var connection = await OpenConnectionAsync(); + command.Connection = connection; + await using var tx = await connection.BeginTransactionAsync(); + await command.ExecuteScalarAsync(); + await tx.CommitAsync(); } + } + + [Test] + public void Bug1011001() + { + //[#1011001] Bug in NpgsqlConnectionStringBuilder affects on cache and connection pool + + var csb1 = new NpgsqlConnectionStringBuilder(@"Server=server;Port=5432;User Id=user;Password=passwor;Database=database;"); + var cs1 = csb1.ToString(); + var csb2 = new NpgsqlConnectionStringBuilder(cs1); + var cs2 = csb2.ToString(); + Assert.IsTrue(cs1 == cs2); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/pull/164")] + public void Connection_State_is_Closed_when_disposed() + { + var c = new NpgsqlConnection(); + c.Dispose(); + Assert.AreEqual(ConnectionState.Closed, c.State); + } - #endregion + [Test] + public void Change_ApplicationName_with_connection_string_builder() + { + // Test for issue #165 on github. + var builder = new NpgsqlConnectionStringBuilder(); + builder.ApplicationName = "test"; + } - #region Client Encoding + [Test, Description("Makes sure notices are probably received and emitted as events")] + public async Task Notice() + { + // Make sure messages are in English + await using var dataSource = CreateDataSource(csb => csb.Options = "-c lc_messages=en_US.UTF-8"); + await using var conn = await dataSource.OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync($@" +CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS +'BEGIN RAISE NOTICE ''testnotice''; END;' +LANGUAGE 'plpgsql'"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] - public async Task ClientEncodingIsUTF8ByDefault() + var mre = new ManualResetEvent(false); + PostgresNotice? notice = null; + NoticeEventHandler action = (sender, args) => + { + notice = args.Notice; + mre.Set(); + }; + conn.Notice += action; + try { - using (var conn = await OpenConnectionAsync()) - Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("UTF8")); + // See docs for CreateSleepCommand + await conn.ExecuteNonQueryAsync($"SELECT {function}()::TEXT"); + mre.WaitOne(5000); + Assert.That(notice, Is.Not.Null, "No notice was emitted"); + Assert.That(notice!.MessageText, Is.EqualTo("testnotice")); + Assert.That(notice.Severity, Is.EqualTo("NOTICE")); } + finally + { + conn.Notice -= action; + } + } + + [Test, Description("Makes sure that concurrent use of the connection throws an exception")] + public async Task Concurrent_use_throws() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + using (await cmd.ExecuteReaderAsync()) + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 2"), + Throws.Exception.TypeOf() + .With.Property(nameof(NpgsqlOperationInProgressException.CommandInProgress)).SameAs(cmd)); + + await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE foo (bar INT)"); + using (conn.BeginBinaryImport("COPY foo (bar) FROM STDIN BINARY")) + { + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 2"), + Throws.Exception.TypeOf() + .With.Message.Contains("Copy")); + } + } + + #region PersistSecurityInfo + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/783")] + public void PersistSecurityInfo_is_true([Values(true, false)] bool pooling) + { + if (IsMultiplexing && !pooling) + return; + + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + PersistSecurityInfo = true, + Pooling = pooling + }.ToString(); + using var conn = new NpgsqlConnection(connString); + var passwd = new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password; + Assert.That(passwd, Is.Not.Null); + conn.Open(); + Assert.That(new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password, Is.EqualTo(passwd)); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/783")] + public void No_password_without_PersistSecurityInfo([Values(true, false)] bool pooling) + { + if (IsMultiplexing && !pooling) + return; + + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Pooling = pooling + }.ToString(); + using var conn = new NpgsqlConnection(connString); + var csb = new NpgsqlConnectionStringBuilder(conn.ConnectionString); + Assert.That(csb.PersistSecurityInfo, Is.False); + Assert.That(csb.Password, Is.Not.Null); + conn.Open(); + Assert.That(new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password, Is.Null); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] - [NonParallelizable] - public async Task ClientEncodingEnvVar() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2725")] + public void Clone_with_PersistSecurityInfo() + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { - using (var testConn = await OpenConnectionAsync()) - Assert.That(await testConn.ExecuteScalarAsync("SHOW client_encoding"), Is.Not.EqualTo("SQL_ASCII")); + PersistSecurityInfo = true + }; + using var _ = CreateTempPool(builder, out var connStringWithPersist); - // Note that the pool is unaware of the environment variable, so if a connection is - // returned from the pool it may contain the wrong client_encoding - using var _ = SetEnvironmentVariable("PGCLIENTENCODING", "SQL_ASCII"); - using var __ = CreateTempPool(ConnectionString, out var connectionString); + using var connWithPersist = new NpgsqlConnection(connStringWithPersist); - var connString = new NpgsqlConnectionStringBuilder(connectionString); - using var conn = await OpenConnectionAsync(connString); - Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("SQL_ASCII")); + // First un-persist, should work + builder.PersistSecurityInfo = false; + var connStringWithoutPersist = builder.ToString(); + using var clonedWithoutPersist = connWithPersist.CloneWith(connStringWithoutPersist); + clonedWithoutPersist.Open(); + + Assert.That(clonedWithoutPersist.ConnectionString, Does.Not.Contain("Password=")); + + // Then attempt to re-persist, should not work + using var clonedConn = clonedWithoutPersist.CloneWith(connStringWithPersist); + clonedConn.Open(); + + Assert.That(clonedConn.ConnectionString, Does.Not.Contain("Password=")); + } + + [Test] + public async Task CloneWith_and_data_source_with_password() + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString); + // Set the password via the data source property later to make sure that's picked up by CloneWith + var password = dataSourceBuilder.ConnectionStringBuilder.Password!; + dataSourceBuilder.ConnectionStringBuilder.Password = null; + await using var dataSource = dataSourceBuilder.Build(); + + await using var connection = dataSource.CreateConnection(); + dataSource.Password = password; + + // Test that the up-to-date password gets copied to the clone, as if we opened the original connection instead of cloning it + using var _ = CreateTempPool(new NpgsqlConnectionStringBuilder(ConnectionString) { Password = null }, out var tempConnectionString); + await using var clonedConnection = connection.CloneWith(tempConnectionString); + await clonedConnection.OpenAsync(); + } + + [Test] + public async Task CloneWith_and_data_source_with_auth_callbacks() + { + var (userCertificateValidationCallbackCalled, clientCertificatesCallbackCalled) = (false, false); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UseUserCertificateValidationCallback(UserCertificateValidationCallback); + dataSourceBuilder.UseClientCertificatesCallback(ClientCertificatesCallback); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = dataSource.CreateConnection(); + + using var _ = CreateTempPool(ConnectionString, out var tempConnectionString); + await using var clonedConnection = connection.CloneWith(tempConnectionString); + + clonedConnection.UserCertificateValidationCallback!(null!, null, null, SslPolicyErrors.None); + Assert.True(userCertificateValidationCallbackCalled); + clonedConnection.ProvideClientCertificatesCallback!(null!); + Assert.True(clientCertificatesCallbackCalled); + + bool UserCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors errors) + => userCertificateValidationCallbackCalled = true; + + void ClientCertificatesCallback(X509CertificateCollection certs) + => clientCertificatesCallbackCalled = true; + } + + #endregion PersistSecurityInfo + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/743")] + [IssueLink("https://github.com/npgsql/npgsql/issues/783")] + public void Clone() + { + using var pool = CreateTempPool(ConnectionString, out var connectionString); + using var conn = new NpgsqlConnection(connectionString); + ProvideClientCertificatesCallback callback1 = certificates => { }; + conn.ProvideClientCertificatesCallback = callback1; + RemoteCertificateValidationCallback callback2 = (sender, certificate, chain, errors) => true; + conn.UserCertificateValidationCallback = callback2; + + conn.Open(); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + using var conn2 = (NpgsqlConnection)((ICloneable)conn).Clone(); + Assert.That(conn2.ConnectionString, Is.EqualTo(conn.ConnectionString)); + Assert.That(conn2.ProvideClientCertificatesCallback, Is.SameAs(callback1)); + Assert.That(conn2.UserCertificateValidationCallback, Is.SameAs(callback2)); + conn2.Open(); + Assert.That(async () => await conn2.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Clone_with_data_source() + { + await using var connection = await DataSource.OpenConnectionAsync(); + await using var clonedConnection = (NpgsqlConnection)((ICloneable)connection).Clone(); + + Assert.That(clonedConnection.NpgsqlDataSource, Is.SameAs(DataSource)); + Assert.DoesNotThrowAsync(() => clonedConnection.OpenAsync()); + } + + [Test] + public async Task DatabaseInfo_is_shared() + { + if (IsMultiplexing) + return; + // Create a temp pool to make sure the second connection will be new and not idle + await using var dataSource = CreateDataSource(); + await using var conn1 = await dataSource.OpenConnectionAsync(); + // Call RealoadTypes to force reload DatabaseInfo + conn1.ReloadTypes(); + await using var conn2 = await dataSource.OpenConnectionAsync(); + Assert.That(conn1.Connector!.DatabaseInfo, Is.SameAs(conn2.Connector!.DatabaseInfo)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/736")] + public async Task ManyOpenClose() + { + await using var dataSource = CreateDataSource(); + // The connector's _sentRfqPrependedMessages is a byte, too many open/closes made it overflow + for (var i = 0; i < 255; i++) + { + await using var conn = await dataSource.OpenConnectionAsync(); + } + await using (var conn = dataSource.CreateConnection()) + { + await conn.OpenAsync(); + } + await using (var conn = dataSource.CreateConnection()) + { + await conn.OpenAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1065")] - public async Task ClientEncodingConnectionParam() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/736")] + public async Task Many_open_close_with_transaction() + { + await using var dataSource = CreateDataSource(); + // The connector's _sentRfqPrependedMessages is a byte, too many open/closes made it overflow + for (var i = 0; i < 255; i++) { - using (var conn = await OpenConnectionAsync()) - Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.Not.EqualTo("SQL_ASCII")); - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { ClientEncoding = "SQL_ASCII" }; - using (var conn = await OpenConnectionAsync(connString)) - Assert.That(await conn.ExecuteScalarAsync("SHOW client_encoding"), Is.EqualTo("SQL_ASCII")); + await using var conn = await dataSource.OpenConnectionAsync(); + await conn.BeginTransactionAsync(); } + await using (var conn = await dataSource.OpenConnectionAsync()) + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - #endregion Client Encoding + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/927")] + [IssueLink("https://github.com/npgsql/npgsql/issues/736")] + [Ignore("Fails when running the entire test suite but not on its own...")] + public async Task Rollback_on_close() + { + // Npgsql 3.0.0 to 3.0.4 prepended a rollback for the next time the connector is used, as an optimization. + // This caused some issues (#927) and was removed. - #region Timezone + await using var dataSource = CreateDataSource(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1634")] - [NonParallelizable] - public async Task TimezoneEnvVar() + int processId; + await using (var conn = await dataSource.OpenConnectionAsync()) { - string newTimezone; - using (var conn1 = await OpenConnectionAsync()) - { - newTimezone = (string?)await conn1.ExecuteScalarAsync("SHOW TIMEZONE") == "Africa/Bamako" - ? "Africa/Lagos" - : "Africa/Bamako"; - } - - // Note that the pool is unaware of the environment variable, so if a connection is - // returned from the pool it may contain the wrong timezone - using var _ = SetEnvironmentVariable("PGTZ", newTimezone); - using var __ = CreateTempPool(ConnectionString, out var connectionString); - using var conn2 = await OpenConnectionAsync(connectionString); - Assert.That(await conn2.ExecuteScalarAsync("SHOW TIMEZONE"), Is.EqualTo(newTimezone)); + processId = conn.Connector!.BackendProcessId; + await conn.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync("SELECT 1"); + Assert.That(conn.Connector.TransactionStatus, Is.EqualTo(TransactionStatus.InTransactionBlock)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1634")] - public async Task TimezoneConnectionParam() + await using (var conn = await dataSource.OpenConnectionAsync()) { - string newTimezone; - using (var conn = await OpenConnectionAsync()) - { - newTimezone = (string?)await conn.ExecuteScalarAsync("SHOW TIMEZONE") == "Africa/Bamako" - ? "Africa/Lagos" - : "Africa/Bamako"; - } - - var _ = CreateTempPool(ConnectionString, out var connString); - var builder = new NpgsqlConnectionStringBuilder(connString) - { - Timezone = newTimezone - }; - using (var conn = await OpenConnectionAsync(builder.ConnectionString)) - Assert.That(await conn.ExecuteScalarAsync("SHOW TIMEZONE"), Is.EqualTo(newTimezone)); + Assert.That(conn.Connector!.BackendProcessId, Is.EqualTo(processId)); + Assert.That(conn.Connector.TransactionStatus, Is.EqualTo(TransactionStatus.Idle)); } + } - #endregion Timezone - - [Test] - public async Task UnixDomainSocket() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - if (Environment.OSVersion.Version.Major < 10 || Environment.OSVersion.Version.Build < 17093) - Assert.Ignore("Unix-domain sockets support was introduced in Windows build 17093"); - - // On Windows we first need a classic IP connection to make sure we're running against the - // right backend version - using var versionConnection = await OpenConnectionAsync(); - MinimumPgVersion(versionConnection, "13.0", "Unix-domain sockets support on Windows was introduced in PostgreSQL 13"); - } - - var port = new NpgsqlConnectionStringBuilder(ConnectionString).Port; - var candidateDirectories = new[] { "/var/run/postgresql", "/tmp", Environment.GetEnvironmentVariable("TMP") ?? "C:\\" }; - var dir = candidateDirectories.FirstOrDefault(d => File.Exists(Path.Combine(d, $".s.PGSQL.{port}"))); - if (dir == null) - { - IgnoreExceptOnBuildServer("No PostgreSQL unix domain socket was found"); - return; - } + [Test, Description("Tests an exception happening when sending the Terminate message while closing a ready connector")] + [IssueLink("https://github.com/npgsql/npgsql/issues/777")] + public async Task Exception_during_close() + { + // Pooling must be on to use multiplexing + if (IsMultiplexing) + return; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Host = dir - }; + await using var dataSource = CreateDataSource(csb => csb.Pooling = false); + await using var conn = await dataSource.OpenConnectionAsync(); + var connectorId = conn.ProcessID; - try - { - using var conn = await OpenConnectionAsync(csb); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - Assert.That(conn.DataSource, Is.EqualTo(Path.Combine(csb.Host, $".s.PGSQL.{port}"))); - } - catch (PostgresException e) when (e.SqlState.StartsWith("28")) - { - if (TestUtil.IsOnBuildServer) - throw; - Assert.Ignore("Connection via unix domain socket failed"); - } - } + using (var conn2 = await OpenConnectionAsync()) + await conn2.ExecuteNonQueryAsync($"SELECT pg_terminate_backend({connectorId})"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/903")] - public void DataSource() - { - using (var conn = new NpgsqlConnection(ConnectionString)) - Assert.That(conn.DataSource, Is.EqualTo($"tcp://{conn.Host}:{conn.Port}")); + conn.Close(); + } - var bld = new NpgsqlConnectionStringBuilder(ConnectionString); - bld.Host = "Otherhost"; + [Test, Description("Some pseudo-PG database don't support pg_type loading, we have a minimal DatabaseInfo for this")] + public async Task NoTypeLoading() + { + await using var dataSource = CreateDataSource(csb => csb.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading); + await using var conn = await dataSource.OpenConnectionAsync(); - using (var conn = new NpgsqlConnection(bld.ToString())) - Assert.That(conn.DataSource, Is.EqualTo($"tcp://{conn.Host}:{conn.Port}")); + Assert.That(await conn.ExecuteScalarAsync("SELECT 8"), Is.EqualTo(8)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 'foo'"), Is.EqualTo("foo")); + Assert.That(await conn.ExecuteScalarAsync("SELECT TRUE"), Is.EqualTo(true)); + Assert.That(await conn.ExecuteScalarAsync("SELECT INET '192.168.1.1'"), Is.EqualTo(IPAddress.Parse("192.168.1.1"))); - bld = new NpgsqlConnectionStringBuilder(ConnectionString); - bld.Port = 5435; + Assert.That(await conn.ExecuteScalarAsync("SELECT '{1,2,3}'::int[]"), Is.EqualTo(new[] { 1, 2, 3 })); + Assert.That(await conn.ExecuteScalarAsync("SELECT '[1,10)'::int4range"), Is.EqualTo(new NpgsqlRange(1, true, 10, false))); - using (var conn = new NpgsqlConnection(bld.ToString())) - Assert.That(conn.DataSource, Is.EqualTo($"tcp://{conn.Host}:{conn.Port}")); + if (conn.PostgreSqlVersion >= new Version(14, 0)) + { + var multirangeArray = (NpgsqlRange[])(await conn.ExecuteScalarAsync("SELECT '{[3,7), (8,]}'::int4multirange"))!; + Assert.That(multirangeArray.Length, Is.EqualTo(2)); + Assert.That(multirangeArray[0], Is.EqualTo(new NpgsqlRange(3, true, false, 7, false, false))); + Assert.That(multirangeArray[1], Is.EqualTo(new NpgsqlRange(9, true, false, 0, false, true))); } - - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2763")] - public void DataSourceDefault() + else { - using (var conn = new NpgsqlConnection()) + using var cmd = new NpgsqlCommand("SELECT $1", conn) { - Assert.That(conn.DataSource, Is.EqualTo(string.Empty)); + Parameters = { new() { Value = DBNull.Value, NpgsqlDbType = NpgsqlDbType.IntegerMultirange } } + }; - conn.ConnectionString = ConnectionString; - Assert.That(conn.DataSource, Is.EqualTo($"tcp://{conn.Host}:{conn.Port}")); - } + Assert.That(async () => await cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf() + .With.Message.EqualTo("The NpgsqlDbType 'IntegerMultirange' isn't present in your database. You may need to install an extension or upgrade to a newer version.")); } + } - [Test] - public void SetConnectionString() - { - using (var conn = new NpgsqlConnection()) - { - conn.ConnectionString = ConnectionString; - conn.Open(); - Assert.That(() => conn.ConnectionString = "", Throws.Exception.TypeOf()); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1158")] + public async Task Table_named_record() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing, ReloadTypes"); - [Test] - public void EmptyCtor() - { - var conn = new NpgsqlConnection(); - Assert.That(conn.ConnectionTimeout, Is.EqualTo(NpgsqlConnectionStringBuilder.DefaultTimeout)); - Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); - Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); - } + using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync(@" - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/703")] - public async Task NoDatabaseDefaultsToUsername() +DROP TABLE IF EXISTS record; +CREATE TABLE record ()"); + try { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Database = null }; - using (var conn = new NpgsqlConnection(csb.ToString())) - { - Assert.That(conn.Database, Is.EqualTo(csb.Username)); - conn.Open(); - Assert.That(await conn.ExecuteScalarAsync("SELECT current_database()"), Is.EqualTo(csb.Username)); - Assert.That(conn.Database, Is.EqualTo(csb.Username)); - } + conn.ReloadTypes(); + Assert.That(await conn.ExecuteScalarAsync("SELECT COUNT(*) FROM record"), Is.Zero); } - - [Test, Description("Breaks a connector while it's in the pool, with a keepalive and without")] - [TestCase(false, TestName = "BreakConnectorInPoolWithoutKeepAlive")] - [TestCase(true, TestName = "BreakConnectorInPoolWithKeepAlive")] - public async Task BreakConnectorInPool(bool keepAlive) + finally { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, hanging"); - - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { MaxPoolSize = 1 }; - if (keepAlive) - csb.KeepAlive = 1; - using (var conn = new NpgsqlConnection(csb.ToString())) - { - conn.Open(); - var connectorId = conn.ProcessID; - conn.Close(); - - // Use another connection to kill the connector currently in the pool - using (var conn2 = await OpenConnectionAsync()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({connectorId})"); - - // Allow some time for the terminate to occur - Thread.Sleep(2000); - - conn.Open(); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - if (keepAlive) - { - Assert.That(conn.ProcessID, Is.Not.EqualTo(connectorId)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - else - { - Assert.That(conn.ProcessID, Is.EqualTo(connectorId)); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception - .AssignableTo()); - } - } + await conn.ExecuteNonQueryAsync("DROP TABLE record"); } + } - #region ChangeDatabase - - [Test] - public async Task ChangeDatabase() - { - using (var conn = await OpenConnectionAsync()) - { - conn.ChangeDatabase("template1"); - using (var cmd = new NpgsqlCommand("select current_database()", conn)) - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("template1")); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/392")] + [NonParallelizable] + [Platform(Exclude = "MacOsX", Reason = "Flaky in CI on Mac")] + public async Task Non_UTF8_Encoding() + { + Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); + await using var adminConn = await OpenConnectionAsync(); + + // Create the database with server encoding sql-ascii + // Starting with PG16, the default locale provider is icu, which does not support encoding sql_ascii. Specify libc explicitly as the + // locale provider (except for older versions where specifying explicitly isn't supported, and libc is the only possibility). + await adminConn.ExecuteNonQueryAsync("DROP DATABASE IF EXISTS sqlascii"); + await adminConn.ExecuteNonQueryAsync( + adminConn.PostgreSqlVersion >= new Version(15, 0) + ? "CREATE DATABASE sqlascii ENCODING 'sql_ascii' LOCALE_PROVIDER libc TEMPLATE template0" + : "CREATE DATABASE sqlascii ENCODING 'sql_ascii' TEMPLATE template0"); + + try + { + // Insert some win1252 data + await using var goodDataSource = CreateDataSource(csb => + { + csb.Database = "sqlascii"; + csb.Encoding = "windows-1252"; + csb.ClientEncoding = "sql-ascii"; + }); + + await using (var conn = await goodDataSource.OpenConnectionAsync()) + { + const string value = "éàç"; + await conn.ExecuteNonQueryAsync("CREATE TABLE foo (bar TEXT)"); + await conn.ExecuteNonQueryAsync($"INSERT INTO foo (bar) VALUES ('{value}')"); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT * FROM foo"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.IsTrue(await reader.ReadAsync()); + + using (var textReader = await reader.GetTextReaderAsync(0)) + Assert.That(textReader.ReadToEnd(), Is.EqualTo(value)); + Assert.That(reader.GetString(0), Is.EqualTo(value)); } - } - [Test] - public async Task ChangeDatabaseDoesNotAffectOtherConnections() - { - using (var conn1 = new NpgsqlConnection(ConnectionString)) - using (var conn2 = new NpgsqlConnection(ConnectionString)) + // A normal connection with the default UTF8 encoding and client_encoding should fail + await using var badDataSource = CreateDataSource(csb => csb.Database = "sqlascii"); + await using (var conn = await badDataSource.OpenConnectionAsync()) { - // Connection 1 changes database - conn1.Open(); - conn1.ChangeDatabase("template1"); - Assert.That(await conn1.ExecuteScalarAsync("SELECT current_database()"), Is.EqualTo("template1")); - - // Connection 2's database should not changed - conn2.Open(); - Assert.That(await conn2.ExecuteScalarAsync("SELECT current_database()"), Is.Not.EqualTo(conn1.Database)); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM foo"), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.CharacterNotInRepertoire) + .Or.TypeOf() + ); } } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1331")] - public void ChangeDatabaseConnectionNotOpen() + finally { - using (var conn = new NpgsqlConnection(ConnectionString)) - Assert.That(() => conn.ChangeDatabase("template1"), Throws.Exception - .TypeOf() - .With.Message.EqualTo("Connection is not open")); + await adminConn.ExecuteNonQueryAsync("DROP DATABASE IF EXISTS sqlascii"); } + } - #endregion + [Test] + public async Task Oversize_buffer() + { + if (IsMultiplexing) + return; - [Test, Description("Tests closing a connector while a reader is open")] - [Timeout(10000)] - public async Task CloseDuringRead([Values(PooledOrNot.Pooled, PooledOrNot.Unpooled)] PooledOrNot pooled) - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - if (pooled == PooledOrNot.Unpooled) - { - if (IsMultiplexing) - return; // Multiplexing requires pooling - csb.Pooling = false; - } + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - using (var conn = await OpenConnectionAsync(csb)) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - conn.Close(); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - Assert.That(reader.IsClosed); - } - - conn.Open(); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + Assert.That(conn.Connector!.ReadBuffer.Size, Is.EqualTo(csb.ReadBufferSize)); - [Test] - public async Task SearchPath() + // Read a big row, we should now be using an oversize buffer + var bigString1 = new string('x', conn.Connector.ReadBuffer.Size + 1); + using (var cmd = new NpgsqlCommand($"SELECT '{bigString1}'", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = await OpenConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) { SearchPath = "foo" })) - Assert.That(await conn.ExecuteScalarAsync("SHOW search_path"), Contains.Substring("foo")); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(bigString1)); } + var size1 = conn.Connector.ReadBuffer.Size; + Assert.That(conn.Connector.ReadBuffer.Size, Is.GreaterThan(csb.ReadBufferSize)); - [Test] - public async Task SetOptions() + // Even bigger oversize buffer + var bigString2 = new string('x', conn.Connector.ReadBuffer.Size + 1); + using (var cmd = new NpgsqlCommand($"SELECT '{bigString2}'", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) { - using var _ = CreateTempPool(new NpgsqlConnectionStringBuilder(ConnectionString) - { - Options = "-c default_transaction_isolation=serializable -c default_transaction_deferrable=on -c foo.bar=My\\ Famous\\\\Thing" - }, out var connectionString); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(bigString2)); + } + Assert.That(conn.Connector.ReadBuffer.Size, Is.GreaterThan(size1)); - using var conn = await OpenConnectionAsync(connectionString); + var processId = conn.ProcessID; + conn.Close(); + conn.Open(); + Assert.That(conn.ProcessID, Is.EqualTo(processId)); + Assert.That(conn.Connector.ReadBuffer.Size, Is.EqualTo(csb.ReadBufferSize)); + } - Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_isolation"), Is.EqualTo("serializable")); - Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_deferrable"), Is.EqualTo("on")); - Assert.That(await conn.ExecuteScalarAsync("SHOW foo.bar"), Is.EqualTo("My Famous\\Thing")); - } + #region Keepalive - [Test] - public async Task ConnectorNotInitializedException1000581() - { - var command = new NpgsqlCommand(); - command.CommandText = @"SELECT 123"; + [Test, Explicit, Description("Turns on TCP keepalive and sleeps forever, good for wiresharking")] + public async Task TcpKeepaliveTime() + { + await using var dataSource = CreateDataSource(csb => csb.TcpKeepAliveTime = 2); + using (await dataSource.OpenConnectionAsync()) + Thread.Sleep(Timeout.Infinite); + } - for (var i = 0; i < 2; i++) - { - using (var connection = new NpgsqlConnection(ConnectionString)) - { - connection.Open(); - command.Connection = connection; - var tx = connection.BeginTransaction(); - await command.ExecuteScalarAsync(); - await tx.CommitAsync(); - } - } - } + [Test, Explicit, Description("Turns on TCP keepalive and sleeps forever, good for wiresharking")] + public async Task TcpKeepalive() + { + await using var dataSource = CreateDataSource(csb => csb.TcpKeepAlive = true); + await using (await dataSource.OpenConnectionAsync()) + Thread.Sleep(Timeout.Infinite); + } - [Test] - [Ignore("")] - public void NpgsqlErrorRepro1() - { - throw new NotImplementedException(); -#if WHAT_TO_DO_WITH_THIS - using (var connection = new NpgsqlConnection(ConnectionString)) - { - connection.Open(); - using (var transaction = connection.BeginTransaction()) - { - var largeObjectMgr = new LargeObjectManager(connection); - try - { - var largeObject = largeObjectMgr.Open(-1, LargeObjectManager.READWRITE); - transaction.Commit(); - } - catch - { - // ignore the LO failure - } - } // *1* sometimes it throws "System.NotSupportedException: This stream does not support seek operations" - - using (var command = connection.CreateCommand()) - { - command.CommandText = "SELECT * FROM pg_database"; - using (var reader = command.ExecuteReader()) - { - Assert.IsTrue(reader.Read()); // *2* this fails if the initial connection is used - } - } - } // *3* sometimes it throws "System.NotSupportedException: This stream does not support seek operations" -#endif - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3511")] + public async Task Keepalive_with_failed_transaction() + { + if (IsMultiplexing) + return; - [Test] - public void Bug1011001() - { - //[#1011001] Bug in NpgsqlConnectionStringBuilder affects on cache and connection pool + await using var dataSource = CreateDataSource(csb => csb.KeepAlive = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); - var csb1 = new NpgsqlConnectionStringBuilder(@"Server=server;Port=5432;User Id=user;Password=passwor;Database=database;"); - var cs1 = csb1.ToString(); - var csb2 = new NpgsqlConnectionStringBuilder(cs1); - var cs2 = csb2.ToString(); - Assert.IsTrue(cs1 == cs2); - } + Assert.ThrowsAsync(async () => await conn.ExecuteScalarAsync("SELECT non_existent_table")); + // Connection is now in a failed transaction state. Wait a bit to allow for the keepalive to execute. + Thread.Sleep(3000); - [Test] - public void NpgsqlErrorRepro2() - { -#if WHAT_TO_DO_WITH_THIS - var connection = new NpgsqlConnection(ConnectionString); - connection.Open(); - var transaction = connection.BeginTransaction(); - var largeObjectMgr = new LargeObjectManager(connection); - try - { - var largeObject = largeObjectMgr.Open(-1, LargeObjectManager.READWRITE); - transaction.Commit(); - } - catch - { - // ignore the LO failure - try - { - transaction.Dispose(); - } - catch - { - // ignore dispose failure - } - try - { - connection.Dispose(); - } - catch - { - // ignore dispose failure - } - } + await tx.RollbackAsync(); - using (connection = new NpgsqlConnection(ConnectionString)) - { - connection.Open(); - using (var command = connection.CreateCommand()) - { - command.CommandText = "SELECT * FROM pg_database"; - using (var reader = command.ExecuteReader()) - { - Assert.IsTrue(reader.Read()); - // *1* this fails if the connection for the pool happens to be the bad one from above - Assert.IsTrue(!String.IsNullOrEmpty((string)reader["datname"])); - } - } - } -#endif - } + // Confirm that the connection is still open and usable + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/pull/164")] - public void voidConnectionStateWhenDisposed() - { - var c = new NpgsqlConnection(); - c.Dispose(); - Assert.AreEqual(ConnectionState.Closed, c.State); - } + #endregion Keepalive - [Test] - public void ChangeApplicationNameWithConnectionStringBuilder() + [Test] + public async Task Change_parameter() + { + if (IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + var defaultApplicationName = conn.PostgresParameters["application_name"]; + await conn.ExecuteNonQueryAsync("SET application_name = 'some_test_value'"); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo("some_test_value")); + await conn.ExecuteNonQueryAsync("SET application_name = 'some_test_value2'"); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo("some_test_value2")); + await conn.ExecuteNonQueryAsync($"SET application_name = '{defaultApplicationName}'"); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo(defaultApplicationName)); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_OptionsFromEnvironment_Succeeds() + { + using (SetEnvironmentVariable("PGOPTIONS", "-c default_transaction_isolation=serializable -c default_transaction_deferrable=on -c foo.bar=My\\ Famous\\\\Thing")) { - // Test for issue #165 on github. - var builder = new NpgsqlConnectionStringBuilder(); - builder.ApplicationName = "test"; + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_isolation"), Is.EqualTo("serializable")); + Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_deferrable"), Is.EqualTo("on")); + Assert.That(await conn.ExecuteScalarAsync("SHOW foo.bar"), Is.EqualTo("My Famous\\Thing")); } + } - [Test, Description("Makes sure notices are probably received and emitted as events")] - public async Task Notice() - { - await using (var conn = await OpenConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) - { - // Make sure messages are in English - Options = "-c lc_messages=en_US.UTF-8" - })) - await using (GetTempFunctionName(conn, out var function)) - { - await conn.ExecuteNonQueryAsync($@" -CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS -'BEGIN RAISE NOTICE ''testnotice''; END;' -LANGUAGE 'plpgsql'"); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3030")] + [TestCase(true, TestName = "NoResetOnClose")] + [TestCase(false, TestName = "NoNoResetOnClose")] + public async Task NoResetOnClose(bool noResetOnClose) + { + var originalApplicationName = new NpgsqlConnectionStringBuilder(ConnectionString).ApplicationName ?? ""; + + await using var dataSource = CreateDataSource(csb => + { + csb.MaxPoolSize = 1; + csb.NoResetOnClose = noResetOnClose; + }); + + await using var conn = await dataSource.OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync("SET application_name = 'modified'"); + await conn.CloseAsync(); + await conn.OpenAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW application_name"), Is.EqualTo( + noResetOnClose || IsMultiplexing + ? "modified" + : originalApplicationName)); + } - var mre = new ManualResetEvent(false); - PostgresNotice? notice = null; - NoticeEventHandler action = (sender, args) => - { - notice = args.Notice; - mre.Set(); - }; - conn.Notice += action; - try - { - // See docs for CreateSleepCommand - await conn.ExecuteNonQueryAsync($"SELECT {function}()::TEXT"); - mre.WaitOne(5000); - Assert.That(notice, Is.Not.Null, "No notice was emitted"); - Assert.That(notice!.MessageText, Is.EqualTo("testnotice")); - Assert.That(notice.Severity, Is.EqualTo("NOTICE")); - } - finally - { - conn.Notice -= action; - } - } - } + [Test] + [Description("Test whether the internal NpgsqlConnection.Open method stays on the same thread with async=false")] + public async Task Sync_open_blocked_same_thread() + { + if (IsMultiplexing) + return; - [Test, Description("Makes sure that concurrent use of the connection throws an exception")] - public async Task ConcurrentUse() + await using var dataSource = CreateDataSource(csb => { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (await cmd.ExecuteReaderAsync()) - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 2"), - Throws.Exception.TypeOf() - .With.Property(nameof(NpgsqlOperationInProgressException.CommandInProgress)).SameAs(cmd)); - - await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE foo (bar INT)"); - using (conn.BeginBinaryImport("COPY foo (bar) FROM STDIN BINARY")) - { - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 2"), - Throws.Exception.TypeOf() - .With.Message.Contains("Copy")); - } - } - } + csb.MaxPoolSize = 1; + }); - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/783")] - public void PersistSecurityInfoIsOn([Values(true, false)] bool pooling) - { - if (IsMultiplexing && !pooling) - return; + await using var openConnection = await dataSource.OpenConnectionAsync(); - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - PersistSecurityInfo = true, - Pooling = pooling - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - var passwd = new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password; - Assert.That(passwd, Is.Not.Null); - conn.Open(); - Assert.That(new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password, Is.EqualTo(passwd)); - } - } + // 2 tasks are usually enough to reproduce the issue + const int taskCount = 2; - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/783")] - public void NoPasswordWithoutPersistSecurityInfo([Values(true, false)] bool pooling) + var tcs = new TaskCompletionSource[taskCount]; + for (var i = 0; i < tcs.Length; i++) { - if (IsMultiplexing && !pooling) - return; - - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = pooling - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - var csb = new NpgsqlConnectionStringBuilder(conn.ConnectionString); - Assert.That(csb.PersistSecurityInfo, Is.False); - Assert.That(csb.Password, Is.Not.Null); - conn.Open(); - Assert.That(new NpgsqlConnectionStringBuilder(conn.ConnectionString).Password, Is.Null); - } + tcs[i] = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } + var sameThreadTasks = Enumerable.Range(0, taskCount).Select(x => Task.Run(async () => + { + var beforeOpenThread = Thread.CurrentThread; + tcs[x].SetResult(null); + using var conn = dataSource.CreateConnection(); + // even though we await it should complete synchronously due to async = false + await conn.Open(async: false, CancellationToken.None); + return beforeOpenThread == Thread.CurrentThread; + })).ToList(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2725")] - public void CloneWithAndPersistSecurityInfo() + await Task.WhenAll(tcs.Select(x => x.Task)); + // Just in case give them a second to block on getting a connection from the pool + await Task.Delay(1000); + await openConnection.CloseAsync(); + + foreach (var sameThreadTask in sameThreadTasks) { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - PersistSecurityInfo = true - }; - using var _ = CreateTempPool(builder, out var connStringWithPersist); + Assert.IsTrue(await sameThreadTask, "Synchronous open completed on different thread"); + } + } - using var connWithPersist = new NpgsqlConnection(connStringWithPersist); + #region Physical connection initialization - // First un-persist, should work - builder.PersistSecurityInfo = false; - var connStringWithoutPersist = builder.ToString(); - using var clonedWithoutPersist = connWithPersist.CloneWith(connStringWithoutPersist); - clonedWithoutPersist.Open(); + [Test] + public async Task PhysicalConnectionInitializer_sync() + { + if (IsMultiplexing) // Sync I/O + return; - Assert.That(clonedWithoutPersist.ConnectionString, Does.Not.Contain("Password=")); + await using var adminConn = await OpenConnectionAsync(); + var table = await CreateTempTable(adminConn, "ID INTEGER"); - // Then attempt to re-persist, should not work - using var clonedConn = clonedWithoutPersist.CloneWith(connStringWithPersist); - clonedConn.Open(); + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + conn => conn.ExecuteNonQuery($"INSERT INTO {table} VALUES (1)"), + _ => throw new NotSupportedException()); + await using var dataSource = dataSourceBuilder.Build(); - Assert.That(clonedConn.ConnectionString, Does.Not.Contain("Password=")); + await using (var conn = dataSource.OpenConnection()) + { + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""), Is.EqualTo(1)); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/743")] - [IssueLink("https://github.com/npgsql/npgsql/issues/783")] - public void Clone() + // Opening a second time should get us an idle connection, which should not cause the initializer to get executed + await using (var conn = dataSource.OpenConnection()) { - using (CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = new NpgsqlConnection(connectionString)) - { - ProvideClientCertificatesCallback callback1 = certificates => { }; - conn.ProvideClientCertificatesCallback = callback1; - RemoteCertificateValidationCallback callback2 = (sender, certificate, chain, errors) => true; - conn.UserCertificateValidationCallback = callback2; - - conn.Open(); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - - using (var conn2 = (NpgsqlConnection)((ICloneable)conn).Clone()) - { - Assert.That(conn2.ConnectionString, Is.EqualTo(conn.ConnectionString)); - Assert.That(conn2.ProvideClientCertificatesCallback, Is.SameAs(callback1)); - Assert.That(conn2.UserCertificateValidationCallback, Is.SameAs(callback2)); - conn2.Open(); - Assert.That(async () => await conn2.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""), Is.EqualTo(1)); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/824")] - public async Task ReloadTypes() - { - if (IsMultiplexing) - return; + [Test] + public async Task PhysicalConnectionInitializer_async() + { + // With multiplexing the connector might become idle at undetermined point after the query is executed. + // Which is why we ignore it. + if (IsMultiplexing) + return; - using (CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = await OpenConnectionAsync(connectionString)) - using (var conn2 = await OpenConnectionAsync(connectionString)) - { - Assert.That(await conn.ExecuteScalarAsync("SELECT EXISTS (SELECT * FROM pg_type WHERE typname='reload_types_enum')"), - Is.False); - await conn.ExecuteNonQueryAsync("CREATE TYPE pg_temp.reload_types_enum AS ENUM ('First', 'Second')"); - Assert.That(() => conn.TypeMapper.MapEnum(), Throws.Exception.TypeOf()); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(); - - // Make sure conn2 picks up the new type after a pooled close - var connId = conn2.ProcessID; - conn2.Close(); - conn2.Open(); - Assert.That(conn2.ProcessID, Is.EqualTo(connId), "Didn't get the same connector back"); - conn2.TypeMapper.MapEnum(); - } - } - enum ReloadTypesEnum { First, Second }; + await using var adminConn = await OpenConnectionAsync(); + var table = await CreateTempTable(adminConn, "ID INTEGER"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + _ => throw new NotSupportedException(), + async conn => await conn.ExecuteNonQueryAsync($"INSERT INTO {table} VALUES (1)")); + await using var dataSource = dataSourceBuilder.Build(); - [Test] - public async Task DatabaseInfoIsShared() + await using (var conn = await dataSource.OpenConnectionAsync()) { - if (IsMultiplexing) - return; - using (var conn1 = await OpenConnectionAsync()) - using (var conn2 = await OpenConnectionAsync()) - Assert.That(conn1.Connector!.DatabaseInfo, Is.SameAs(conn2.Connector!.DatabaseInfo)); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""), Is.EqualTo(1)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/736")] - public async Task ManyOpenClose() + // Opening a second time should get us an idle connection, which should not cause the initializer to get executed + await using (var conn = await dataSource.OpenConnectionAsync()) { - // The connector's _sentRfqPrependedMessages is a byte, too many open/closes made it overflow - for (var i = 0; i < 255; i++) - { - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - } - } - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - } - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""), Is.EqualTo(1)); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/736")] - public async Task ManyOpenCloseWithTransaction() - { - // The connector's _sentRfqPrependedMessages is a byte, too many open/closes made it overflow - for (var i = 0; i < 255; i++) + [Test] + public async Task PhysicalConnectionInitializer_sync_with_break() + { + if (IsMultiplexing) // Sync I/O + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + conn => { - using (var conn = await OpenConnectionAsync()) - conn.BeginTransaction(); - } - using (var conn = await OpenConnectionAsync()) - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + // Use another connection to kill the connector currently in the pool + using (var conn2 = OpenConnection()) + conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/927")] - [IssueLink("https://github.com/npgsql/npgsql/issues/736")] - [Ignore("Fails when running the entire test suite but not on its own...")] - public async Task RollbackOnClose() - { - // Npgsql 3.0.0 to 3.0.4 prepended a rollback for the next time the connector is used, as an optimization. - // This caused some issues (#927) and was removed. + conn.ExecuteScalar("SELECT 1"); + }, + _ => throw new NotSupportedException()); + await using var dataSource = dataSourceBuilder.Build(); - // Clear connections in pool as we're going to need to reopen the same connection - var dummyConn = new NpgsqlConnection(ConnectionString); - NpgsqlConnection.ClearPool(dummyConn); + Assert.That(() => dataSource.OpenConnection(), Throws.Exception.InstanceOf()); + Assert.That(dataSource.Statistics, Is.EqualTo((0, 0, 0))); + } - int processId; - using (var conn = await OpenConnectionAsync()) - { - processId = conn.Connector!.BackendProcessId; - conn.BeginTransaction(); - await conn.ExecuteNonQueryAsync("SELECT 1"); - Assert.That(conn.Connector.TransactionStatus, Is.EqualTo(TransactionStatus.InTransactionBlock)); - } - using (var conn = await OpenConnectionAsync()) + [Test] + public async Task PhysicalConnectionInitializer_async_with_break() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + _ => throw new NotSupportedException(), + async conn => { - Assert.That(conn.Connector!.BackendProcessId, Is.EqualTo(processId)); - Assert.That(conn.Connector.TransactionStatus, Is.EqualTo(TransactionStatus.Idle)); - } - } + // Use another connection to kill the connector currently in the pool + await using (var conn2 = await OpenConnectionAsync()) + await conn2.ExecuteNonQueryAsync($"SELECT pg_terminate_backend({conn.ProcessID})"); - [Test, Description("Tests an exception happening when sending the Terminate message while closing a ready connector")] - [IssueLink("https://github.com/npgsql/npgsql/issues/777")] - [Ignore("Flaky")] - public async Task ExceptionDuringClose() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = false }; - using (var conn = await OpenConnectionAsync(csb)) - { - var connectorId = conn.ProcessID; + await conn.ExecuteScalarAsync("SELECT 1"); + }); + await using var dataSource = dataSourceBuilder.Build(); - using (var conn2 = await OpenConnectionAsync()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({connectorId})"); + Assert.That(async () => await dataSource.OpenConnectionAsync(), Throws.Exception.InstanceOf()); + Assert.That(dataSource.Statistics, Is.EqualTo((0, 0, 0))); + } - conn.Close(); - } - } + [Test] + public async Task PhysicalConnectionInitializer_async_throws_on_second_open() + { + // With multiplexing a physical connection might open on NpgsqlConnection.OpenAsync (if there was no completed bootstrap beforehand) + // or on NpgsqlCommand.ExecuteReaderAsync. + // We've already tested the first case in PhysicalConnectionInitializer_async_throws above, testing the second one below. + var count = 0; + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + _ => throw new NotSupportedException(), + _ => + { + if (++count == 1) + return Task.CompletedTask; + throw new Exception("INTENTIONAL FAILURE"); + }); + await using var dataSource = dataSourceBuilder.Build(); + + await using var conn1 = dataSource.CreateConnection(); + Assert.DoesNotThrowAsync(async () => await conn1.OpenAsync()); + + // We start a transaction specifically for multiplexing (to bind a connector to the connection) + await using var tx = await conn1.BeginTransactionAsync(); + + await using var conn2 = dataSource.CreateConnection(); + Exception exception; + if (IsMultiplexing) + { + await conn2.OpenAsync(); + exception = Assert.ThrowsAsync(async () => await conn2.BeginTransactionAsync())!; + } + else + exception = Assert.ThrowsAsync(async () => await conn2.OpenAsync())!; + Assert.That(exception.Message, Is.EqualTo("INTENTIONAL FAILURE")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1180")] - public void PoolByPassword() - { - using var _ = CreateTempPool(ConnectionString, out var connectionString); - using (var goodConn = new NpgsqlConnection(connectionString)) - goodConn.Open(); + [Test] + public async Task PhysicalConnectionInitializer_disposes_connection() + { + NpgsqlConnection? initializerConnection = null; - var badConnectionString = new NpgsqlConnectionStringBuilder(connectionString) + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.UsePhysicalConnectionInitializer( + _ => throw new NotSupportedException(), + conn => { - Password = "badpasswd" - }.ConnectionString; - using (var conn = new NpgsqlConnection(badConnectionString)) - Assert.That(conn.Open, Throws.Exception.TypeOf()); - } + initializerConnection = conn; + return Task.CompletedTask; + }); + await using var dataSource = dataSourceBuilder.Build(); - [Test, Description("Some pseudo-PG database don't support pg_type loading, we have a minimal DatabaseInfo for this")] - public async Task NoTypeLoading() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading - }; + await using var conn = await dataSource.OpenConnectionAsync(); - using var _ = CreateTempPool(builder, out var connectionString); - using var conn = await OpenConnectionAsync(connectionString); - // Arrays should not be supported in this mode - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT '{1,2,3}'::INTEGER[]"), - Throws.Exception.TypeOf()); - // Test that some basic types do work - Assert.That(await conn.ExecuteScalarAsync("SELECT 8"), Is.EqualTo(8)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 'foo'"), Is.EqualTo("foo")); - Assert.That(await conn.ExecuteScalarAsync("SELECT TRUE"), Is.EqualTo(true)); - Assert.That(await conn.ExecuteScalarAsync("SELECT INET '192.168.1.1'"), Is.EqualTo(IPAddress.Parse("192.168.1.1"))); - } + Assert.That(initializerConnection, Is.Not.Null); + Assert.That(conn, Is.Not.SameAs(initializerConnection)); + Assert.That(() => initializerConnection!.Open(), Throws.Exception.TypeOf()); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1158")] - public async Task TableNamedRecord() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); + #endregion Physical connection initialization - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync(@" + [Test] + [NonParallelizable] // Modifies global database info factories + [IssueLink("https://github.com/npgsql/npgsql/issues/4425")] + public async Task Breaking_connection_while_loading_database_info() + { + if (IsMultiplexing) + return; -DROP TABLE IF EXISTS record; -CREATE TABLE record ()"); - try - { - conn.ReloadTypes(); - Assert.That(await conn.ExecuteScalarAsync("SELECT COUNT(*) FROM record"), Is.Zero); - } - finally - { - await conn.ExecuteNonQueryAsync("DROP TABLE record"); - } - } - } + await using var dataSource = CreateDataSource(); -// TODO: Port this test to .NET Core somehow -#if NET461 - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/392")] - public async Task NonUTF8Encoding() + await using var firstConn = dataSource.CreateConnection(); + NpgsqlDatabaseInfo.RegisterFactory(new BreakingDatabaseInfoFactory()); + try { - using (var adminConn = await OpenConnectionAsync()) - { - // Create the database with server encoding sql-ascii - await adminConn.ExecuteNonQueryAsync("DROP DATABASE IF EXISTS sqlascii"); - await adminConn.ExecuteNonQueryAsync("CREATE DATABASE sqlascii ENCODING 'sql_ascii' TEMPLATE template0"); - try - { - // Insert some win1252 data - var goodBuilder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Database = "sqlascii", - Encoding = "windows-1252", - ClientEncoding = "sql-ascii", - }; - - using var _ = CreateTempPool(goodBuilder, out var goodConnectionString); - - using (var conn = await OpenConnectionAsync(goodConnectionString)) - { - await conn.ExecuteNonQueryAsync("CREATE TABLE foo (bar TEXT)"); - await conn.ExecuteNonQueryAsync("INSERT INTO foo (bar) VALUES ('éàç')"); - Assert.That(await conn.ExecuteScalarAsync("SELECT * FROM foo"), Is.EqualTo("éàç")); - } - - // A normal connection with the default UTF8 encoding and client_encoding should fail - var badBuilder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Database = "sqlascii", - }; - using var __ = CreateTempPool(badBuilder, out var badConnectionString); - using (var conn = await OpenConnectionAsync(badConnectionString)) - { - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM foo"), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("22021") - .Or.TypeOf() - ); - } - } - finally - { - await adminConn.ExecuteNonQueryAsync("DROP DATABASE IF EXISTS sqlascii"); - } - } + // Test the first time we load the database info + Assert.ThrowsAsync(firstConn.OpenAsync); } -#endif - - [Test] - public async Task OversizeBuffer() + finally { - if (IsMultiplexing) - return; - - using (CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = await OpenConnectionAsync(connectionString)) - { - var csb = new NpgsqlConnectionStringBuilder(connectionString); - - Assert.That(conn.Connector!.ReadBuffer.Size, Is.EqualTo(csb.ReadBufferSize)); - - // Read a big row, we should now be using an oversize buffer - var bigString1 = new string('x', csb.ReadBufferSize + 10); - using (var cmd = new NpgsqlCommand($"SELECT '{bigString1}'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(bigString1)); - } - var size1 = conn.Connector.ReadBuffer.Size; - Assert.That(conn.Connector.ReadBuffer.Size, Is.GreaterThan(csb.ReadBufferSize)); - - // Even bigger oversize buffer - var bigString2 = new string('x', csb.ReadBufferSize + 20); - using (var cmd = new NpgsqlCommand($"SELECT '{bigString2}'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(bigString2)); - } - Assert.That(conn.Connector.ReadBuffer.Size, Is.GreaterThan(size1)); - - var processId = conn.ProcessID; - conn.Close(); - conn.Open(); - Assert.That(conn.ProcessID, Is.EqualTo(processId)); - Assert.That(conn.Connector.ReadBuffer.Size, Is.EqualTo(csb.ReadBufferSize)); - } + NpgsqlDatabaseInfo.ResetFactories(); } - [Test, Explicit, Description("Turns on TCP keepalive and sleeps forever, good for wiresharking")] - public async Task TcpKeepaliveTime() + await firstConn.OpenAsync(); + await using var secondConn = await dataSource.OpenConnectionAsync(); + await secondConn.CloseAsync(); + await firstConn.ReloadTypesAsync(); + + NpgsqlDatabaseInfo.RegisterFactory(new BreakingDatabaseInfoFactory()); + try { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - TcpKeepAliveTime = 2 - }; - using (await OpenConnectionAsync(csb)) - Thread.Sleep(Timeout.Infinite); + // Make sure that the database info is now cached and won't be reloaded + Assert.DoesNotThrowAsync(secondConn.OpenAsync); } - - [Test, Explicit, Description("Turns on TCP keepalive and sleeps forever, good for wiresharking")] - public async Task TcpKeepalive() + finally { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - TcpKeepAlive = true - }; - using (await OpenConnectionAsync(csb)) - Thread.Sleep(Timeout.Infinite); + NpgsqlDatabaseInfo.ResetFactories(); } + } - [Test] - public async Task ChangeParameter() - { - if (IsMultiplexing) - return; + class BreakingDatabaseInfoFactory : INpgsqlDatabaseInfoFactory + { + public Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) + => throw conn.Break(new IOException()); + } - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync("SET application_name = 'some_test_value'"); - Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo("some_test_value")); - await conn.ExecuteNonQueryAsync("SET application_name = 'some_test_value2'"); - Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo("some_test_value2")); - } - } + #region Logging tests - [Test] - [NonParallelizable] - public async Task Connect_UserNameFromEnvironment_Succeeds() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { IntegratedSecurity = false }; - using var _ = SetEnvironmentVariable("PGUSER", builder.Username); - builder.Username = null; - using var __ = CreateTempPool(builder.ConnectionString, out var connectionString); - using var ___ = await OpenConnectionAsync(connectionString); - } + [Test] + public async Task Log_Open_Close_pooled() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = dataSource.CreateConnection(); - [Test] - [NonParallelizable] - public async Task Connect_PasswordFromEnvironment_Succeeds() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { IntegratedSecurity = false }; - using var _ = SetEnvironmentVariable("PGPASSWORD", builder.Password); - builder.Password = null; - using var __ = CreateTempPool(builder.ConnectionString, out var connectionString); - using var ___ = await OpenConnectionAsync(connectionString); - } + // Open and close to have an idle connection in the pool - we don't want to test physical open/close + await conn.OpenAsync(); + await conn.CloseAsync(); - [Test] - [NonParallelizable] - public async Task Connect_OptionsFromEnvironment_Succeeds() + int processId, port; + string host, database; + using (listLoggerProvider.Record()) { - using (SetEnvironmentVariable("PGOPTIONS", "-c default_transaction_isolation=serializable -c default_transaction_deferrable=on -c foo.bar=My\\ Famous\\\\Thing")) - { - using var _ = CreateTempPool(ConnectionString, out var connectionString); - using var conn = await OpenConnectionAsync(connectionString); - Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_isolation"), Is.EqualTo("serializable")); - Assert.That(await conn.ExecuteScalarAsync("SHOW default_transaction_deferrable"), Is.EqualTo("on")); - Assert.That(await conn.ExecuteScalarAsync("SHOW foo.bar"), Is.EqualTo("My Famous\\Thing")); - } - } + await conn.OpenAsync(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3030")] - [TestCase(true, TestName = "NoResetOnClose")] - [TestCase(false, TestName = "NoNoResetOnClose")] - public async Task NoResetOnClose(bool noResetOnClose) - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxPoolSize = 1, - NoResetOnClose = noResetOnClose - }; - using var _ = CreateTempPool(builder, out var connectionString); - var original = new NpgsqlConnectionStringBuilder(connectionString).ApplicationName; + var tx = await conn.BeginTransactionAsync(); + (processId, host, port, database) = (conn.ProcessID, conn.Host!, conn.Port, conn.Database); + await tx.CommitAsync(); - using var conn = await OpenConnectionAsync(connectionString); - await conn.ExecuteNonQueryAsync("SET application_name = 'modified'"); await conn.CloseAsync(); - await conn.OpenAsync(); - Assert.That(await conn.ExecuteScalarAsync("SHOW application_name"), Is.EqualTo( - noResetOnClose || IsMultiplexing - ? "modified" - : original)); } - [Test] - [NonParallelizable] - public async Task UsePgPassFile() - { - using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - - var password = builder.Password; - var passFile = Path.GetTempFileName(); + var openingConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.OpeningConnection); + AssertLoggingConnectionString(conn, openingConnectionEvent.State); + AssertLoggingStateContains(openingConnectionEvent, "Host", host); + AssertLoggingStateContains(openingConnectionEvent, "Port", port); + AssertLoggingStateContains(openingConnectionEvent, "Database", database); - builder.Password = password; - builder.Passfile = passFile; + var openedConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.OpenedConnection); + AssertLoggingConnectionString(conn, openedConnectionEvent.State); + AssertLoggingStateContains(openedConnectionEvent, "Host", host); + AssertLoggingStateContains(openedConnectionEvent, "Port", port); + AssertLoggingStateContains(openedConnectionEvent, "Database", database); - using var deletePassFile = Defer(() => File.Delete(passFile)); + var closingConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.ClosingConnection); + AssertLoggingConnectionString(conn, closingConnectionEvent.State); + AssertLoggingStateContains(closingConnectionEvent, "Host", host); + AssertLoggingStateContains(closingConnectionEvent, "Port", port); + AssertLoggingStateContains(closingConnectionEvent, "Database", database); - File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{password}"); + var closedConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.ClosedConnection); + AssertLoggingConnectionString(conn, closedConnectionEvent.State); + AssertLoggingStateContains(closedConnectionEvent, "Host", host); + AssertLoggingStateContains(closedConnectionEvent, "Port", port); + AssertLoggingStateContains(closedConnectionEvent, "Database", database); - using var passFileVariable = SetEnvironmentVariable("PGPASSFILE", passFile); - using var pool = CreateTempPool(builder.ConnectionString, out var connectionString); - using var conn = await OpenConnectionAsync(connectionString); + if (!IsMultiplexing) + { + AssertLoggingStateContains(openedConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closingConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closedConnectionEvent, "ConnectorId", processId); } - [Test] - [NonParallelizable] - public void PasswordSourcePrecendence() + var ids = new[] { - using var resetPassword = SetEnvironmentVariable("PGPASSWORD", null); - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - - var password = builder.Password; - var passwordBad = password + "_bad"; - - var passFile = Path.GetTempFileName(); - var passFileBad = passFile + "_bad"; - - using var deletePassFile = Defer(() => File.Delete(passFile)); - using var deletePassFileBad = Defer(() => File.Delete(passFileBad)); - - File.WriteAllText(passFile, $"*:*:*:{builder.Username}:{password}"); - File.WriteAllText(passFileBad, $"*:*:*:{builder.Username}:{passwordBad}"); - - using (var passFileVariable = SetEnvironmentVariable("PGPASSFILE", passFileBad)) - { - // Password from the connection string goes first - using (var passwordVariable = SetEnvironmentVariable("PGPASSWORD", passwordBad)) - Assert.That(OpenConnection(password, passFileBad), Throws.Nothing); - - // Password from the environment variable goes second - using (var passwordVariable = SetEnvironmentVariable("PGPASSWORD", password)) - Assert.That(OpenConnection(password: null, passFileBad), Throws.Nothing); + NpgsqlEventId.OpeningPhysicalConnection, + NpgsqlEventId.OpenedPhysicalConnection, + NpgsqlEventId.ClosingPhysicalConnection, + NpgsqlEventId.ClosedPhysicalConnection + }; - // Passfile from the connection string goes third - Assert.That(OpenConnection(password: null, passFile: passFile), Throws.Nothing); - } + foreach (var id in ids) + Assert.That(listLoggerProvider.Log.Count(l => l.Id == id), Is.Zero); + } - // Passfile from the environment variable goes fourth - using (var passFileVariable = SetEnvironmentVariable("PGPASSFILE", passFile)) - Assert.That(OpenConnection(password: null, passFile: null), Throws.Nothing); + [Test] + public async Task Log_Open_Close_physical() + { + if (IsMultiplexing) + return; - Func OpenConnection(string? password, string? passFile) => async () => - { - builder.Password = password; - builder.Passfile = passFile; - builder.IntegratedSecurity = false; - builder.ApplicationName = $"{nameof(PasswordSourcePrecendence)}:{Guid.NewGuid()}"; + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = false }; + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, csb.ToString()); + await using var conn = dataSource.CreateConnection(); - using var pool = CreateTempPool(builder.ConnectionString, out var connectionString); - using var connection = await OpenConnectionAsync(connectionString); - }; + int processId, port; + string host, database; + using (listLoggerProvider.Record()) + { + await conn.OpenAsync(); + (processId, host, port, database) = (conn.ProcessID, conn.Host!, conn.Port, conn.Database); + await conn.CloseAsync(); } - [Test, Description("Simulates a timeout during the authentication phase")] - [IssueLink("https://github.com/npgsql/npgsql/issues/3227")] - [Timeout(10000)] - public async Task TimeoutDuringAuthentication() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { Timeout = 1 }; - await using var postmasterMock = new PgPostmasterMock(builder.ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + var openingConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.OpeningPhysicalConnection); + AssertLoggingConnectionString(conn, openingConnectionEvent.State); + AssertLoggingStateContains(openingConnectionEvent, "Host", host); + AssertLoggingStateContains(openingConnectionEvent, "Port", port); + AssertLoggingStateContains(openingConnectionEvent, "Database", database); + + var openedConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.OpenedPhysicalConnection); + AssertLoggingConnectionString(conn, openedConnectionEvent.State); + AssertLoggingStateContains(openedConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(openingConnectionEvent, "Host", host); + AssertLoggingStateContains(openingConnectionEvent, "Port", port); + AssertLoggingStateContains(openingConnectionEvent, "Database", database); + AssertLoggingStateContains(openedConnectionEvent, "DurationMs"); + + var closingConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.ClosingPhysicalConnection); + AssertLoggingConnectionString(conn, closingConnectionEvent.State); + AssertLoggingStateContains(closingConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closingConnectionEvent, "Host", host); + AssertLoggingStateContains(closingConnectionEvent, "Port", port); + AssertLoggingStateContains(closingConnectionEvent, "Database", database); + + var closededConnectionEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.ClosedPhysicalConnection); + AssertLoggingConnectionString(conn, closededConnectionEvent.State); + AssertLoggingStateContains(closededConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closededConnectionEvent, "Host", host); + AssertLoggingStateContains(closededConnectionEvent, "Port", port); + AssertLoggingStateContains(closededConnectionEvent, "Database", database); + } - var __ = postmasterMock.AcceptServer(); + void AssertLoggingConnectionString(NpgsqlConnection connection, object? logState) + { + var keyValuePairs = (IEnumerable>)logState!; + var connectionString = keyValuePairs.Single(kvp => kvp.Key == "ConnectionString").Value; + Assert.That(connectionString, Is.EqualTo(connection.ConnectionString)); + Assert.That(connectionString, Does.Not.Contain("Password")); + } - // The server will accept a connection from the client, but will not respond to the client's authentication - // request. This should trigger a timeout - Assert.That(async () => await OpenConnectionAsync(connectionString), - Throws.Exception.TypeOf() - .With.InnerException.TypeOf()); - } + #endregion Logging tests - public ConnectionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} - } + public ConnectionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/CopyTests.cs b/test/Npgsql.Tests/CopyTests.cs index 83903e9ed1..5abe7e9c80 100644 --- a/test/Npgsql.Tests/CopyTests.cs +++ b/test/Npgsql.Tests/CopyTests.cs @@ -1,1134 +1,1380 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Specialized; using System.Data; +using System.Diagnostics; using System.IO; +using System.Numerics; using System.Text; using System.Threading; using System.Threading.Tasks; +using Npgsql.Internal; +using Npgsql.Tests.Support; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class CopyTests : MultiplexingTestBase { - public class CopyTests : MultiplexingTestBase + #region Issue 2257 + + [Test, Description("Reproduce #2257")] + public async Task Issue2257() { - #region issue 2257 + await using var conn = await OpenConnectionAsync(); + var table1 = await GetTempTableName(conn); + var table2 = await GetTempTableName(conn); - [Test, Description("Reproduce #2257")] - public async Task Issue2257() + const int rowCount = 1000000; + using (var cmd = conn.CreateCommand()) { - await using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTableName(conn, out var table1); - await using var __ = await GetTempTableName(conn, out var table2); + cmd.CommandText = $"CREATE TABLE {table1} AS SELECT * FROM generate_series(1, {rowCount}) id"; + await cmd.ExecuteNonQueryAsync(); + cmd.CommandText = $"ALTER TABLE {table1} ADD CONSTRAINT {table1}_pk PRIMARY KEY (id)"; + await cmd.ExecuteNonQueryAsync(); + cmd.CommandText = $"CREATE TABLE {table2} (master_id integer NOT NULL REFERENCES {table1} (id))"; + await cmd.ExecuteNonQueryAsync(); + } - const int rowCount = 1000000; - using (var cmd = conn.CreateCommand()) + await using var writer = conn.BeginBinaryImport($"COPY {table2} FROM STDIN BINARY"); + writer.Timeout = TimeSpan.FromMilliseconds(3); + var e = Assert.Throws(() => + { + for (var i = 1; i <= rowCount; ++i) { - cmd.CommandText = $"CREATE TABLE {table1} AS SELECT * FROM generate_series(1, {rowCount}) id"; - await cmd.ExecuteNonQueryAsync(); - cmd.CommandText = $"ALTER TABLE {table1} ADD CONSTRAINT {table1}_pk PRIMARY KEY (id)"; - await cmd.ExecuteNonQueryAsync(); - cmd.CommandText = $"CREATE TABLE {table2} (master_id integer NOT NULL REFERENCES {table1} (id))"; - await cmd.ExecuteNonQueryAsync(); + writer.StartRow(); + writer.Write(i); } - await using var writer = conn.BeginBinaryImport($"COPY {table2} FROM STDIN BINARY"); - writer.Timeout = TimeSpan.FromMilliseconds(3); - var e = Assert.Throws(() => + writer.Complete(); + })!; + Assert.That(e.InnerException, Is.TypeOf()); + } + + #endregion + + #region Raw + + [Test, Description("Exports data in binary format (raw mode) and then loads it back in")] + public async Task Raw_binary_roundtrip([Values(false, true)] bool async) + { + using var conn = await OpenConnectionAsync(); + //var iterations = Conn.BufferSize / 10 + 100; + //var iterations = Conn.BufferSize / 10 - 100; + const int iterations = 500; + + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); + using (var tx = conn.BeginTransaction()) + { + + // Preload some data into the table + using (var cmd = + new NpgsqlCommand($"INSERT INTO {table} (field_text, field_int4) VALUES (@p1, @p2)", conn)) { - for (var i = 1; i <= rowCount; ++i) + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Text, "HELLO"); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer, 8); + for (var i = 0; i < iterations; i++) { - writer.StartRow(); - writer.Write(i); + await cmd.ExecuteNonQueryAsync(); } + } - writer.Complete(); - }); - Assert.That(e.InnerException, Is.TypeOf()); + await tx.CommitAsync(); } - #endregion - - #region Raw - - [Test, Description("Exports data in binary format (raw mode) and then loads it back in")] - public async Task RawBinaryRoundtrip() + var data = new byte[10000]; + var len = 0; + using (var outStream = async + ? await conn.BeginRawBinaryCopyAsync($"COPY {table} (field_text, field_int4) TO STDIN BINARY") + : conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) + StateAssertions(conn); + + while (true) { - //var iterations = Conn.BufferSize / 10 + 100; - //var iterations = Conn.BufferSize / 10 - 100; - const int iterations = 500; + var read = outStream.Read(data, len, data.Length - len); + if (read == 0) + break; + len += read; + } - await using var _ = await GetTempTableName(conn, out var table); + Assert.That(len, Is.GreaterThan(conn.Settings.ReadBufferSize) & Is.LessThan(data.Length)); + } - using (var tx = conn.BeginTransaction()) - { - await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); - - // Preload some data into the table - using (var cmd = - new NpgsqlCommand($"INSERT INTO {table} (field_text, field_int4) VALUES (@p1, @p2)", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Text, "HELLO"); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer, 8); - for (var i = 0; i < iterations; i++) - { - await cmd.ExecuteNonQueryAsync(); - } - } - - await tx.CommitAsync(); - } + await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); - var data = new byte[10000]; - var len = 0; - using (var outStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY")) - { - StateAssertions(conn); + using (var inStream = async + ? await conn.BeginRawBinaryCopyAsync($"COPY {table} (field_text, field_int4) FROM STDIN BINARY") + : conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + { + StateAssertions(conn); - while (true) - { - var read = outStream.Read(data, len, data.Length - len); - if (read == 0) - break; - len += read; - } + inStream.Write(data, 0, len); + } - Assert.That(len, Is.GreaterThan(conn.Settings.ReadBufferSize) & Is.LessThan(data.Length)); - } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(iterations)); + } - await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + [Test, Description("Disposes a raw binary stream in the middle of an export")] + public async Task Dispose_in_middle_of_raw_binary_export() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER); +INSERT INTO {table} (field_text, field_int4) VALUES ('HELLO', 8)"); - using (var inStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) - { - StateAssertions(conn); + var data = new byte[3]; + using (var inStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY")) + { + // Read some bytes + var len = inStream.Read(data, 0, data.Length); + Assert.That(len, Is.EqualTo(data.Length)); + } + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - inStream.Write(data, 0, len); - } + [Test, Description("Disposes a raw binary stream in the middle of an import")] + public async Task Dispose_in_middle_of_raw_binary_import() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); + + var inStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); + inStream.Write(NpgsqlRawCopyStream.BinarySignature, 0, NpgsqlRawCopyStream.BinarySignature.Length); + Assert.That(() => inStream.Dispose(), Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.BadCopyFileFormat) + ); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(iterations)); + [Test, Description("Cancels a binary write")] + public async Task Cancel_raw_binary_import() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); + await using (var tx = await conn.BeginTransactionAsync()) + { + var garbage = new byte[] {1, 2, 3, 4}; + using (var s = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) + { + s.Write(garbage, 0, garbage.Length); + s.Cancel(); } } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - [Test, Description("Disposes a raw binary stream in the middle of an export")] - public async Task DisposeInMiddleOfRawBinaryExport() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER); -INSERT INTO {table} (field_text, field_int4) VALUES ('HELLO', 8)"); + [Test] + public async Task Import_large_value_raw() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); - var data = new byte[3]; - using (var inStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY")) - { - // Read some bytes - var len = inStream.Read(data, 0, data.Length); - Assert.That(len, Is.EqualTo(data.Length)); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + var data = new byte[conn.Settings.WriteBufferSize + 10]; + var dump = new byte[conn.Settings.WriteBufferSize + 200]; + var len = 0; + + // Insert a blob with a regular insert + using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (blob) VALUES (@p)", conn)) + { + cmd.Parameters.AddWithValue("p", data); + await cmd.ExecuteNonQueryAsync(); } - [Test, Description("Disposes a raw binary stream in the middle of an import")] - public async Task DisposeInMiddleOfRawBinaryImport() + // Raw dump out + using (var outStream = conn.BeginRawBinaryCopy($"COPY {table} (blob) TO STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) + while (true) { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); - - var inStream = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); - inStream.Write(NpgsqlRawCopyStream.BinarySignature, 0, NpgsqlRawCopyStream.BinarySignature.Length); - Assert.That(() => inStream.Dispose(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("22P04") - ); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + var read = outStream.Read(dump, len, dump.Length - len); + if (read == 0) + break; + len += read; } + Assert.That(len < dump.Length); } - [Test, Description("Cancels a binary write")] - public async Task CancelRawBinaryImport() + await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + + // And raw dump back in + using (var inStream = conn.BeginRawBinaryCopy($"COPY {table} (blob) FROM STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@"CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER)"); + inStream.Write(dump, 0, len); + } + } - var garbage = new byte[] {1, 2, 3, 4}; - using (var s = conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) - { - s.Write(garbage, 0, garbage.Length); - s.Cancel(); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_table_definition_raw_binary_copy() + { + using var conn = await OpenConnectionAsync(); + Assert.Throws(() => conn.BeginRawBinaryCopy("COPY table_is_not_exist (blob) TO STDOUT BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + Assert.Throws(() => conn.BeginRawBinaryCopy("COPY table_is_not_exist (blob) FROM STDIN BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_format_raw_binary_copy() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using (var conn = await OpenConnectionAsync()) + { + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginRawBinaryCopy($"COPY {table} (blob) TO STDOUT")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); } - [Test] - public async Task ImportLargeValueRaw() + using (var conn = await OpenConnectionAsync()) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginRawBinaryCopy($"COPY {table} (blob) FROM STDIN")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + } - var data = new byte[conn.Settings.WriteBufferSize + 10]; - var dump = new byte[conn.Settings.WriteBufferSize + 200]; - var len = 0; + #endregion - // Insert a blob with a regular insert - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (blob) VALUES (@p)", conn)) - { - cmd.Parameters.AddWithValue("p", data); - await cmd.ExecuteNonQueryAsync(); - } + #region Binary - // Raw dump out - using (var outStream = conn.BeginRawBinaryCopy($"COPY {table} (blob) TO STDIN BINARY")) - { - while (true) - { - var read = outStream.Read(dump, len, dump.Length - len); - if (read == 0) - break; - len += read; - } - Assert.That(len < dump.Length); - } + [Test, Description("Roundtrips some data")] + public async Task Binary_roundtrip([Values(false, true)] bool async) + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT"); - await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + var longString = new StringBuilder(conn.Settings.WriteBufferSize + 50).Append('a').ToString(); - // And raw dump back in - using (var inStream = conn.BeginRawBinaryCopy($"COPY {table} (blob) FROM STDIN BINARY")) - { - inStream.Write(dump, 0, len); - } - } + using (var writer = async + ? await conn.BeginBinaryImportAsync($"COPY {table} (field_text, field_int2) FROM STDIN BINARY") + : conn.BeginBinaryImport($"COPY {table} (field_text, field_int2) FROM STDIN BINARY")) + { + StateAssertions(conn); + + writer.StartRow(); + writer.Write("Hello"); + writer.Write((short)8, NpgsqlDbType.Smallint); + + writer.WriteRow("Something", (short)9); + + writer.StartRow(); + writer.Write(longString, "text"); + writer.WriteNull(); + + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(3)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongTableDefinitionRawBinaryCopy() + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + using (var reader = async + ? await conn.BeginBinaryExportAsync($"COPY {table} (field_text, field_int2) TO STDIN BINARY") + : conn.BeginBinaryExport($"COPY {table} (field_text, field_int2) TO STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) - { - Assert.Throws(() => conn.BeginRawBinaryCopy("COPY table_is_not_exist (blob) TO STDOUT BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + StateAssertions(conn); - Assert.Throws(() => conn.BeginRawBinaryCopy("COPY table_is_not_exist (blob) FROM STDIN BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + Assert.That(reader.StartRow(), Is.EqualTo(2)); + Assert.That(reader.Read(), Is.EqualTo("Hello")); + Assert.That(reader.Read(NpgsqlDbType.Smallint), Is.EqualTo(8)); + + Assert.That(reader.StartRow(), Is.EqualTo(2)); + Assert.That(reader.IsNull, Is.False); + Assert.That(reader.Read(), Is.EqualTo("Something")); + reader.Skip(); + + Assert.That(reader.StartRow(), Is.EqualTo(2)); + Assert.That(reader.Read(), Is.EqualTo(longString)); + Assert.That(reader.IsNull, Is.True); + Assert.That(reader.IsNull, Is.True); + reader.Skip(); + + Assert.That(reader.StartRow(), Is.EqualTo(-1)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongFormatRawBinaryCopy() + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Cancel_binary_import() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + await using (var tx = await conn.BeginTransactionAsync()) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) + using (var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginRawBinaryCopy($"COPY {table} (blob) TO STDOUT")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + writer.StartRow(); + writer.Write("Hello"); + writer.Write(8); + // No commit should rollback } + } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginRawBinaryCopy($"COPY {table} (blob) FROM STDIN")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/657")] + public async Task Import_bytea() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field BYTEA"); + + var data = new byte[] {1, 5, 8}; + + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(data, NpgsqlDbType.Bytea); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - #endregion + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); + } - #region Binary + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4693")] + public async Task Import_numeric() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field NUMERIC(1000)"); - [Test, Description("Roundtrips some data")] - public async Task BinaryRoundtrip() + await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (field) FROM STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT", out var table); + await writer.StartRowAsync(); + await writer.WriteAsync(new BigInteger(1234), NpgsqlDbType.Numeric); + await writer.StartRowAsync(); + await writer.WriteAsync(new BigInteger(5678), NpgsqlDbType.Numeric); - var longString = new StringBuilder(conn.Settings.WriteBufferSize + 50).Append('a').ToString(); + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(2)); + } - using (var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int2) FROM STDIN BINARY")) - { - StateAssertions(conn); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT field FROM {table}"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.IsTrue(await reader.ReadAsync()); + Assert.That(reader.GetValue(0), Is.EqualTo(1234m)); + Assert.IsTrue(await reader.ReadAsync()); + Assert.That(reader.GetValue(0), Is.EqualTo(5678m)); + } - writer.StartRow(); - writer.Write("Hello"); - writer.Write((short)8, NpgsqlDbType.Smallint); + [Test] + public async Task Import_string_array() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field TEXT[]"); - writer.WriteRow("Something", (short)9); + var data = new[] {"foo", "a", "bar"}; + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(data, NpgsqlDbType.Array | NpgsqlDbType.Text); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); + } - writer.StartRow(); - writer.Write(longString, "text"); - writer.WriteNull(); + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); + } - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(3)); - } + [Test] + public async Task Import_reused_instance_mapping_info_identical_or_throws() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field int4"); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + var data = 8; + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(data, NpgsqlDbType.Integer); + writer.StartRow(); + Assert.Throws(Is.TypeOf().With.Property("Message").StartsWith("Write for column 0 resolves to a different PostgreSQL type"), + () => writer.Write(data, "int2")); + // Should be recoverable by using the same type again. + writer.Write(data, "int4"); + writer.Complete(); + } + } - using (var reader = conn.BeginBinaryExport($"COPY {table} (field_text, field_int2) TO STDIN BINARY")) - { - StateAssertions(conn); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/816")] + public async Task Import_string_with_buffer_length() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field TEXT"); - Assert.That(reader.StartRow(), Is.EqualTo(2)); - Assert.That(reader.Read(), Is.EqualTo("Hello")); - Assert.That(reader.Read(NpgsqlDbType.Smallint), Is.EqualTo(8)); + var data = new string('a', conn.Settings.WriteBufferSize); + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(data, NpgsqlDbType.Text); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); + } + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); + } - Assert.That(reader.StartRow(), Is.EqualTo(2)); - Assert.That(reader.IsNull, Is.False); - Assert.That(reader.Read(), Is.EqualTo("Something")); - reader.Skip(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/662")] + public async Task Import_direct_buffer() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); - Assert.That(reader.StartRow(), Is.EqualTo(2)); - Assert.That(reader.Read(), Is.EqualTo(longString)); - Assert.That(reader.IsNull, Is.True); - reader.Skip(); + using var writer = conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN BINARY"); + // Big value - triggers use of the direct write optimization + var data = new byte[conn.Settings.WriteBufferSize + 10]; - Assert.That(reader.StartRow(), Is.EqualTo(-1)); - } + writer.StartRow(); + writer.Write(data); + writer.StartRow(); + writer.Write(data); + } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5330")] + public async Task Import_object_null() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field TEXT[]"); - [Test] - public async Task CancelBinaryImport() + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - - using (var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write("Hello"); - writer.Write(8); - // No commit should rollback - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } + writer.StartRow(); + writer.Write(null, NpgsqlDbType.Boolean); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/657")] - public async Task ImportBytea() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field BYTEA", out var table); + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(DBNull.Value)); + } - var data = new byte[] {1, 5, 8}; + static readonly TestCaseData[] DBNullValues = + { + new TestCaseData(DBNull.Value).SetName("DBNull.Value"), + new TestCaseData(null).SetName("null") + }; - using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(data, NpgsqlDbType.Bytea); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } + [Test, TestCaseSource(nameof(DBNullValues))] + public async Task Import_dbnull(DBNull? value) + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field TEXT[]"); - Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); - } + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(value, NpgsqlDbType.Boolean); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public async Task ImportStringArray() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field TEXT[]", out var table); + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(DBNull.Value)); + } - var data = new[] {"foo", "a", "bar"}; - using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(data, NpgsqlDbType.Array | NpgsqlDbType.Text); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_table_definition_binary_import() + { + using var conn = await OpenConnectionAsync(); + // Connection should be kept alive after PostgresException was triggered + Assert.Throws(() => conn.BeginBinaryImport("COPY table_is_not_exist (blob) FROM STDIN BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_format_binary_import() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/816")] - public async Task ImportStringWithBufferLength() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field TEXT", out var table); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_table_definition_binary_export() + { + using var conn = await OpenConnectionAsync(); + // Connection should be kept alive after PostgresException was triggered + Assert.Throws(() => conn.BeginBinaryExport("COPY table_is_not_exist (blob) TO STDOUT BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - var data = new string('a', conn.Settings.WriteBufferSize); - using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(data, NpgsqlDbType.Text); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5457")] + public async Task MixedOperations() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + + using var reader = conn.BeginBinaryExport(""" + COPY (values ('foo', 1), ('bar', null), (null, 2)) TO STDOUT BINARY + """); + while(reader.StartRow() != -1) + { + string? col1 = null; + if (reader.IsNull) + reader.Skip(); + else + col1 = reader.Read(); + int? col2 = null; + if (reader.IsNull) + reader.Skip(); + else + col2 = reader.Read(); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/662")] - public async Task ImportDirectBuffer() + [Test] + public async Task ReadMoreColumnsThanExist() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + + using var reader = conn.BeginBinaryExport(""" + COPY (values ('foo', 1), ('bar', null), (null, 2)) TO STDOUT BINARY + """); + while(reader.StartRow() != -1) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - - using (var writer = conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN BINARY")) - { - // Big value - triggers use of the direct write optimization - var data = new byte[conn.Settings.WriteBufferSize + 10]; - - writer.StartRow(); - writer.Write(data); - writer.StartRow(); - writer.Write(data); - } - } + string? col1 = null; + if (reader.IsNull) + reader.Skip(); + else + col1 = reader.Read(); + int? col2 = null; + if (reader.IsNull) + reader.Skip(); + else + col2 = reader.Read(); + + Assert.Throws(() => _ = reader.IsNull); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongTableDefinitionBinaryImport() + [Test] + public async Task ReadZeroSizedColumns() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + + using var reader = conn.BeginBinaryExport(""" + COPY (values (1, '', ''), (2, null, ''), (3, '', null)) TO STDOUT BINARY + """); + while(reader.StartRow() != -1) { - using (var conn = await OpenConnectionAsync()) - { - // Connection should be kept alive after PostgresException was triggered - Assert.Throws(() => conn.BeginBinaryImport("COPY table_is_not_exist (blob) FROM STDIN BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + int? col1 = null; + if (reader.IsNull) + reader.Skip(); + else + col1 = reader.Read(); + + string? col2 = null; + if (reader.IsNull) + reader.Skip(); + else + col2 = reader.Read(); + + string? col3 = null; + if (reader.IsNull) + reader.Skip(); + else + col3 = reader.Read(); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongFormatBinaryImport() + [Test] + public async Task ReadConverterResolverType() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + + using (var reader = conn.BeginBinaryExport(""" + COPY (values (NOW()), (NULL)) TO STDOUT BINARY + """)) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) + while (reader.StartRow() != -1) { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + DateTime? col1 = null; + if (reader.IsNull) + reader.Skip(); + else + col1 = reader.Read(); } } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongTableDefinitionBinaryExport() + using (var reader = conn.BeginBinaryExport(""" + COPY (values (NOW()), (NULL)) TO STDOUT BINARY + """)) { - using (var conn = await OpenConnectionAsync()) + while (reader.StartRow() != -1) { - // Connection should be kept alive after PostgresException was triggered - Assert.Throws(() => conn.BeginBinaryExport("COPY table_is_not_exist (blob) TO STDOUT BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + DateTimeOffset? col1 = null; + if (reader.IsNull) + reader.Skip(); + else + col1 = reader.Read(); } } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongFormatBinaryExport() + [Test] + public async Task StreamingRead() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + + var str = new string('a', PgReader.MaxPreparedTextReaderSize + 1); + var reader = conn.BeginBinaryExport($"""COPY (values ('{str}')) TO STDOUT BINARY"""); + while (reader.StartRow() != -1) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginBinaryExport($"COPY {table} (blob) TO STDOUT")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + using var _ = reader.Read(NpgsqlDbType.Text); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/661")] - [Ignore("Unreliable")] - public async Task UnexpectedExceptionBinaryImport() - { - if (IsMultiplexing) - return; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_format_binary_export() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginBinaryExport($"COPY {table} (blob) TO STDOUT")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); + [Test, NonParallelizable, IssueLink("https://github.com/npgsql/npgsql/issues/661")] + [Ignore("Unreliable")] + public async Task Unexpected_exception_binary_import() + { + if (IsMultiplexing) + return; - var data = new byte[conn.Settings.WriteBufferSize + 10]; + // Use a private data source since we terminate the connection below (affects database state) + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); - var writer = conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN BINARY"); + var data = new byte[conn.Settings.WriteBufferSize + 10]; - using (var conn2 = await OpenConnectionAsync()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); + var writer = conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN BINARY"); - Thread.Sleep(50); - Assert.That(() => - { - writer.StartRow(); - writer.Write(data); - writer.Dispose(); - }, Throws.Exception.TypeOf()); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } - } + using (var conn2 = await OpenConnectionAsync()) + conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/657")] - [Explicit] - public async Task ImportByteaMassive() + Thread.Sleep(50); + Assert.That(() => { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field BYTEA", out var table); + writer.StartRow(); + writer.Write(data); + writer.Dispose(); + }, Throws.Exception.TypeOf()); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - const int iterations = 10000; - var data = new byte[1024*1024]; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/657")] + [Explicit] + public async Task Import_bytea_massive() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field BYTEA"); - using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) - { - for (var i = 0; i < iterations; i++) - { - if (i%100 == 0) - Console.WriteLine("Iteration " + i); - writer.StartRow(); - writer.Write(data, NpgsqlDbType.Bytea); - } - } + const int iterations = 10000; + var data = new byte[1024*1024]; - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(iterations)); + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + for (var i = 0; i < iterations; i++) + { + if (i%100 == 0) + Console.WriteLine("Iteration " + i); + writer.StartRow(); + writer.Write(data, NpgsqlDbType.Bytea); } } - [Test] - public async Task ExportLongString() - { - const int iterations = 100; - using (var conn = await OpenConnectionAsync()) - { - var len = conn.Settings.WriteBufferSize; - await using var _ = await CreateTempTable(conn, "foo1 TEXT, foo2 TEXT, foo3 TEXT, foo4 TEXT, foo5 TEXT", out var table); - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} VALUES (@p, @p, @p, @p, @p)", conn)) - { - cmd.Parameters.AddWithValue("p", new string('x', len)); - for (var i = 0; i < iterations; i++) - await cmd.ExecuteNonQueryAsync(); - } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(iterations)); + } - using (var reader = conn.BeginBinaryExport($"COPY {table} (foo1, foo2, foo3, foo4, foo5) TO STDIN BINARY")) - { - for (var row = 0; row < iterations; row++) - { - Assert.That(reader.StartRow(), Is.EqualTo(5)); - for (var col = 0; col < 5; col++) - Assert.That(reader.Read().Length, Is.EqualTo(len)); - } - } - } + [Test] + public async Task Export_long_string() + { + const int iterations = 100; + using var conn = await OpenConnectionAsync(); + var len = conn.Settings.WriteBufferSize; + var table = await CreateTempTable(conn, "foo1 TEXT, foo2 TEXT, foo3 TEXT, foo4 TEXT, foo5 TEXT"); + using (var cmd = new NpgsqlCommand($"INSERT INTO {table} VALUES (@p, @p, @p, @p, @p)", conn)) + { + cmd.Parameters.AddWithValue("p", new string('x', len)); + for (var i = 0; i < iterations; i++) + await cmd.ExecuteNonQueryAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1134")] - public async Task ReadBitString() + using (var reader = conn.BeginBinaryExport($"COPY {table} (foo1, foo2, foo3, foo4, foo5) TO STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) + int row, col = 0; + for (row = 0; row < iterations; row++) { - await using var _ = await GetTempTableName(conn, out var table); - - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table} (bits BIT(3), bitarray BIT(3)[]); -INSERT INTO {table} (bits, bitarray) VALUES (B'101', ARRAY[B'101', B'111'])"); - - using (var reader = conn.BeginBinaryExport($"COPY {table} (bits, bitarray) TO STDIN BINARY")) + Assert.That(reader.StartRow(), Is.EqualTo(5)); + for (col = 0; col < 5; col++) { - reader.StartRow(); - Assert.That(reader.Read(), Is.EqualTo(new BitArray(new[] { true, false, true }))); - Assert.That(reader.Read(), Is.EqualTo(new[] - { - new BitArray(new[] { true, false, true }), - new BitArray(new[] { true, true, true }) - })); + var str = reader.Read(); + Assert.That(str.Length, Is.EqualTo(len)); +#if NET6_0_OR_GREATER + Assert.True(str.AsSpan().IndexOfAnyExcept('x') is -1); +#endif } } + Assert.That(row, Is.EqualTo(100)); + Assert.That(col, Is.EqualTo(5)); } + } - [Test] - public async Task Array() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1134")] + public async Task Read_bit_string() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (bits BIT(11), bitvector BIT(11), bitarray BIT(3)[]); +INSERT INTO {table} (bits, bitvector, bitarray) VALUES (B'00000001101', B'00000001101', ARRAY[B'101', B'111'])"); + + using var reader = conn.BeginBinaryExport($"COPY {table} (bits, bitvector, bitarray) TO STDIN BINARY"); + reader.StartRow(); + Assert.That(reader.Read(), Is.EqualTo(new BitArray(new[] { false, false, false, false, false, false, false, true, true, false, true }))); + Assert.That(reader.Read(), Is.EqualTo(new BitVector32(0b00000001101000000000000000000000))); + Assert.That(reader.Read(), Is.EqualTo(new[] { - var expected = new[] { 8 }; + new BitArray(new[] { true, false, true }), + new BitArray(new[] { true, true, true }) + })); + } - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "arr INTEGER[]", out var table); + [Test] + public async Task Array() + { + var expected = new[] { 8 }; - using (var writer = conn.BeginBinaryImport($"COPY {table} (arr) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(expected); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "arr INTEGER[]"); - using (var reader = conn.BeginBinaryExport($"COPY {table} (arr) TO STDIN BINARY")) - { - reader.StartRow(); - Assert.That(reader.Read(), Is.EqualTo(expected)); - } - } + using (var writer = conn.BeginBinaryImport($"COPY {table} (arr) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(expected); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - [Test] - public async Task Enum() + using (var reader = conn.BeginBinaryExport($"COPY {table} (arr) TO STDIN BINARY")) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: connection-specific mapping"); + reader.StartRow(); + Assert.That(reader.Read(), Is.EqualTo(expected)); + } + } - using var conn = await OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(); + [Test] + public async Task Enum() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (mymood mood, mymoodarr mood[])"); + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - using (var writer = conn.BeginBinaryImport("COPY data (mymood, mymoodarr) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(Mood.Happy); - writer.Write(new[] { Mood.Happy }); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } + var table = await CreateTempTable(connection, $"mymood {type}, mymoodarr {type}[]"); - using (var reader = conn.BeginBinaryExport("COPY data (mymood, mymoodarr) TO STDIN BINARY")) - { - reader.StartRow(); - Assert.That(reader.Read(), Is.EqualTo(Mood.Happy)); - Assert.That(reader.Read(), Is.EqualTo(new[] { Mood.Happy })); - } + await using (var writer = await connection.BeginBinaryImportAsync($"COPY {table} (mymood, mymoodarr) FROM STDIN BINARY")) + { + await writer.StartRowAsync(); + await writer.WriteAsync(Mood.Happy); + await writer.WriteAsync(new[] { Mood.Happy }); + var rowsWritten = await writer.CompleteAsync(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - enum Mood { Sad, Ok, Happy }; - - [Test] - public async Task Read_NullAsNullable_Succeeds() + await using (var reader = await connection.BeginBinaryExportAsync($"COPY {table} (mymood, mymoodarr) TO STDIN BINARY")) { - using var connection = await OpenConnectionAsync(); - using var exporter = connection.BeginBinaryExport("COPY (SELECT NULL::int) TO STDOUT BINARY"); + await reader.StartRowAsync(); + Assert.That(reader.Read(), Is.EqualTo(Mood.Happy)); + Assert.That(reader.Read(), Is.EqualTo(new[] { Mood.Happy })); + } + } - exporter.StartRow(); + enum Mood { Sad, Ok, Happy }; - Assert.That(exporter.Read(), Is.Null); - } + [Test] + public async Task Read_null_as_nullable() + { + using var connection = await OpenConnectionAsync(); + using var exporter = connection.BeginBinaryExport("COPY (SELECT NULL::int) TO STDOUT BINARY"); - [Test] - public async Task Read_NullAsValue_ThrowsInvalidCastException() - { - using var connection = await OpenConnectionAsync(); - using var exporter = connection.BeginBinaryExport("COPY (SELECT NULL::int) TO STDOUT BINARY"); + exporter.StartRow(); - exporter.StartRow(); + Assert.That(exporter.Read(), Is.Null); + } - Assert.Throws(() => exporter.Read()); - } + [Test] + public async Task Read_null_as_non_nullable_throws() + { + using var connection = await OpenConnectionAsync(); + using var exporter = connection.BeginBinaryExport("COPY (SELECT NULL::int) TO STDOUT BINARY"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1440")] - public async Task ErrorDuringImport() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INT, CONSTRAINT uq UNIQUE(foo)", out var table); + exporter.StartRow(); - var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY"); - writer.StartRow(); - writer.Write(8); - writer.StartRow(); - writer.Write(8); - Assert.That(() => writer.Complete(), Throws.Exception - .TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("23505")); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + Assert.Throws(() => exporter.Read()); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1440")] + public async Task Error_during_import() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT UNIQUE"); + var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY"); + writer.StartRow(); + writer.Write(8); + writer.StartRow(); + writer.Write(8); + Assert.That(() => writer.Complete(), Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UniqueViolation)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test] - public async Task ImportCannotWriteAfterCommit() + [Test] + public async Task Import_cannot_write_after_commit() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT"); + try { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INT", out var table); - try - { - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(8); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - writer.StartRow(); - Assert.Fail("StartRow should have thrown"); - } - } - catch (InvalidOperationException) - { - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); - } - } + using var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY"); + writer.StartRow(); + writer.Write(8); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); + writer.StartRow(); + Assert.Fail("StartRow should have thrown"); } - - [Test] - public async Task ImportCommitInMiddleOfRow() + catch (InvalidOperationException) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INT, bar TEXT", out var table); - - try - { - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo, bar) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(8); - writer.Write("hello"); - writer.StartRow(); - writer.Write(9); - writer.Complete(); - Assert.Fail("Commit should have thrown"); - } - } - catch (InvalidOperationException) - { - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } - } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); } + } - [Test] - public async Task ImportExceptionDoesNotCommit() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INT", out var table); + [Test] + public async Task Import_commit_in_middle_of_row() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT, bar TEXT"); - try - { - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(8); - throw new Exception("FOO"); - } - } - catch (Exception e) when (e.Message == "FOO") - { - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.Zero); - } - } + try + { + using var writer = conn.BeginBinaryImport($"COPY {table} (foo, bar) FROM STDIN BINARY"); + writer.StartRow(); + writer.Write(8); + writer.Write("hello"); + writer.StartRow(); + writer.Write(9); + writer.Complete(); + Assert.Fail("Commit should have thrown"); } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2347")] - public async Task Write_ColumnOutOfBounds_ThrowsInvalidOperationException() + catch (InvalidOperationException) { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 INTEGER", out var table); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } + } - using var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int2) FROM STDIN BINARY"); - StateAssertions(conn); + [Test] + public async Task Import_exception_does_not_commit() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT"); + try + { + using var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY"); writer.StartRow(); - writer.Write("Hello"); - writer.Write(8, NpgsqlDbType.Smallint); + writer.Write(8); + throw new Exception("FOO"); + } + catch (Exception e) when (e.Message == "FOO") + { + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.Zero); + } + } - Assert.Throws(() => writer.Write("I should not be here")); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2347")] + public async Task Write_column_out_of_bounds_throws() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 INTEGER"); - writer.StartRow(); - writer.Write("Hello"); - writer.Write(8, NpgsqlDbType.Smallint); + using var writer = conn.BeginBinaryImport($"COPY {table} (field_text, field_int2) FROM STDIN BINARY"); + StateAssertions(conn); - Assert.Throws(() => writer.Write("I should not be here", NpgsqlDbType.Text)); + writer.StartRow(); + writer.Write("Hello"); + writer.Write(8, NpgsqlDbType.Smallint); - writer.StartRow(); - writer.Write("Hello"); - writer.Write(8, NpgsqlDbType.Smallint); + Assert.Throws(() => writer.Write("I should not be here")); - Assert.Throws(() => writer.Write("I should not be here", "text")); - Assert.Throws(() => writer.WriteRow("Hello", 8, "I should not be here")); - } + writer.StartRow(); + writer.Write("Hello"); + writer.Write(8, NpgsqlDbType.Smallint); + + Assert.Throws(() => writer.Write("I should not be here", NpgsqlDbType.Text)); + + writer.StartRow(); + writer.Write("Hello"); + writer.Write(8, NpgsqlDbType.Smallint); + + Assert.Throws(() => writer.Write("I should not be here", "text")); + Assert.Throws(() => writer.WriteRow("Hello", 8, "I should not be here")); + } - [Test] - public async Task CancelRawBinaryExportWhenNotConsumedAndThenDispose() + [Test] + public async Task Cancel_raw_binary_export_when_not_consumed_and_then_Dispose() + { + await using var conn = await OpenConnectionAsync(); + await using (var tx = await conn.BeginTransactionAsync()) { - await using var conn = await OpenConnectionAsync(); // This must be large enough to cause Postgres to queue up CopyData messages. var stream = conn.BeginRawBinaryCopy("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); var buffer = new byte[32]; await stream.ReadAsync(buffer, 0, buffer.Length); stream.Cancel(); Assert.DoesNotThrowAsync(async () => await stream.DisposeAsync()); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); + } - [Test] - public async Task CancelBinaryExportWhenNotConsumedAndThenDispose() + [Test] + public async Task Cancel_binary_export_when_not_consumed_and_then_Dispose() + { + await using var conn = await OpenConnectionAsync(); + await using (var tx = await conn.BeginTransactionAsync()) { - await using var conn = await OpenConnectionAsync(); // This must be large enough to cause Postgres to queue up CopyData messages. var exporter = conn.BeginBinaryExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); await exporter.StartRowAsync(); await exporter.ReadAsync(); exporter.Cancel(); Assert.DoesNotThrowAsync(async () => await exporter.DisposeAsync()); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); + } - #endregion + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/5110")] + public async Task Binary_copy_read_char_column() + { + await using var conn = await OpenConnectionAsync(); + var tableName = await CreateTempTable(conn, "id serial, value char"); - #region Text + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $"INSERT INTO {tableName}(value) VALUES ('d'), ('s')"; + await cmd.ExecuteNonQueryAsync(); - [Test] - public async Task TextImport() + await using var export = await conn.BeginBinaryExportAsync($"COPY {tableName}(id, value) TO STDOUT (FORMAT BINARY)"); + while (await export.StartRowAsync() != -1) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - const string line = "HELLO\t1\n"; - - // Short write - var writer = conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); - StateAssertions(conn); - writer.Write(line); - writer.Dispose(); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table} WHERE field_int4=1"), Is.EqualTo(1)); - Assert.That(() => writer.Write(line), Throws.Exception.TypeOf()); - await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); - - // Long (multi-buffer) write - var iterations = NpgsqlWriteBuffer.MinimumSize/line.Length + 100; - writer = conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); - for (var i = 0; i < iterations; i++) - writer.Write(line); - writer.Dispose(); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table} WHERE field_int4=1"), Is.EqualTo(iterations)); - } + var id = export.Read(); + var value = export.Read(); } + } - [Test] - public async Task CancelTextImport() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); + #endregion - var writer = (NpgsqlCopyTextWriter)conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); - writer.Write("HELLO\t1\n"); - writer.Cancel(); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } - } + #region Text - [Test] - public async Task TextImportEmpty() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); + [Test] + public async Task Text_import([Values(false, true)] bool async) + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + const string line = "HELLO\t1\n"; + + // Short write + var writer = async + ? await conn.BeginTextImportAsync($"COPY {table} (field_text, field_int4) FROM STDIN") + : conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); + StateAssertions(conn); + writer.Write(line); + writer.Dispose(); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table} WHERE field_int4=1"), Is.EqualTo(1)); + Assert.That(() => writer.Write(line), Throws.Exception.TypeOf()); + await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + + // Long (multi-buffer) write + var iterations = NpgsqlWriteBuffer.MinimumSize/line.Length + 100; + writer = async + ? await conn.BeginTextImportAsync($"COPY {table} (field_text, field_int4) FROM STDIN") + : conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); + for (var i = 0; i < iterations; i++) + writer.Write(line); + writer.Dispose(); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table} WHERE field_int4=1"), Is.EqualTo(iterations)); + } - using (conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN")) - { - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } + [Test] + public async Task Cancel_text_import() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + await using (var tx = await conn.BeginTransactionAsync()) + { + var writer = (NpgsqlCopyTextWriter)conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); + writer.Write("HELLO\t1\n"); + writer.Cancel(); } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } + + [Test] + public async Task Text_import_empty() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); - [Test] - public async Task TextExport() + using (conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN")) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); + } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } + + [Test] + public async Task Text_export([Values(false, true)] bool async) + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER); INSERT INTO {table} (field_text, field_int4) VALUES ('HELLO', 1)"); - var chars = new char[30]; - - // Short read - var reader = conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); - StateAssertions(conn); - Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(8)); - Assert.That(new string(chars, 0, 8), Is.EqualTo("HELLO\t1\n")); - Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(0)); - Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(0)); - reader.Dispose(); - Assert.That(() => reader.Read(chars, 0, chars.Length), Throws.Exception.TypeOf()); - await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); - } - } + var chars = new char[30]; + + // Short read + var reader = async + ? await conn.BeginTextExportAsync($"COPY {table} (field_text, field_int4) TO STDIN") + : conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); + StateAssertions(conn); + Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(8)); + Assert.That(new string(chars, 0, 8), Is.EqualTo("HELLO\t1\n")); + Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(0)); + Assert.That(reader.Read(chars, 0, chars.Length), Is.EqualTo(0)); + reader.Dispose(); + Assert.That(() => reader.Read(chars, 0, chars.Length), Throws.Exception.TypeOf()); + await conn.ExecuteNonQueryAsync($"TRUNCATE {table}"); + } - [Test] - public async Task DisposeInMiddleOfTextExport() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); + [Test] + public async Task Dispose_in_middle_of_text_export() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} (field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER); INSERT INTO {table} (field_text, field_int4) VALUES ('HELLO', 1)"); - var reader = conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); - reader.Dispose(); - // Make sure the connection is still OK - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + var reader = conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); + reader.Dispose(); + // Make sure the connection is still OK + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongTableDefinitionTextImport() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - Assert.Throws(() => conn.BeginTextImport("COPY table_is_not_exist (blob) FROM STDIN")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_table_definition_text_import() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + Assert.Throws(() => conn.BeginTextImport("COPY table_is_not_exist (blob) FROM STDIN")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongFormatTextImport() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginTextImport($"COPY {table} (blob) FROM STDIN BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_format_text_import() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginTextImport($"COPY {table} (blob) FROM STDIN BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongTableDefinitionTextExport() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - Assert.Throws(() => conn.BeginTextExport("COPY table_is_not_exist (blob) TO STDOUT")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_table_definition_text_export() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + Assert.Throws(() => conn.BeginTextExport("COPY table_is_not_exist (blob) TO STDOUT")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] - public async Task WrongFormatTextExport() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "blob BYTEA", out var table); - Assert.Throws(() => conn.BeginTextExport($"COPY {table} (blob) TO STDOUT BINARY")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] + public async Task Wrong_format_text_export() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "blob BYTEA"); + Assert.Throws(() => conn.BeginTextExport($"COPY {table} (blob) TO STDOUT BINARY")); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test] - public async Task CancelTextExportWhenNotConsumedAndThenDispose() + [Test] + public async Task Cancel_text_export_when_not_consumed_and_then_Dispose() + { + await using var conn = await OpenConnectionAsync(); + await using (var tx = await conn.BeginTransactionAsync()) { - await using var conn = await OpenConnectionAsync(); // This must be large enough to cause Postgres to queue up CopyData messages. var reader = (NpgsqlCopyTextReader) conn.BeginTextExport("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT"); var buffer = new char[32]; await reader.ReadAsync(buffer, 0, buffer.Length); reader.Cancel(); Assert.DoesNotThrow(reader.Dispose); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); } + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1), "The connection is still OK"); + } - #endregion + #endregion - #region Other + #region Other - [Test, Description("Starts a transaction before a COPY, testing that prepended messages are handled well")] - public async Task PrependedMessages() - { - using (var conn = await OpenConnectionAsync()) - { - conn.BeginTransaction(); - await TextImport(); - } + [Test, Description("Starts a transaction before a COPY, testing that prepended messages are handled well")] + public async Task Prepended_messages() + { + using var conn = await OpenConnectionAsync(); + conn.BeginTransaction(); + await Text_import(async: false); + } + + [Test] + public async Task Undefined_table_throws() + { + using var conn = await OpenConnectionAsync(); + Assert.That(() => conn.BeginBinaryImport("COPY undefined_table (field_text, field_int2) FROM STDIN BINARY"), + Throws.Exception + .TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedTable) + ); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/621")] + public async Task Close_during_copy_throws() + { + // TODO: Check no broken connections were returned to the pool + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); } - [Test] - public async Task UndefinedTable() - { - using (var conn = await OpenConnectionAsync()) - Assert.That(() => conn.BeginBinaryImport("COPY undefined_table (field_text, field_int2) FROM STDIN BINARY"), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("42P01") - ); + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginBinaryExport($"COPY {table} (field_text, field_int2) TO STDIN BINARY"); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/621")] - public async Task CloseDuringCopy() - { - // TODO: Check no broken connections were returned to the pool - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginBinaryImport($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); - } + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); + } - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginBinaryExport($"COPY {table} (field_text, field_int2) TO STDIN BINARY"); - } + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY"); + } - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) FROM STDIN BINARY"); - } + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); + } - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginRawBinaryCopy($"COPY {table} (field_text, field_int4) TO STDIN BINARY"); - } + using (var conn = await OpenConnectionAsync()) { + var table = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER"); + conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); + } + } - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginTextImport($"COPY {table} (field_text, field_int4) FROM STDIN"); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/994")] + public async Task Non_ascii_column_name() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "non_ascii_éè TEXT"); + using (conn.BeginBinaryImport($"COPY {table} (non_ascii_éè) FROM STDIN BINARY")) { } + } - using (var conn = await OpenConnectionAsync()) { - await using var _ = await CreateTempTable(conn, "field_text TEXT, field_int2 SMALLINT, field_int4 INTEGER", out var table); - conn.BeginTextExport($"COPY {table} (field_text, field_int4) TO STDIN"); - } - } + [Test, IssueLink("https://stackoverflow.com/questions/37431054/08p01-insufficient-data-left-in-message-for-nullable-datetime/37431464")] + public async Task Write_null_values() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo1 INT, foo2 UUID, foo3 INT, foo4 UUID"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/994")] - public async Task NonAsciiColumnName() + using (var writer = conn.BeginBinaryImport($"COPY {table} (foo1, foo2, foo3, foo4) FROM STDIN BINARY")) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "non_ascii_éè TEXT", out var table); - using (conn.BeginBinaryImport($"COPY {table} (non_ascii_éè) FROM STDIN BINARY")) { } - } + writer.StartRow(); + writer.Write(DBNull.Value, NpgsqlDbType.Integer); + writer.Write(null, NpgsqlDbType.Uuid); + writer.Write(DBNull.Value); + writer.Write((string?)null); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(1)); } - - [Test, IssueLink("https://stackoverflow.com/questions/37431054/08p01-insufficient-data-left-in-message-for-nullable-datetime/37431464")] - public async Task WriteNullValues() + using (var cmd = new NpgsqlCommand($"SELECT foo1,foo2,foo3,foo4 FROM {table}", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo1 INT, foo2 UUID, foo3 INT, foo4 UUID", out var table); - - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo1, foo2, foo3, foo4) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(DBNull.Value, NpgsqlDbType.Integer); - writer.Write((string?)null, NpgsqlDbType.Uuid); - writer.Write(DBNull.Value); - writer.Write((string?)null); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(1)); - } - using (var cmd = new NpgsqlCommand($"SELECT foo1,foo2,foo3,foo4 FROM {table}", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - Assert.That(reader.Read(), Is.True); - for (var i = 0; i < reader.FieldCount; i++) - Assert.That(reader.IsDBNull(i), Is.True); - } - } + Assert.That(reader.Read(), Is.True); + for (var i = 0; i < reader.FieldCount; i++) + Assert.That(reader.IsDBNull(i), Is.True); } + } - [Test] - public async Task WriteDifferentTypes() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INT, bar INT[]", out var table); + [Test] + public async Task Write_different_types() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT, bar INT[]"); - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo, bar) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(3.0, NpgsqlDbType.Integer); - writer.Write((object)new[] { 1, 2, 3 }); - writer.StartRow(); - writer.Write(3, NpgsqlDbType.Integer); - writer.Write((object)new List { 4, 5, 6 }); - var rowsWritten = writer.Complete(); - Assert.That(rowsWritten, Is.EqualTo(2)); - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(2)); - } + using (var writer = conn.BeginBinaryImport($"COPY {table} (foo, bar) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(3.0, NpgsqlDbType.Integer); + writer.Write(new[] { 1, 2, 3 }); + writer.StartRow(); + writer.Write(3, NpgsqlDbType.Integer); + writer.Write((object)new List { 4, 5, 6 }); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(2)); } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(2)); + } - [Test, Description("Tests nested binding scopes in multiplexing")] - public async Task WithinTransaction() + [Test, Description("Tests nested binding scopes in multiplexing")] + public async Task Within_transaction() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT"); + + using (var tx = conn.BeginTransaction()) + using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INT", out var table); + writer.StartRow(); + writer.Write(1); + writer.Dispose(); + // Don't complete + await tx.CommitAsync(); + } - using (var tx = conn.BeginTransaction()) - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(1); - writer.Dispose(); - // Don't complete - await tx.CommitAsync(); - } + using (var tx = conn.BeginTransaction()) + using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write(2); + writer.Complete(); + // Don't commit + } - using (var tx = conn.BeginTransaction()) + using (var tx = conn.BeginTransaction()) + { using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) { writer.StartRow(); - writer.Write(2); + writer.Write(3); writer.Complete(); - // Don't commit } + await tx.CommitAsync(); + } - using (var tx = conn.BeginTransaction()) - { - using (var writer = conn.BeginBinaryImport($"COPY {table} (foo) FROM STDIN BINARY")) - { - writer.StartRow(); - writer.Write(3); - writer.Complete(); - } - await tx.CommitAsync(); - } + Assert.That(async () => await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + Assert.That(async () => await conn.ExecuteScalarAsync($"SELECT foo FROM {table}"), Is.EqualTo(3)); + } - Assert.That(async () => await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); - Assert.That(async () => await conn.ExecuteScalarAsync($"SELECT foo FROM {table}"), Is.EqualTo(3)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4199")] + public async Task Copy_from_is_not_supported_in_regular_command_execution() + { + // Run in a separate pool to protect other queries in multiplexing + // because we're going to break the connection on CopyInResponse + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT"); - #endregion + Assert.That(() => conn.ExecuteNonQuery($@"COPY {table} (foo) FROM stdin"), Throws.Exception.TypeOf()); + } - #region Utils + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4974")] + public async Task Copy_to_is_not_supported_in_regular_command_execution() + { + // Run in a separate pool to protect other queries in multiplexing + // because we're going to break the connection on CopyInResponse + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INT"); - /// - /// Checks that the connector state is properly managed for COPY operations - /// - void StateAssertions(NpgsqlConnection conn) - { - Assert.That(conn.Connector!.State, Is.EqualTo(ConnectorState.Copy)); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception.TypeOf()); - } + Assert.That(() => conn.ExecuteNonQuery($@"COPY {table} (foo) TO stdin"), Throws.Exception.TypeOf()); + } - #endregion + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5209")] + [Platform(Exclude = "MacOsX", Reason = "Write might not throw an exception")] + public async Task RawBinaryCopy_write_nre([Values] bool async) + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + var server = await postmasterMock.WaitForServerConnection(); + await server + .WriteCopyInResponse(isBinary: true) + .FlushAsync(); + + await using var stream = await conn.BeginRawBinaryCopyAsync("COPY SomeTable (field_text, field_int4) FROM STDIN"); + server.Close(); + var value = Encoding.UTF8.GetBytes(new string('a', conn.Settings.WriteBufferSize * 2)); + if (async) + Assert.ThrowsAsync(async () => await stream.WriteAsync(value)); + else + Assert.Throws(() => stream.Write(value)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + #endregion + + #region Utils - public CopyTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + /// + /// Checks that the connector state is properly managed for COPY operations + /// + void StateAssertions(NpgsqlConnection conn) + { + Assert.That(conn.Connector!.State, Is.EqualTo(ConnectorState.Copy)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception.TypeOf()); } + + #endregion + + public CopyTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/DataAdapterTests.cs b/test/Npgsql.Tests/DataAdapterTests.cs index c479b75c36..4b413409d7 100644 --- a/test/Npgsql.Tests/DataAdapterTests.cs +++ b/test/Npgsql.Tests/DataAdapterTests.cs @@ -5,608 +5,543 @@ using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class DataAdapterTests : TestBase { - public class DataAdapterTests : TestBase + [Test] + public async Task DataAdapter_SelectCommand() { - [Test] - public async Task UseDataAdapter() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT 1", conn)) - { - var da = new NpgsqlDataAdapter(); - da.SelectCommand = command; - var ds = new DataSet(); - da.Fill(ds); - //ds.WriteXml("TestUseDataAdapter.xml"); - } - } - - [Test] - public async Task UseDataAdapterNpgsqlConnectionConstructor() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT 1", conn)) - { - command.Connection = conn; - var da = new NpgsqlDataAdapter(command); - var ds = new DataSet(); - da.Fill(ds); - //ds.WriteXml("TestUseDataAdapterNpgsqlConnectionConstructor.xml"); - } - } - - [Test] - public async Task UseDataAdapterStringNpgsqlConnectionConstructor() - { - using (var conn = await OpenConnectionAsync()) - { - var da = new NpgsqlDataAdapter("SELECT 1", conn); - var ds = new DataSet(); - da.Fill(ds); - //ds.WriteXml("TestUseDataAdapterStringNpgsqlConnectionConstructor.xml"); - } - } - - [Test] - public void UseDataAdapterStringStringConstructor() - { - var da = new NpgsqlDataAdapter("SELECT 1", ConnectionString); - var ds = new DataSet(); - da.Fill(ds); - //ds.WriteXml("TestUseDataAdapterStringStringConstructor.xml"); - } - - [Test] - public void UseDataAdapterStringStringConstructor2() - { - var da = new NpgsqlDataAdapter("SELECT 1", ConnectionString); - var ds = new DataSet(); - da.Fill(ds); - //ds.WriteXml("TestUseDataAdapterStringStringConstructor2.xml"); - } - - [Test] - [MonoIgnore("Bug in mono, submitted pull request: https://github.com/mono/mono/pull/1172")] - public async Task InsertWithDataSet() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT 1", conn); + var da = new NpgsqlDataAdapter(); + da.SelectCommand = command; + var ds = new DataSet(); + da.Fill(ds); + //ds.WriteXml("TestUseDataAdapter.xml"); + } - da.InsertCommand = new NpgsqlCommand($"INSERT INTO {table} (field_int2, field_timestamp, field_numeric) VALUES (:a, :b, :c)", conn); + [Test] + public async Task DataAdapter_NpgsqlCommand_in_constructor() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT 1", conn); + command.Connection = conn; + var da = new NpgsqlDataAdapter(command); + var ds = new DataSet(); + da.Fill(ds); + //ds.WriteXml("TestUseDataAdapterNpgsqlConnectionConstructor.xml"); + } - da.InsertCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); - da.InsertCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime)); - da.InsertCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); + [Test] + public async Task DataAdapter_string_command_in_constructor() + { + using var conn = await OpenConnectionAsync(); + var da = new NpgsqlDataAdapter("SELECT 1", conn); + var ds = new DataSet(); + da.Fill(ds); + //ds.WriteXml("TestUseDataAdapterStringNpgsqlConnectionConstructor.xml"); + } - da.InsertCommand.Parameters[0].Direction = ParameterDirection.Input; - da.InsertCommand.Parameters[1].Direction = ParameterDirection.Input; - da.InsertCommand.Parameters[2].Direction = ParameterDirection.Input; + [Test] + public void DataAdapter_connection_string_in_constructor() + { + var da = new NpgsqlDataAdapter("SELECT 1", ConnectionString); + var ds = new DataSet(); + da.Fill(ds); + //ds.WriteXml("TestUseDataAdapterStringStringConstructor.xml"); + } - da.InsertCommand.Parameters[0].SourceColumn = "field_int2"; - da.InsertCommand.Parameters[1].SourceColumn = "field_timestamp"; - da.InsertCommand.Parameters[2].SourceColumn = "field_numeric"; + [Test] + public async Task Insert_with_DataSet() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); - da.Fill(ds); + da.InsertCommand = new NpgsqlCommand($"INSERT INTO {table} (field_int2, field_timestamp, field_numeric) VALUES (:a, :b, :c)", conn); - var dt = ds.Tables[0]; - var dr = dt.NewRow(); - dr["field_int2"] = 4; - dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); - dr["field_numeric"] = 7.3M; - dt.Rows.Add(dr); + da.InsertCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); + da.InsertCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime2)); + da.InsertCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); - var ds2 = ds.GetChanges()!; - da.Update(ds2); + da.InsertCommand.Parameters[0].Direction = ParameterDirection.Input; + da.InsertCommand.Parameters[1].Direction = ParameterDirection.Input; + da.InsertCommand.Parameters[2].Direction = ParameterDirection.Input; - ds.Merge(ds2); - ds.AcceptChanges(); + da.InsertCommand.Parameters[0].SourceColumn = "field_int2"; + da.InsertCommand.Parameters[1].SourceColumn = "field_timestamp"; + da.InsertCommand.Parameters[2].SourceColumn = "field_numeric"; - var dr2 = new NpgsqlCommand($"SELECT field_int2, field_numeric, field_timestamp FROM {table}", conn).ExecuteReader(); - dr2.Read(); + da.Fill(ds); - Assert.AreEqual(4, dr2[0]); - Assert.AreEqual(7.3000000M, dr2[1]); - dr2.Close(); - } - } + var dt = ds.Tables[0]; + var dr = dt.NewRow(); + dr["field_int2"] = 4; + dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); + dr["field_numeric"] = 7.3M; + dt.Rows.Add(dr); - [Test] - public async Task DataAdapterUpdateReturnValue() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); - - da.InsertCommand = new NpgsqlCommand($@"INSERT INTO {table} (field_int2, field_timestamp, field_numeric) VALUES (:a, :b, :c)", conn); - - da.InsertCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); - da.InsertCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime)); - da.InsertCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); - - da.InsertCommand.Parameters[0].Direction = ParameterDirection.Input; - da.InsertCommand.Parameters[1].Direction = ParameterDirection.Input; - da.InsertCommand.Parameters[2].Direction = ParameterDirection.Input; - - da.InsertCommand.Parameters[0].SourceColumn = "field_int2"; - da.InsertCommand.Parameters[1].SourceColumn = "field_timestamp"; - da.InsertCommand.Parameters[2].SourceColumn = "field_numeric"; - - da.Fill(ds); - - var dt = ds.Tables[0]; - var dr = dt.NewRow(); - dr["field_int2"] = 4; - dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); - dr["field_numeric"] = 7.3M; - dt.Rows.Add(dr); - - dr = dt.NewRow(); - dr["field_int2"] = 4; - dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); - dr["field_numeric"] = 7.3M; - dt.Rows.Add(dr); - - var ds2 = ds.GetChanges()!; - var daupdate = da.Update(ds2); - - Assert.AreEqual(2, daupdate); - } - } - - [Test] - [Ignore("")] - public async Task DataAdapterUpdateReturnValue2() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var cmd = conn.CreateCommand(); - var da = new NpgsqlDataAdapter($"select * from {table}", conn); - var cb = new NpgsqlCommandBuilder(da); - var ds = new DataSet(); - da.Fill(ds); - - //## Insert a new row with id = 1 - ds.Tables[0].Rows.Add(0.4, 0.5); - da.Update(ds); - - //## change id from 1 to 2 - cmd.CommandText = $"update {table} set field_float4 = 0.8"; - cmd.ExecuteNonQuery(); - - //## change value to newvalue - ds.Tables[0].Rows[0][1] = 0.7; - //## update should fail, and make a DBConcurrencyException - var count = da.Update(ds); - //## count is 1, even if the isn't updated in the database - Assert.AreEqual(0, count); - } - } - - [Test] - public async Task FillWithEmptyResultset() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"SELECT field_serial, field_int2, field_timestamp, field_numeric FROM {table} WHERE field_serial = -1", conn); - - da.Fill(ds); - - Assert.AreEqual(1, ds.Tables.Count); - Assert.AreEqual(4, ds.Tables[0].Columns.Count); - Assert.AreEqual("field_serial", ds.Tables[0].Columns[0].ColumnName); - Assert.AreEqual("field_int2", ds.Tables[0].Columns[1].ColumnName); - Assert.AreEqual("field_timestamp", ds.Tables[0].Columns[2].ColumnName); - Assert.AreEqual("field_numeric", ds.Tables[0].Columns[3].ColumnName); - } - } - - [Test] - [Ignore("")] - public async Task FillAddWithKey() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"select field_serial, field_int2, field_timestamp, field_numeric from {table}", conn); - - da.MissingSchemaAction = MissingSchemaAction.AddWithKey; - da.Fill(ds); - - var field_serial = ds.Tables[0].Columns[0]; - var field_int2 = ds.Tables[0].Columns[1]; - var field_timestamp = ds.Tables[0].Columns[2]; - var field_numeric = ds.Tables[0].Columns[3]; - - Assert.IsFalse(field_serial.AllowDBNull); - Assert.IsTrue(field_serial.AutoIncrement); - Assert.AreEqual("field_serial", field_serial.ColumnName); - Assert.AreEqual(typeof(int), field_serial.DataType); - Assert.AreEqual(0, field_serial.Ordinal); - Assert.IsTrue(field_serial.Unique); - - Assert.IsTrue(field_int2.AllowDBNull); - Assert.IsFalse(field_int2.AutoIncrement); - Assert.AreEqual("field_int2", field_int2.ColumnName); - Assert.AreEqual(typeof(short), field_int2.DataType); - Assert.AreEqual(1, field_int2.Ordinal); - Assert.IsFalse(field_int2.Unique); - - Assert.IsTrue(field_timestamp.AllowDBNull); - Assert.IsFalse(field_timestamp.AutoIncrement); - Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); - Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); - Assert.AreEqual(2, field_timestamp.Ordinal); - Assert.IsFalse(field_timestamp.Unique); - - Assert.IsTrue(field_numeric.AllowDBNull); - Assert.IsFalse(field_numeric.AutoIncrement); - Assert.AreEqual("field_numeric", field_numeric.ColumnName); - Assert.AreEqual(typeof(decimal), field_numeric.DataType); - Assert.AreEqual(3, field_numeric.Ordinal); - Assert.IsFalse(field_numeric.Unique); - } - } - - [Test] - public async Task FillAddColumns() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"SELECT field_serial, field_int2, field_timestamp, field_numeric FROM {table}", conn); - - da.MissingSchemaAction = MissingSchemaAction.Add; - da.Fill(ds); - - var field_serial = ds.Tables[0].Columns[0]; - var field_int2 = ds.Tables[0].Columns[1]; - var field_timestamp = ds.Tables[0].Columns[2]; - var field_numeric = ds.Tables[0].Columns[3]; - - Assert.AreEqual("field_serial", field_serial.ColumnName); - Assert.AreEqual(typeof(int), field_serial.DataType); - Assert.AreEqual(0, field_serial.Ordinal); - - Assert.AreEqual("field_int2", field_int2.ColumnName); - Assert.AreEqual(typeof(short), field_int2.DataType); - Assert.AreEqual(1, field_int2.Ordinal); - - Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); - Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); - Assert.AreEqual(2, field_timestamp.Ordinal); - - Assert.AreEqual("field_numeric", field_numeric.ColumnName); - Assert.AreEqual(typeof(decimal), field_numeric.DataType); - Assert.AreEqual(3, field_numeric.Ordinal); - } - } - - [Test] - [MonoIgnore("Bug in mono, submitted pull request: https://github.com/mono/mono/pull/1172")] - public async Task UpdateLettingNullFieldValue() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); + var ds2 = ds.GetChanges()!; + da.Update(ds2); - var command = new NpgsqlCommand($"INSERT INTO {table} (field_int2) VALUES (2)", conn); - command.ExecuteNonQuery(); + ds.Merge(ds2); + ds.AcceptChanges(); - var ds = new DataSet(); + var dr2 = new NpgsqlCommand($"SELECT field_int2, field_numeric, field_timestamp FROM {table}", conn).ExecuteReader(); + dr2.Read(); - var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); - da.InsertCommand = new NpgsqlCommand(";", conn); - da.UpdateCommand = new NpgsqlCommand($"UPDATE {table} SET field_int2 = :a, field_timestamp = :b, field_numeric = :c WHERE field_serial = :d", conn); + Assert.AreEqual(4, dr2[0]); + Assert.AreEqual(7.3000000M, dr2[1]); + dr2.Close(); + } - da.UpdateCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); - da.UpdateCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime)); - da.UpdateCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); - da.UpdateCommand.Parameters.Add(new NpgsqlParameter("d", NpgsqlDbType.Bigint)); + [Test] + public async Task DataAdapter_update_return_value() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); - da.UpdateCommand.Parameters[0].Direction = ParameterDirection.Input; - da.UpdateCommand.Parameters[1].Direction = ParameterDirection.Input; - da.UpdateCommand.Parameters[2].Direction = ParameterDirection.Input; - da.UpdateCommand.Parameters[3].Direction = ParameterDirection.Input; + da.InsertCommand = new NpgsqlCommand($@"INSERT INTO {table} (field_int2, field_timestamp, field_numeric) VALUES (:a, :b, :c)", conn); - da.UpdateCommand.Parameters[0].SourceColumn = "field_int2"; - da.UpdateCommand.Parameters[1].SourceColumn = "field_timestamp"; - da.UpdateCommand.Parameters[2].SourceColumn = "field_numeric"; - da.UpdateCommand.Parameters[3].SourceColumn = "field_serial"; + da.InsertCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); + da.InsertCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime2)); + da.InsertCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); - da.Fill(ds); + da.InsertCommand.Parameters[0].Direction = ParameterDirection.Input; + da.InsertCommand.Parameters[1].Direction = ParameterDirection.Input; + da.InsertCommand.Parameters[2].Direction = ParameterDirection.Input; - var dt = ds.Tables[0]; - Assert.IsNotNull(dt); + da.InsertCommand.Parameters[0].SourceColumn = "field_int2"; + da.InsertCommand.Parameters[1].SourceColumn = "field_timestamp"; + da.InsertCommand.Parameters[2].SourceColumn = "field_numeric"; - var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; - dr["field_int2"] = 4; + da.Fill(ds); - var ds2 = ds.GetChanges()!; - da.Update(ds2); - ds.Merge(ds2); - ds.AcceptChanges(); + var dt = ds.Tables[0]; + var dr = dt.NewRow(); + dr["field_int2"] = 4; + dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); + dr["field_numeric"] = 7.3M; + dt.Rows.Add(dr); - using (var dr2 = new NpgsqlCommand($"SELECT field_int2 FROM {table}", conn).ExecuteReader()) - { - dr2.Read(); - Assert.AreEqual(4, dr2["field_int2"]); - } - } - } + dr = dt.NewRow(); + dr["field_int2"] = 4; + dr["field_timestamp"] = new DateTime(2003, 01, 30, 14, 0, 0); + dr["field_numeric"] = 7.3M; + dt.Rows.Add(dr); - [Test] - public async Task FillWithDuplicateColumnName() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); + var ds2 = ds.GetChanges()!; + var daupdate = da.Update(ds2); - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"SELECT field_serial, field_serial FROM {table}", conn); - da.Fill(ds); - } - } + Assert.AreEqual(2, daupdate); + } - [Test] - [Ignore("")] - public Task UpdateWithDataSet() => DoUpdateWithDataSet(); + [Test] + [Ignore("")] + public async Task DataAdapter_update_return_value2() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var cmd = conn.CreateCommand(); + var da = new NpgsqlDataAdapter($"select * from {table}", conn); + var cb = new NpgsqlCommandBuilder(da); + var ds = new DataSet(); + da.Fill(ds); + + //## Insert a new row with id = 1 + ds.Tables[0].Rows.Add(0.4, 0.5); + da.Update(ds); + + //## change id from 1 to 2 + cmd.CommandText = $"update {table} set field_float4 = 0.8"; + cmd.ExecuteNonQuery(); + + //## change value to newvalue + ds.Tables[0].Rows[0][1] = 0.7; + //## update should fail, and make a DBConcurrencyException + var count = da.Update(ds); + //## count is 1, even if the isn't updated in the database + Assert.AreEqual(0, count); + } - public async Task DoUpdateWithDataSet() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); + [Test] + public async Task Fill_with_empty_resultset() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); - var command = new NpgsqlCommand($"insert into {table} (field_int2) values (2)", conn); - command.ExecuteNonQuery(); + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"SELECT field_serial, field_int2, field_timestamp, field_numeric FROM {table} WHERE field_serial = -1", conn); - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"select * from {table}", conn); - var cb = new NpgsqlCommandBuilder(da); - Assert.IsNotNull(cb); + da.Fill(ds); - da.Fill(ds); + Assert.AreEqual(1, ds.Tables.Count); + Assert.AreEqual(4, ds.Tables[0].Columns.Count); + Assert.AreEqual("field_serial", ds.Tables[0].Columns[0].ColumnName); + Assert.AreEqual("field_int2", ds.Tables[0].Columns[1].ColumnName); + Assert.AreEqual("field_timestamp", ds.Tables[0].Columns[2].ColumnName); + Assert.AreEqual("field_numeric", ds.Tables[0].Columns[3].ColumnName); + } - var dt = ds.Tables[0]; - Assert.IsNotNull(dt); + [Test] + [Ignore("")] + public async Task Fill_add_with_key() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"select field_serial, field_int2, field_timestamp, field_numeric from {table}", conn); + + da.MissingSchemaAction = MissingSchemaAction.AddWithKey; + da.Fill(ds); + + var field_serial = ds.Tables[0].Columns[0]; + var field_int2 = ds.Tables[0].Columns[1]; + var field_timestamp = ds.Tables[0].Columns[2]; + var field_numeric = ds.Tables[0].Columns[3]; + + Assert.IsFalse(field_serial.AllowDBNull); + Assert.IsTrue(field_serial.AutoIncrement); + Assert.AreEqual("field_serial", field_serial.ColumnName); + Assert.AreEqual(typeof(int), field_serial.DataType); + Assert.AreEqual(0, field_serial.Ordinal); + Assert.IsTrue(field_serial.Unique); + + Assert.IsTrue(field_int2.AllowDBNull); + Assert.IsFalse(field_int2.AutoIncrement); + Assert.AreEqual("field_int2", field_int2.ColumnName); + Assert.AreEqual(typeof(short), field_int2.DataType); + Assert.AreEqual(1, field_int2.Ordinal); + Assert.IsFalse(field_int2.Unique); + + Assert.IsTrue(field_timestamp.AllowDBNull); + Assert.IsFalse(field_timestamp.AutoIncrement); + Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); + Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); + Assert.AreEqual(2, field_timestamp.Ordinal); + Assert.IsFalse(field_timestamp.Unique); + + Assert.IsTrue(field_numeric.AllowDBNull); + Assert.IsFalse(field_numeric.AutoIncrement); + Assert.AreEqual("field_numeric", field_numeric.ColumnName); + Assert.AreEqual(typeof(decimal), field_numeric.DataType); + Assert.AreEqual(3, field_numeric.Ordinal); + Assert.IsFalse(field_numeric.Unique); + } - var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; + [Test] + public async Task Fill_add_columns() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); - dr["field_int2"] = 4; + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"SELECT field_serial, field_int2, field_timestamp, field_numeric FROM {table}", conn); - var ds2 = ds.GetChanges()!; - da.Update(ds2); - ds.Merge(ds2); - ds.AcceptChanges(); + da.MissingSchemaAction = MissingSchemaAction.Add; + da.Fill(ds); - using (var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader()) - { - dr2.Read(); - Assert.AreEqual(4, dr2["field_int2"]); - } - } - } + var field_serial = ds.Tables[0].Columns[0]; + var field_int2 = ds.Tables[0].Columns[1]; + var field_timestamp = ds.Tables[0].Columns[2]; + var field_numeric = ds.Tables[0].Columns[3]; - [Test] - [Ignore("")] - public Task InsertWithCommandBuilderCaseSensitive() - => DoInsertWithCommandBuilderCaseSensitive(); + Assert.AreEqual("field_serial", field_serial.ColumnName); + Assert.AreEqual(typeof(int), field_serial.DataType); + Assert.AreEqual(0, field_serial.Ordinal); - public async Task DoInsertWithCommandBuilderCaseSensitive() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var ds = new DataSet(); - var da = new NpgsqlDataAdapter($"select * from {table}", conn); - var builder = new NpgsqlCommandBuilder(da); - Assert.IsNotNull(builder); - - da.Fill(ds); - - var dt = ds.Tables[0]; - var dr = dt.NewRow(); - dr["Field_Case_Sensitive"] = 4; - dt.Rows.Add(dr); - - var ds2 = ds.GetChanges()!; - da.Update(ds2); - ds.Merge(ds2); - ds.AcceptChanges(); - - using (var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader()) - { - dr2.Read(); - Assert.AreEqual(4, dr2[1]); - } - } - } - - [Test] - public async Task IntervalAsTimeSpan() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@" + Assert.AreEqual("field_int2", field_int2.ColumnName); + Assert.AreEqual(typeof(short), field_int2.DataType); + Assert.AreEqual(1, field_int2.Ordinal); + + Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); + Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); + Assert.AreEqual(2, field_timestamp.Ordinal); + + Assert.AreEqual("field_numeric", field_numeric.ColumnName); + Assert.AreEqual(typeof(decimal), field_numeric.DataType); + Assert.AreEqual(3, field_numeric.Ordinal); + } + + [Test] + public async Task Update_letting_null_field_falue() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var command = new NpgsqlCommand($"INSERT INTO {table} (field_int2) VALUES (2)", conn); + command.ExecuteNonQuery(); + + var ds = new DataSet(); + + var da = new NpgsqlDataAdapter($"SELECT * FROM {table}", conn); + da.InsertCommand = new NpgsqlCommand(";", conn); + da.UpdateCommand = new NpgsqlCommand($"UPDATE {table} SET field_int2 = :a, field_timestamp = :b, field_numeric = :c WHERE field_serial = :d", conn); + + da.UpdateCommand.Parameters.Add(new NpgsqlParameter("a", DbType.Int16)); + da.UpdateCommand.Parameters.Add(new NpgsqlParameter("b", DbType.DateTime)); + da.UpdateCommand.Parameters.Add(new NpgsqlParameter("c", DbType.Decimal)); + da.UpdateCommand.Parameters.Add(new NpgsqlParameter("d", NpgsqlDbType.Bigint)); + + da.UpdateCommand.Parameters[0].Direction = ParameterDirection.Input; + da.UpdateCommand.Parameters[1].Direction = ParameterDirection.Input; + da.UpdateCommand.Parameters[2].Direction = ParameterDirection.Input; + da.UpdateCommand.Parameters[3].Direction = ParameterDirection.Input; + + da.UpdateCommand.Parameters[0].SourceColumn = "field_int2"; + da.UpdateCommand.Parameters[1].SourceColumn = "field_timestamp"; + da.UpdateCommand.Parameters[2].SourceColumn = "field_numeric"; + da.UpdateCommand.Parameters[3].SourceColumn = "field_serial"; + + da.Fill(ds); + + var dt = ds.Tables[0]; + Assert.IsNotNull(dt); + + var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; + dr["field_int2"] = 4; + + var ds2 = ds.GetChanges()!; + da.Update(ds2); + ds.Merge(ds2); + ds.AcceptChanges(); + + using var dr2 = new NpgsqlCommand($"SELECT field_int2 FROM {table}", conn).ExecuteReader(); + dr2.Read(); + Assert.AreEqual(4, dr2["field_int2"]); + } + + [Test] + public async Task Fill_with_duplicate_column_name() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"SELECT field_serial, field_serial FROM {table}", conn); + da.Fill(ds); + } + + [Test] + [Ignore("")] + public Task Update_with_DataSet() => DoUpdateWithDataSet(); + + public async Task DoUpdateWithDataSet() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var command = new NpgsqlCommand($"insert into {table} (field_int2) values (2)", conn); + command.ExecuteNonQuery(); + + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"select * from {table}", conn); + var cb = new NpgsqlCommandBuilder(da); + Assert.IsNotNull(cb); + + da.Fill(ds); + + var dt = ds.Tables[0]; + Assert.IsNotNull(dt); + + var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; + + dr["field_int2"] = 4; + + var ds2 = ds.GetChanges()!; + da.Update(ds2); + ds.Merge(ds2); + ds.AcceptChanges(); + + using var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader(); + dr2.Read(); + Assert.AreEqual(4, dr2["field_int2"]); + } + + [Test] + [Ignore("")] + public async Task Insert_with_CommandBuilder_case_sensitive() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var ds = new DataSet(); + var da = new NpgsqlDataAdapter($"select * from {table}", conn); + var builder = new NpgsqlCommandBuilder(da); + Assert.IsNotNull(builder); + + da.Fill(ds); + + var dt = ds.Tables[0]; + var dr = dt.NewRow(); + dr["Field_Case_Sensitive"] = 4; + dt.Rows.Add(dr); + + var ds2 = ds.GetChanges()!; + da.Update(ds2); + ds.Merge(ds2); + ds.AcceptChanges(); + + using var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader(); + dr2.Read(); + Assert.AreEqual(4, dr2[1]); + } + + [Test] + public async Task Interval_as_TimeSpan() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} ( pk SERIAL PRIMARY KEY, interval INTERVAL ); INSERT INTO {table} (interval) VALUES ('1 hour'::INTERVAL);"); - var dt = new DataTable("data"); - var command = new NpgsqlCommand - { - CommandType = CommandType.Text, - CommandText = $"SELECT interval FROM {table}", - Connection = conn - }; - var da = new NpgsqlDataAdapter { SelectCommand = command }; - da.Fill(dt); - } - } - - [Test] - public async Task IntervalAsTimeSpan2() + var dt = new DataTable("data"); + var command = new NpgsqlCommand { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@" + CommandType = CommandType.Text, + CommandText = $"SELECT interval FROM {table}", + Connection = conn + }; + var da = new NpgsqlDataAdapter { SelectCommand = command }; + da.Fill(dt); + } + + [Test] + public async Task Interval_as_TimeSpan2() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} ( pk SERIAL PRIMARY KEY, interval INTERVAL ); INSERT INTO {table} (interval) VALUES ('1 hour'::INTERVAL);"); - var dt = new DataTable("data"); - //DataColumn c = dt.Columns.Add("dauer", typeof(TimeSpan)); - // DataColumn c = dt.Columns.Add("dauer", typeof(NpgsqlInterval)); - //c.AllowDBNull = true; - var command = new NpgsqlCommand(); - command.CommandType = CommandType.Text; - command.CommandText = $"SELECT interval FROM {table}"; - command.Connection = conn; - var da = new NpgsqlDataAdapter(); - da.SelectCommand = command; - da.Fill(dt); - } - } - - [Test] - public async Task DbDataAdapterCommandAccess() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT CAST('1 hour' AS interval) AS dauer", conn)) - { - var da = new NpgsqlDataAdapter(); - da.SelectCommand = command; - System.Data.Common.DbDataAdapter common = da; - Assert.IsNotNull(common.SelectCommand); - } - } - - [Test, Description("Makes sure that the INSERT/UPDATE/DELETE commands are auto-populated on NpgsqlDataAdapter")] - [IssueLink("https://github.com/npgsql/npgsql/issues/179")] - [Ignore("Somehow related to us using a temporary table???")] - public async Task AutoPopulateAdapterCommands() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - var da = new NpgsqlDataAdapter($"SELECT field_pk,field_int4 FROM {table}", conn); - var builder = new NpgsqlCommandBuilder(da); - var ds = new DataSet(); - da.Fill(ds); - - var t = ds.Tables[0]; - var row = t.NewRow(); - row["field_pk"] = 1; - row["field_int4"] = 8; - t.Rows.Add(row); - da.Update(ds); - Assert.That(await conn.ExecuteScalarAsync($"SELECT field_int4 FROM {table}"), Is.EqualTo(8)); - - row["field_int4"] = 9; - da.Update(ds); - Assert.That(await conn.ExecuteScalarAsync($"SELECT field_int4 FROM {table}"), Is.EqualTo(9)); - - row.Delete(); - da.Update(ds); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } - } - - [Test] - public void CommandBuilderQuoting() - { - var cb = new NpgsqlCommandBuilder(); - const string orig = "some\"column"; - var quoted = cb.QuoteIdentifier(orig); - Assert.That(quoted, Is.EqualTo("\"some\"\"column\"")); - Assert.That(cb.UnquoteIdentifier(quoted), Is.EqualTo(orig)); - } - - [Test, Description("Makes sure a correct SQL string is built with GetUpdateCommand(true) using correct parameter names and placeholders")] - [IssueLink("https://github.com/npgsql/npgsql/issues/397")] - [Ignore("Somehow related to us using a temporary table???")] - public async Task GetUpdateCommand() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await SetupTempTable(conn, out var table); - - using (var da = new NpgsqlDataAdapter($"SELECT field_pk, field_int4 FROM {table}", conn)) - { - using (var cb = new NpgsqlCommandBuilder(da)) - { - var updateCommand = cb.GetUpdateCommand(true); - da.UpdateCommand = updateCommand; - - var ds = new DataSet(); - da.Fill(ds); - - var t = ds.Tables[0]; - var row = t.Rows.Add(); - row["field_pk"] = 1; - row["field_int4"] = 1; - da.Update(ds); - - row["field_int4"] = 2; - da.Update(ds); - - row.Delete(); - da.Update(ds); - } - } - } - } - - [Test] - public async Task LoadDataTable() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "char5 CHAR(5), varchar5 VARCHAR(5)", out var table); - using (var command = new NpgsqlCommand($"SELECT char5, varchar5 FROM {table}", conn)) - using (var dr = command.ExecuteReader()) - { - var dt = new DataTable(); - dt.Load(dr); - dr.Close(); - - Assert.AreEqual(5, dt.Columns[0].MaxLength); - Assert.AreEqual(5, dt.Columns[1].MaxLength); - } - } - } - - public Task SetupTempTable(NpgsqlConnection conn, out string table) - => CreateTempTable(conn, @" + var dt = new DataTable("data"); + //DataColumn c = dt.Columns.Add("dauer", typeof(TimeSpan)); + // DataColumn c = dt.Columns.Add("dauer", typeof(NpgsqlInterval)); + //c.AllowDBNull = true; + var command = new NpgsqlCommand(); + command.CommandType = CommandType.Text; + command.CommandText = $"SELECT interval FROM {table}"; + command.Connection = conn; + var da = new NpgsqlDataAdapter(); + da.SelectCommand = command; + da.Fill(dt); + } + + [Test] + public async Task DataAdapter_command_access() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT CAST('1 hour' AS interval) AS dauer", conn); + var da = new NpgsqlDataAdapter(); + da.SelectCommand = command; + System.Data.Common.DbDataAdapter common = da; + Assert.IsNotNull(common.SelectCommand); + } + + [Test, Description("Makes sure that the INSERT/UPDATE/DELETE commands are auto-populated on NpgsqlDataAdapter")] + [IssueLink("https://github.com/npgsql/npgsql/issues/179")] + [Ignore("Somehow related to us using a temporary table???")] + public async Task Auto_populate_adapter_commands() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + var da = new NpgsqlDataAdapter($"SELECT field_pk,field_int4 FROM {table}", conn); + var builder = new NpgsqlCommandBuilder(da); + var ds = new DataSet(); + da.Fill(ds); + + var t = ds.Tables[0]; + var row = t.NewRow(); + row["field_pk"] = 1; + row["field_int4"] = 8; + t.Rows.Add(row); + da.Update(ds); + Assert.That(await conn.ExecuteScalarAsync($"SELECT field_int4 FROM {table}"), Is.EqualTo(8)); + + row["field_int4"] = 9; + da.Update(ds); + Assert.That(await conn.ExecuteScalarAsync($"SELECT field_int4 FROM {table}"), Is.EqualTo(9)); + + row.Delete(); + da.Update(ds); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } + + [Test] + public void Command_builder_quoting() + { + var cb = new NpgsqlCommandBuilder(); + const string orig = "some\"column"; + var quoted = cb.QuoteIdentifier(orig); + Assert.That(quoted, Is.EqualTo("\"some\"\"column\"")); + Assert.That(cb.UnquoteIdentifier(quoted), Is.EqualTo(orig)); + } + + [Test, Description("Makes sure a correct SQL string is built with GetUpdateCommand(true) using correct parameter names and placeholders")] + [IssueLink("https://github.com/npgsql/npgsql/issues/397")] + [Ignore("Somehow related to us using a temporary table???")] + public async Task Get_UpdateCommand() + { + using var conn = await OpenConnectionAsync(); + var table = await SetupTempTable(conn); + + using var da = new NpgsqlDataAdapter($"SELECT field_pk, field_int4 FROM {table}", conn); + using var cb = new NpgsqlCommandBuilder(da); + var updateCommand = cb.GetUpdateCommand(true); + da.UpdateCommand = updateCommand; + + var ds = new DataSet(); + da.Fill(ds); + + var t = ds.Tables[0]; + var row = t.Rows.Add(); + row["field_pk"] = 1; + row["field_int4"] = 1; + da.Update(ds); + + row["field_int4"] = 2; + da.Update(ds); + + row.Delete(); + da.Update(ds); + } + + [Test] + public async Task Load_DataTable() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "char5 CHAR(5), varchar5 VARCHAR(5)"); + using var command = new NpgsqlCommand($"SELECT char5, varchar5 FROM {table}", conn); + using var dr = command.ExecuteReader(); + var dt = new DataTable(); + dt.Load(dr); + dr.Close(); + + Assert.AreEqual(5, dt.Columns[0].MaxLength); + Assert.AreEqual(5, dt.Columns[1].MaxLength); + } + + public Task SetupTempTable(NpgsqlConnection conn) + => CreateTempTable(conn, @" field_pk SERIAL PRIMARY KEY, field_serial SERIAL, field_int2 SMALLINT, field_int4 INTEGER, field_numeric NUMERIC, -field_timestamp TIMESTAMP", out table); - } +field_timestamp TIMESTAMP"); } diff --git a/test/Npgsql.Tests/DataSourceTests.cs b/test/Npgsql.Tests/DataSourceTests.cs new file mode 100644 index 0000000000..639e83a795 --- /dev/null +++ b/test/Npgsql.Tests/DataSourceTests.cs @@ -0,0 +1,359 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using NUnit.Framework; + +// ReSharper disable MethodHasAsyncOverload + +namespace Npgsql.Tests; + +public class DataSourceTests : TestBase +{ + [Test] + public new async Task CreateConnection() + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var connection = dataSource.CreateConnection(); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); + + await connection.OpenAsync(); + Assert.That(await connection.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task OpenConnection([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + + Assert.That(connection.State, Is.EqualTo(ConnectionState.Open)); + + Assert.That(await connection.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task ExecuteScalar_on_connectionless_command([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + if (async) + Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo(1)); + else + Assert.That(command.ExecuteScalar(), Is.EqualTo(1)); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public async Task ExecuteNonQuery_on_connectionless_command([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + if (async) + Assert.That(await command.ExecuteNonQueryAsync(), Is.EqualTo(-1)); + else + Assert.That(command.ExecuteNonQuery(), Is.EqualTo(-1)); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public async Task ExecuteReader_on_connectionless_command([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + await using (var reader = async ? await command.ExecuteReaderAsync() : command.ExecuteReader()) + { + Assert.True(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + } + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public async Task ExecuteScalar_on_connectionless_batch([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.BatchCommands.Add(new("SELECT 1")); + batch.BatchCommands.Add(new("SELECT 2")); + + if (async) + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(1)); + else + Assert.That(batch.ExecuteScalar(), Is.EqualTo(1)); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public async Task ExecuteNonQuery_on_connectionless_batch([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.BatchCommands.Add(new("SELECT 1")); + batch.BatchCommands.Add(new("SELECT 2")); + + if (async) + Assert.That(await batch.ExecuteNonQueryAsync(), Is.EqualTo(-1)); + else + Assert.That(batch.ExecuteNonQuery(), Is.EqualTo(-1)); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public async Task ExecuteReader_on_connectionless_batch([Values] bool async) + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.BatchCommands.Add(new("SELECT 1")); + batch.BatchCommands.Add(new("SELECT 2")); + + using (var reader = async ? await batch.ExecuteReaderAsync() : batch.ExecuteReader()) + { + Assert.True(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.True(reader.NextResult()); + Assert.True(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); + } + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + + [Test] + public void Dispose() + { + using var dataSource = NpgsqlDataSource.Create(ConnectionString); + var connection1 = dataSource.OpenConnection(); + var connection2 = dataSource.OpenConnection(); + connection1.Close(); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 2, Idle: 1, Busy: 1))); + + dataSource.Dispose(); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 0, Busy: 1))); + + Assert.That(() => dataSource.OpenConnection(), Throws.Exception.TypeOf()); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 0, Busy: 1))); + + connection2.Close(); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 0, Idle: 0, Busy: 0))); + } + + [Test] + public async Task DisposeAsync() + { + await using var dataSource = NpgsqlDataSource.Create(ConnectionString); + var connection1 = await dataSource.OpenConnectionAsync(); + var connection2 = await dataSource.OpenConnectionAsync(); + await connection1.CloseAsync(); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 2, Idle: 1, Busy: 1))); + + await dataSource.DisposeAsync(); + Assert.That(() => dataSource.OpenConnectionAsync(), Throws.Exception.TypeOf()); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 0, Busy: 1))); + + await connection2.CloseAsync(); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 0, Idle: 0, Busy: 0))); + } + + [Test] + public void No_password_without_PersistSecurityInfo() + { + if (string.IsNullOrEmpty(new NpgsqlConnectionStringBuilder(ConnectionString).Password)) + Assert.Fail("No password in default connection string, test cannot run"); + + using var dataSource = NpgsqlDataSource.Create(ConnectionString); + var parsedConnectionString = new NpgsqlConnectionStringBuilder(dataSource.ConnectionString); + Assert.That(parsedConnectionString.Password, Is.Null); + } + + [Test] + public async Task Cannot_access_connection_transaction_on_data_source_command() + { + await using var command = DataSource.CreateCommand(); + + Assert.That(() => command.Connection, Throws.Exception.TypeOf()); + Assert.That(() => command.Connection = null, Throws.Exception.TypeOf()); + Assert.That(() => command.Transaction, Throws.Exception.TypeOf()); + Assert.That(() => command.Transaction = null, Throws.Exception.TypeOf()); + + Assert.That(() => command.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => command.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Cannot_access_connection_transaction_on_data_source_batch() + { + await using var batch = DataSource.CreateBatch(); + + Assert.That(() => batch.Connection, Throws.Exception.TypeOf()); + Assert.That(() => batch.Connection = null, Throws.Exception.TypeOf()); + Assert.That(() => batch.Transaction, Throws.Exception.TypeOf()); + Assert.That(() => batch.Transaction = null, Throws.Exception.TypeOf()); + + Assert.That(() => batch.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => batch.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Cannot_get_connection_after_dispose_pooled([Values] bool async) + { + var dataSource = NpgsqlDataSource.Create(ConnectionString); + + if (async) + { + await dataSource.DisposeAsync(); + Assert.That(() => dataSource.OpenConnectionAsync(), Throws.Exception.TypeOf()); + } + else + { + dataSource.Dispose(); + Assert.That(() => dataSource.OpenConnection(), Throws.Exception.TypeOf()); + } + } + + [Test] + public async Task Cannot_get_connection_after_dispose_unpooled([Values] bool async) + { + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = false }; + var dataSource = NpgsqlDataSource.Create(connectionStringBuilder); + + if (async) + { + await dataSource.DisposeAsync(); + Assert.That(() => dataSource.OpenConnectionAsync(), Throws.Exception.TypeOf()); + } + else + { + dataSource.Dispose(); + Assert.That(() => dataSource.OpenConnection(), Throws.Exception.TypeOf()); + } + } + + [Test] // #4752 + public async Task As_DbDataSource([Values] bool async) + { + await using DbDataSource dataSource = NpgsqlDataSource.Create(ConnectionString); + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Open)); + + await using var command = dataSource.CreateCommand("SELECT 1"); + + Assert.That(async + ? await command.ExecuteScalarAsync() + : command.ExecuteScalar(), Is.EqualTo(1)); + } + + [Test] + public async Task Executing_command_on_disposed_datasource([Values] bool multiplexing) + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Multiplexing = multiplexing + }; + DbDataSource dataSource = NpgsqlDataSource.Create(csb.ConnectionString); + await using (var _ = await dataSource.OpenConnectionAsync()) {} + await dataSource.DisposeAsync(); + await using var command = dataSource.CreateCommand("SELECT 1"); + Assert.ThrowsAsync(command.ExecuteNonQueryAsync); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4840")] + public async Task Multiplexing_connectionless_command_open_connection() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Multiplexing = true + }; + await using var dataSource = NpgsqlDataSource.Create(csb.ConnectionString); + + await using var conn = await dataSource.OpenConnectionAsync(); + await using var _ = await conn.BeginTransactionAsync(); + + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + await using var reader = await command.ExecuteReaderAsync(); + Assert.True(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + } + + [Test] + public async Task Connection_string_builder_settings_are_frozen_on_Build() + { + var builder = CreateDataSourceBuilder(); + builder.ConnectionStringBuilder.ApplicationName = "foo"; + await using var dataSource = builder.Build(); + + builder.ConnectionStringBuilder.ApplicationName = "bar"; + + await using var command = dataSource.CreateCommand("SHOW application_name"); + Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo("foo")); + } + + class Test + { + public int Id { get; set; } + } + + [Test] + public async Task ConfigureJsonOptions_is_order_independent() + { + // Expect failure, no options + { + var builder = CreateDataSourceBuilder(); + builder.EnableDynamicJson(); + await using var dataSource = builder.Build(); + + await using var command = dataSource.CreateCommand("SELECT '{\"id\": 1}'::json;"); + using var reader = await command.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldValue(0).Id, Is.EqualTo(default(int))); + } + + // Expect success, ConfigureJsonOptions before EnableDynamicJson + { + var builder = CreateDataSourceBuilder(); + builder.ConfigureJsonOptions(new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + builder.EnableDynamicJson(); + await using var dataSource = builder.Build(); + + await using var command = dataSource.CreateCommand("SELECT '{\"id\": 1}'::json;"); + using var reader = await command.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldValue(0).Id, Is.EqualTo(1)); + } + + // Expect success, EnableDynamicJson before ConfigureJsonOptions + { + var builder = CreateDataSourceBuilder(); + builder.EnableDynamicJson(); + builder.ConfigureJsonOptions(new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + await using var dataSource = builder.Build(); + + await using var command = dataSource.CreateCommand("SELECT '{\"id\": 1}'::json;"); + using var reader = await command.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldValue(0).Id, Is.EqualTo(1)); + } + } +} diff --git a/test/Npgsql.Tests/DataTypeNameTests.cs b/test/Npgsql.Tests/DataTypeNameTests.cs new file mode 100644 index 0000000000..fd366d8258 --- /dev/null +++ b/test/Npgsql.Tests/DataTypeNameTests.cs @@ -0,0 +1,26 @@ +using System; +using Npgsql.Internal.Postgres; +using NUnit.Framework; + +namespace Npgsql.Tests; + +public class DataTypeNameTests +{ + [Test] + public void MaxLengthDataTypeName() + { + var name = new string('a', DataTypeName.NAMEDATALEN); + var fullyQualifiedDataTypeName= $"public.{name}"; + Assert.DoesNotThrow(() => new DataTypeName(fullyQualifiedDataTypeName)); + Assert.AreEqual(new DataTypeName(fullyQualifiedDataTypeName).Value, fullyQualifiedDataTypeName); + } + + [Test] + public void TooLongDataTypeName() + { + var name = new string('a', DataTypeName.NAMEDATALEN + 1); + var fullyQualifiedDataTypeName= $"public.{name}"; + var exception = Assert.Throws(() => new DataTypeName(fullyQualifiedDataTypeName)); + Assert.That(exception!.Message, Does.EndWith($": public.{new string('a', DataTypeName.NAMEDATALEN)}")); + } +} diff --git a/test/Npgsql.Tests/DistributedTransactionTests.cs b/test/Npgsql.Tests/DistributedTransactionTests.cs index 4208832460..e55d6e7bd9 100644 --- a/test/Npgsql.Tests/DistributedTransactionTests.cs +++ b/test/Npgsql.Tests/DistributedTransactionTests.cs @@ -1,626 +1,639 @@ +#if NET7_0_OR_GREATER + using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using System.Transactions; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -// TransactionScope exists in netstandard20, but distributed transactions do not. -// We used to support distributed transactions back when we targeted .NET Framework, keeping them here in case -// they get ported to .NET Core (https://github.com/dotnet/runtime/issues/715) -#if DISTRIBUTED_TRANSACTIONS +namespace Npgsql.Tests; -namespace Npgsql.Tests +[NonParallelizable] +public class DistributedTransactionTests : TestBase { - [NonParallelizable] - public class DistributedTransactionTests : TestBase + [Test] + public void Two_connections_rollback_implicit_enlistment() { - [Test] - public void TwoConnections() - { - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) - using (var conn2 = OpenConnection(ConnectionStringEnlistOff)) - { - using (var scope = new TransactionScope()) - { - conn1.EnlistTransaction(Transaction.Current); - conn2.EnlistTransaction(Transaction.Current); + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - Assert.That(conn2.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); + var dataSource = EnlistOnDataSource; - scope.Complete(); - } - } - // TODO: There may be a race condition here, where the prepared transaction above still hasn't committed. + using (new TransactionScope()) + using (var conn1 = dataSource.OpenConnection()) + using (var conn2 = dataSource.OpenConnection()) + { + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + conn2.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test2')"); + } + + Retry(() => + { AssertNoDistributedIdentifier(); AssertNoPreparedTransactions(); - AssertNumberOfRows(2); + AssertNumberOfRows(adminConn, table, 0); + }); + } + + [Test] + public void Two_connections_rollback_explicit_enlistment() + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); + + var dataSource = EnlistOffDataSource; + + using (var conn1 = dataSource.OpenConnection()) + using (var conn2 = dataSource.OpenConnection()) + using (new TransactionScope()) + { + conn1.EnlistTransaction(Transaction.Current); + conn2.EnlistTransaction(Transaction.Current); + + Assert.That(conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); + Assert.That(conn2.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); } - [Test] - public void TwoConnectionsRollback() + Retry(() => { - using (new TransactionScope()) - using (var conn1 = OpenConnection(ConnectionStringEnlistOn)) - using (var conn2 = OpenConnection(ConnectionStringEnlistOn)) - { - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - Assert.That(conn2.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); - } - // TODO: There may be a race condition here, where the prepared transaction above still hasn't committed. AssertNoDistributedIdentifier(); AssertNoPreparedTransactions(); - AssertNumberOfRows(0); + AssertNumberOfRows(adminConn, table, 0); + }); + } + + [Test] + public void Two_connections_commit() + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); + + var dataSource = EnlistOnDataSource; + + using (var scope = new TransactionScope()) + using (var conn1 = dataSource.OpenConnection()) + using (var conn2 = dataSource.OpenConnection()) + { + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + conn2.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test2')"); + + scope.Complete(); } - [Test, Ignore("Flaky")] - public void DistributedRollback() + Retry(() => + { + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + AssertNumberOfRows(adminConn, table, 2); + }); + } + + [Test] + public void Two_connections_with_failure() + { + // Use our own data source since this test breaks the connection with a critical failure, affecting database state tracking. + using var dataSource = CreateDataSource(csb => csb.Enlist = true); + using var adminConn = dataSource.OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); + + using var scope = new TransactionScope(); + using var conn1 = dataSource.OpenConnection(); + using var conn2 = dataSource.OpenConnection(); + + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + conn2.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test2')"); + + conn1.ExecuteNonQuery($"SELECT pg_terminate_backend({conn2.ProcessID})"); + scope.Complete(); + Assert.That(() => scope.Dispose(), Throws.Exception.TypeOf()); + + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + AssertNumberOfRows(adminConn, table, 0); + } + + [Test(Description = "Transaction race, bool distributed")] + [Explicit("Fails on Appveyor (https://ci.appveyor.com/project/roji/npgsql/build/3.3.0-250)")] + public void Transaction_race([Values(false, true)] bool distributed) + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); + + var dataSource = EnlistOnDataSource; + + for (var i = 1; i <= 100; i++) { - var disposedCalled = false; - var tx = new TransactionScope(); + var eventQueue = new ConcurrentQueue(); try { - using (var conn1 = OpenConnection(ConnectionStringEnlistOn)) + using (var tx = new TransactionScope()) + using (var conn1 = dataSource.OpenConnection()) { - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); + eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + eventQueue.Enqueue(new TransactionEvent("Insert done")); + + if (distributed) + { + EnlistResource.EscalateToDistributed(eventQueue); + AssertHasDistributedIdentifier(); + } + else + { + EnlistResource.EnlistVolatile(eventQueue); + AssertNoDistributedIdentifier(); + } - EnlistResource.EscalateToDistributed(true); - AssertHasDistributedIdentifier(); tx.Complete(); + eventQueue.Enqueue(new TransactionEvent("Scope completed")); } - disposedCalled = true; - Assert.That(() => tx.Dispose(), Throws.TypeOf()); - // TODO: There may be a race condition here, where the prepared transaction above still hasn't completed. + + eventQueue.Enqueue(new TransactionEvent("Scope disposed")); AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - AssertNumberOfRows(0); - } - finally - { - if (!disposedCalled) - tx.Dispose(); - } - } - [Test(Description = "Transaction race, bool distributed")] - [Explicit("Fails on Appveyor (https://ci.appveyor.com/project/roji/npgsql/build/3.3.0-250)")] - public void TransactionRace([Values(false, true)] bool distributed) - { - for (var i = 1; i <= 100; i++) - { - var eventQueue = new ConcurrentQueue(); - try + if (distributed) { - using (var tx = new TransactionScope()) - using (var conn1 = OpenConnection(ConnectionStringEnlistOn)) + // There may be a race condition here, where the prepared transaction above still hasn't completed. + // This is by design of MS DTC. Giving it up to 100ms to complete. If it proves flaky, raise + // maxLoop. + const int maxLoop = 20; + for (var j = 0; j < maxLoop; j++) { - eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - eventQueue.Enqueue(new TransactionEvent("Insert done")); - - if (distributed) - { - EnlistResource.EscalateToDistributed(eventQueue); - AssertHasDistributedIdentifier(); - } - else + Thread.Sleep(10); + try { - EnlistResource.EnlistVolatile(eventQueue); - AssertNoDistributedIdentifier(); + AssertNumberOfRows(adminConn, table, i); + break; } - - tx.Complete(); - eventQueue.Enqueue(new TransactionEvent("Scope completed")); - } - eventQueue.Enqueue(new TransactionEvent("Scope disposed")); - AssertNoDistributedIdentifier(); - if (distributed) - { - // There may be a race condition here, where the prepared transaction above still hasn't completed. - // This is by design of MS DTC. Giving it up to 100ms to complete. If it proves flaky, raise - // maxLoop. - const int maxLoop = 20; - for (var j = 0; j < maxLoop; j++) + catch { - Thread.Sleep(10); - try - { - AssertNumberOfRows(i); - break; - } - catch - { - if (j == maxLoop - 1) - throw; - } + if (j == maxLoop - 1) + throw; } } - else - AssertNumberOfRows(i); } - catch (Exception ex) - { - Assert.Fail( - @"Failed at iteration {0}. + else + AssertNumberOfRows(adminConn, table, i); + } + catch (Exception ex) + { + Assert.Fail( + @"Failed at iteration {0}. Events: {1} Exception {2}", - i, FormatEventQueue(eventQueue), ex); - } + i, FormatEventQueue(eventQueue), ex); } } + } - [Test] - public void TwoConnectionsWithFailure() - { - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) - using (var conn2 = OpenConnection(ConnectionStringEnlistOff)) - { - var scope = new TransactionScope(); - conn1.EnlistTransaction(Transaction.Current); - conn2.EnlistTransaction(Transaction.Current); - - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - Assert.That(conn2.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); - - conn1.ExecuteNonQuery($"SELECT pg_terminate_backend({conn2.ProcessID})"); - scope.Complete(); - Assert.That(() => scope.Dispose(), Throws.Exception.TypeOf()); + [Test(Description = "Connection reuse race after transaction, bool distributed"), Explicit] + public void Connection_reuse_race_after_transaction([Values(false, true)] bool distributed) + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - using (var tx = conn1.BeginTransaction()) - { - Assert.That(conn1.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(0), "Unexpected data count"); - tx.Rollback(); - } - } - } + var dataSource = EnlistOffDataSource; - [Test(Description = "Connection reuse race after transaction, bool distributed"), Explicit] - public void ConnectionReuseRaceAfterTransaction([Values(false, true)] bool distributed) + for (var i = 1; i <= 100; i++) { - for (var i = 1; i <= 100; i++) + var eventQueue = new ConcurrentQueue(); + try { - var eventQueue = new ConcurrentQueue(); - try + using var conn1 = dataSource.OpenConnection(); + + using (var scope = new TransactionScope()) { - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) - { - using (var scope = new TransactionScope()) - { - conn1.EnlistTransaction(Transaction.Current); - eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); - - if (distributed) - { - EnlistResource.EscalateToDistributed(eventQueue); - AssertHasDistributedIdentifier(); - } - else - { - EnlistResource.EnlistVolatile(eventQueue); - AssertNoDistributedIdentifier(); - } - - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - eventQueue.Enqueue(new TransactionEvent("Insert done")); - - scope.Complete(); - eventQueue.Enqueue(new TransactionEvent("Scope completed")); - } - eventQueue.Enqueue(new TransactionEvent("Scope disposed")); + conn1.EnlistTransaction(Transaction.Current); + eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); - Assert.DoesNotThrow(() => conn1.ExecuteScalar(@"SELECT COUNT(*) FROM data")); + if (distributed) + { + EnlistResource.EscalateToDistributed(eventQueue); + AssertHasDistributedIdentifier(); + } + else + { + EnlistResource.EnlistVolatile(eventQueue); + AssertNoDistributedIdentifier(); } + + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + eventQueue.Enqueue(new TransactionEvent("Insert done")); + + scope.Complete(); + eventQueue.Enqueue(new TransactionEvent("Scope completed")); } - catch (Exception ex) - { - Assert.Fail( - @"Failed at iteration {0}. + + eventQueue.Enqueue(new TransactionEvent("Scope disposed")); + + Assert.DoesNotThrow(() => conn1.ExecuteScalar($"SELECT COUNT(*) FROM {table}")); + } + catch (Exception ex) + { + Assert.Fail( + @"Failed at iteration {0}. Events: {1} Exception {2}", - i, FormatEventQueue(eventQueue), ex); - } + i, FormatEventQueue(eventQueue), ex); } } + } - [Test(Description = "Connection reuse race after rollback, bool distributed"), Explicit("Currently failing.")] - public void ConnectionReuseRaceAfterRollback([Values(false, true)] bool distributed) + [Test(Description = "Connection reuse race after rollback, bool distributed"), Explicit("Currently failing.")] + public void Connection_reuse_race_after_rollback([Values(false, true)] bool distributed) + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); + + var dataSource = EnlistOffDataSource; + + for (var i = 1; i <= 100; i++) { - for (var i = 1; i <= 100; i++) + var eventQueue = new ConcurrentQueue(); + try { - var eventQueue = new ConcurrentQueue(); - try + using var conn1 = dataSource.OpenConnection(); + + using (new TransactionScope()) { - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) + conn1.EnlistTransaction(Transaction.Current); + eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); + + if (distributed) { - using (new TransactionScope()) - { - conn1.EnlistTransaction(Transaction.Current); - eventQueue.Enqueue(new TransactionEvent("Scope started, connection enlisted")); - - if (distributed) - { - EnlistResource.EscalateToDistributed(eventQueue); - AssertHasDistributedIdentifier(); - } - else - { - EnlistResource.EnlistVolatile(eventQueue); - AssertNoDistributedIdentifier(); - } - - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - eventQueue.Enqueue(new TransactionEvent("Insert done")); - - eventQueue.Enqueue(new TransactionEvent("Scope not completed")); - } - eventQueue.Enqueue(new TransactionEvent("Scope disposed")); - conn1.EnlistTransaction(null); - eventQueue.Enqueue(new TransactionEvent("Connection enlisted with null")); - Assert.DoesNotThrow(() => conn1.ExecuteScalar(@"SELECT COUNT(*) FROM data")); + EnlistResource.EscalateToDistributed(eventQueue); + AssertHasDistributedIdentifier(); + } + else + { + EnlistResource.EnlistVolatile(eventQueue); + AssertNoDistributedIdentifier(); } + + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + eventQueue.Enqueue(new TransactionEvent("Insert done")); + + eventQueue.Enqueue(new TransactionEvent("Scope not completed")); } - catch (Exception ex) - { - Assert.Fail( - @"Failed at iteration {0}. + + eventQueue.Enqueue(new TransactionEvent("Scope disposed")); + conn1.EnlistTransaction(null); + eventQueue.Enqueue(new TransactionEvent("Connection enlisted with null")); + Assert.DoesNotThrow(() => conn1.ExecuteScalar($"SELECT COUNT(*) FROM {table}")); + } + catch (Exception ex) + { + Assert.Fail( + @"Failed at iteration {0}. Events: {1} Exception {2}", - i, FormatEventQueue(eventQueue), ex); - } + i, FormatEventQueue(eventQueue), ex); } } + } + + [Test(Description = "Connection reuse race chaining transactions, bool distributed")] + [Explicit] + public void Connection_reuse_race_chaining_transaction([Values(false, true)] bool distributed) + { + using var adminConn = OpenConnection(); + var table = CreateTempTable(adminConn, "name TEXT"); - [Test(Description = "Connection reuse race chaining transactions, bool distributed")] - [Explicit] - public void ConnectionReuseRaceChainingTransaction([Values(false, true)] bool distributed) + var dataSource = EnlistOffDataSource; + + for (var i = 1; i <= 100; i++) { - for (var i = 1; i <= 100; i++) + var eventQueue = new ConcurrentQueue(); + try { - var eventQueue = new ConcurrentQueue(); - try + using var conn1 = dataSource.OpenConnection(); + + using (var scope = new TransactionScope()) { - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) - { - using (var scope = new TransactionScope()) - { - eventQueue.Enqueue(new TransactionEvent("First scope started")); - conn1.EnlistTransaction(Transaction.Current); - eventQueue.Enqueue(new TransactionEvent("First scope, connection enlisted")); - - if (distributed) - { - EnlistResource.EscalateToDistributed(eventQueue); - AssertHasDistributedIdentifier(); - } - else - { - EnlistResource.EnlistVolatile(eventQueue); - AssertNoDistributedIdentifier(); - } - - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - eventQueue.Enqueue(new TransactionEvent("First insert done")); - - scope.Complete(); - eventQueue.Enqueue(new TransactionEvent("First scope completed")); - } - eventQueue.Enqueue(new TransactionEvent("First scope disposed")); + eventQueue.Enqueue(new TransactionEvent("First scope started")); + conn1.EnlistTransaction(Transaction.Current); + eventQueue.Enqueue(new TransactionEvent("First scope, connection enlisted")); - using (var scope = new TransactionScope()) - { - eventQueue.Enqueue(new TransactionEvent("Second scope started")); - conn1.EnlistTransaction(Transaction.Current); - eventQueue.Enqueue(new TransactionEvent("Second scope, connection enlisted")); - - if (distributed) - { - EnlistResource.EscalateToDistributed(eventQueue); - AssertHasDistributedIdentifier(); - } - else - { - EnlistResource.EnlistVolatile(eventQueue); - AssertNoDistributedIdentifier(); - } - - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected second insert rowcount"); - eventQueue.Enqueue(new TransactionEvent("Second insert done")); - - scope.Complete(); - eventQueue.Enqueue(new TransactionEvent("Second scope completed")); - } - eventQueue.Enqueue(new TransactionEvent("Second scope disposed")); + if (distributed) + { + EnlistResource.EscalateToDistributed(eventQueue); + AssertHasDistributedIdentifier(); } + else + { + EnlistResource.EnlistVolatile(eventQueue); + AssertNoDistributedIdentifier(); + } + + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + eventQueue.Enqueue(new TransactionEvent("First insert done")); + + scope.Complete(); + eventQueue.Enqueue(new TransactionEvent("First scope completed")); } - catch (Exception ex) + eventQueue.Enqueue(new TransactionEvent("First scope disposed")); + + using (var scope = new TransactionScope()) { - Assert.Fail( - @"Failed at iteration {0}. + eventQueue.Enqueue(new TransactionEvent("Second scope started")); + conn1.EnlistTransaction(Transaction.Current); + eventQueue.Enqueue(new TransactionEvent("Second scope, connection enlisted")); + + if (distributed) + { + EnlistResource.EscalateToDistributed(eventQueue); + AssertHasDistributedIdentifier(); + } + else + { + EnlistResource.EnlistVolatile(eventQueue); + AssertNoDistributedIdentifier(); + } + + conn1.ExecuteNonQuery($"INSERT INTO {table} (name) VALUES ('test1')"); + eventQueue.Enqueue(new TransactionEvent("Second insert done")); + + scope.Complete(); + eventQueue.Enqueue(new TransactionEvent("Second scope completed")); + } + eventQueue.Enqueue(new TransactionEvent("Second scope disposed")); + } + catch (Exception ex) + { + Assert.Fail( + @"Failed at iteration {0}. Events: {1} Exception {2}", - i, FormatEventQueue(eventQueue), ex); - } + i, FormatEventQueue(eventQueue), ex); } } + } - [Test] - public void ReuseConnectionWithEscalation() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5246")] + public void Transaction_complete_with_undisposed_connections() + { + using var deleteOuter = new TransactionScope(); + using (var delImidiate = new TransactionScope(TransactionScopeOption.RequiresNew)) { - using (new TransactionScope()) - { - using (var conn1 = new NpgsqlConnection(ConnectionStringEnlistOn)) - { - conn1.Open(); - var processId = conn1.ProcessID; - using (new NpgsqlConnection(ConnectionStringEnlistOn)) { } - conn1.Close(); - - conn1.Open(); - Assert.That(conn1.ProcessID, Is.EqualTo(processId)); - conn1.Close(); - } - } + var deleteNow = EnlistOnDataSource.OpenConnection(); + deleteNow.ExecuteNonQuery("SELECT 'del_now'"); + var deleteNow2 = EnlistOnDataSource.OpenConnection(); + deleteNow2.ExecuteNonQuery("SELECT 'del_now2'"); + delImidiate.Complete(); } + var deleteConn = EnlistOnDataSource.OpenConnection(); + deleteConn.ExecuteNonQuery("SELECT 'delete, this should commit last'"); + deleteOuter.Complete(); + } + + #region Utilities + + // MSDTC is asynchronous, i.e. Commit/Rollback may return before the transaction has actually completed in the database; + // so allow some time for assertions to succeed. + static void Retry(Action action) + { + const int Retries = 50; - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1594")] - public void Bug1594() + for (var i = 0; i < Retries; i++) { - using (new TransactionScope()) + try { - using (var conn = OpenConnection(ConnectionStringEnlistOn)) - using (var innerScope1 = new TransactionScope()) - { - conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"); - innerScope1.Complete(); - } - using (OpenConnection(ConnectionStringEnlistOn)) - using (new TransactionScope()) + action(); + return; + } + catch (AssertionException) + { + if (i == Retries - 1) { - // Don't complete, triggering rollback + throw; } + + Thread.Sleep(100); } } + } - #region Utilities - - void AssertNoPreparedTransactions() - => Assert.That(GetNumberOfPreparedTransactions(), Is.EqualTo(0), "Prepared transactions found"); + void AssertNoPreparedTransactions() + => Assert.That(GetNumberOfPreparedTransactions(), Is.EqualTo(0), "Prepared transactions found"); - int GetNumberOfPreparedTransactions() + int GetNumberOfPreparedTransactions() + { + var dataSource = EnlistOffDataSource; + using (var conn = dataSource.OpenConnection()) + using (var cmd = new NpgsqlCommand("SELECT COUNT(*) FROM pg_prepared_xacts WHERE database = @database", conn)) { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) - using (var cmd = new NpgsqlCommand("SELECT COUNT(*) FROM pg_prepared_xacts WHERE database = @database", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("database", conn.Database)); - return (int)(long)cmd.ExecuteScalar(); - } + cmd.Parameters.Add(new NpgsqlParameter("database", conn.Database)); + return (int)(long)cmd.ExecuteScalar()!; } + } - void AssertNumberOfRows(int expected) - => Assert.That(_controlConn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(expected), "Unexpected data count"); + void AssertNumberOfRows(NpgsqlConnection connection, string table, int expected) + => Assert.That(connection.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(expected), "Unexpected data count"); - static void AssertNoDistributedIdentifier() - => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.EqualTo(Guid.Empty), "Distributed identifier found"); + static void AssertNoDistributedIdentifier() + => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.EqualTo(Guid.Empty), "Distributed identifier found"); - static void AssertHasDistributedIdentifier() - => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.Not.EqualTo(Guid.Empty), "Distributed identifier not found"); + static void AssertHasDistributedIdentifier() + => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.Not.EqualTo(Guid.Empty), "Distributed identifier not found"); - public string ConnectionStringEnlistOn - => new NpgsqlConnectionStringBuilder(ConnectionString) { Enlist = true }.ToString(); + NpgsqlDataSource EnlistOnDataSource { get; set; } = default!; - public string ConnectionStringEnlistOff - => new NpgsqlConnectionStringBuilder(ConnectionString) { Enlist = false }.ToString(); + NpgsqlDataSource EnlistOffDataSource { get; set; } = default!; - static string FormatEventQueue(ConcurrentQueue eventQueue) - { - eventQueue.Enqueue(new TransactionEvent(@"------------- + static string FormatEventQueue(ConcurrentQueue eventQueue) + { + eventQueue.Enqueue(new TransactionEvent(@"------------- Start formatting event queue, going to sleep a bit for late events -------------")); - Thread.Sleep(20); - var eventsMessage = new StringBuilder(); - foreach (var evt in eventQueue) - { - eventsMessage.AppendLine(evt.Message); - } - return eventsMessage.ToString(); - } - - // Idea from NHibernate test project, DtcFailuresFixture - public class EnlistResource : IEnlistmentNotification + Thread.Sleep(20); + var eventsMessage = new StringBuilder(); + foreach (var evt in eventQueue) { - public static int Counter { get; set; } - - readonly bool _shouldRollBack; - readonly string _name; - readonly ConcurrentQueue? _eventQueue; - - public static void EnlistVolatile(ConcurrentQueue eventQueue) - => EnlistVolatile(false, eventQueue); + eventsMessage.AppendLine(evt.Message); + } + return eventsMessage.ToString(); + } - public static void EnlistVolatile(bool shouldRollBack = false, ConcurrentQueue? eventQueue = null) - => Enlist(false, shouldRollBack, eventQueue); + // Idea from NHibernate test project, DtcFailuresFixture + public class EnlistResource : IEnlistmentNotification + { + public static int Counter { get; set; } - public static void EscalateToDistributed(ConcurrentQueue eventQueue) - => EscalateToDistributed(false, eventQueue); + readonly bool _shouldRollBack; + readonly string _name; + readonly ConcurrentQueue? _eventQueue; - public static void EscalateToDistributed(bool shouldRollBack = false, ConcurrentQueue? eventQueue = null) - => Enlist(true, shouldRollBack, eventQueue); + public static void EnlistVolatile(ConcurrentQueue eventQueue) + => EnlistVolatile(false, eventQueue); - static void Enlist(bool durable, bool shouldRollBack, ConcurrentQueue? eventQueue) - { - Counter++; + public static void EnlistVolatile(bool shouldRollBack = false, ConcurrentQueue? eventQueue = null) + => Enlist(false, shouldRollBack, eventQueue); - var name = $"{(durable ? "Durable" : "Volatile")} resource {Counter}"; - var resource = new EnlistResource(shouldRollBack, name, eventQueue); - if (durable) - Transaction.Current.EnlistDurable(Guid.NewGuid(), resource, EnlistmentOptions.None); - else - Transaction.Current.EnlistVolatile(resource, EnlistmentOptions.None); + public static void EscalateToDistributed(ConcurrentQueue eventQueue) + => EscalateToDistributed(false, eventQueue); - Transaction.Current.TransactionCompleted += resource.Current_TransactionCompleted; + public static void EscalateToDistributed(bool shouldRollBack = false, ConcurrentQueue? eventQueue = null) + => Enlist(true, shouldRollBack, eventQueue); - eventQueue?.Enqueue(new TransactionEvent(name + ": enlisted")); - } + static void Enlist(bool durable, bool shouldRollBack, ConcurrentQueue? eventQueue) + { + Counter++; - EnlistResource(bool shouldRollBack, string name, ConcurrentQueue? eventQueue) - { - _shouldRollBack = shouldRollBack; - _name = name; - _eventQueue = eventQueue; - } + var name = $"{(durable ? "Durable" : "Volatile")} resource {Counter}"; + var resource = new EnlistResource(shouldRollBack, name, eventQueue); + if (durable) + Transaction.Current!.EnlistDurable(Guid.NewGuid(), resource, EnlistmentOptions.None); + else + Transaction.Current!.EnlistVolatile(resource, EnlistmentOptions.None); - public void Prepare(PreparingEnlistment preparingEnlistment) - { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase start")); - Thread.Sleep(1); - if (_shouldRollBack) - { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase, calling rollback-ed")); - preparingEnlistment.ForceRollback(); - } - else - { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase, calling prepared")); - preparingEnlistment.Prepared(); - } - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase end")); - } + Transaction.Current.TransactionCompleted += resource.Current_TransactionCompleted!; - public void Commit(Enlistment enlistment) - { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase start")); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase, calling done")); - enlistment.Done(); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase end")); - } + eventQueue?.Enqueue(new TransactionEvent(name + ": enlisted")); + } - public void Rollback(Enlistment enlistment) - { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase start")); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase, calling done")); - enlistment.Done(); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase end")); - } + EnlistResource(bool shouldRollBack, string name, ConcurrentQueue? eventQueue) + { + _shouldRollBack = shouldRollBack; + _name = name; + _eventQueue = eventQueue; + } - public void InDoubt(Enlistment enlistment) + public void Prepare(PreparingEnlistment preparingEnlistment) + { + _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase start")); + Thread.Sleep(1); + if (_shouldRollBack) { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase start")); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase, calling done")); - enlistment.Done(); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase end")); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase, calling rollback-ed")); + preparingEnlistment.ForceRollback(); } - - void Current_TransactionCompleted(object sender, TransactionEventArgs e) + else { - _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed start")); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed middle")); - Thread.Sleep(1); - _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed end")); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase, calling prepared")); + preparingEnlistment.Prepared(); } + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": prepare phase end")); } - public class TransactionEvent + public void Commit(Enlistment enlistment) { - public TransactionEvent(string message) - { - Message = $"{message} (TId {Thread.CurrentThread.ManagedThreadId})"; - } - public string Message { get; } + _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase start")); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase, calling done")); + enlistment.Done(); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": commit phase end")); } - #endregion Utilities - - #region Setup - - NpgsqlConnection _controlConn = default!; - - [OneTimeSetUp] - public void OneTimeSetUp() + public void Rollback(Enlistment enlistment) { - using (new TransactionScope(TransactionScopeOption.RequiresNew)) - { - try - { - Transaction.Current.EnlistPromotableSinglePhase(new FakePromotableSinglePhaseNotification()); - } - catch (NotImplementedException) - { - Assert.Ignore("Promotable single phase transactions aren't supported (mono < 3.0.0?)"); - } - } + _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase start")); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase, calling done")); + enlistment.Done(); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": rollback phase end")); + } - _controlConn = OpenConnection(); + public void InDoubt(Enlistment enlistment) + { + _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase start")); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase, calling done")); + enlistment.Done(); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": in-doubt phase end")); + } - // Make sure prepared transactions are enabled in postgresql.conf (disabled by default) - if (int.Parse((string)_controlConn.ExecuteScalar("SHOW max_prepared_transactions")) == 0) - { - TestUtil.IgnoreExceptOnBuildServer("max_prepared_transactions is set to 0 in your postgresql.conf"); - _controlConn.Close(); - } + void Current_TransactionCompleted(object sender, TransactionEventArgs e) + { + _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed start")); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed middle")); + Thread.Sleep(1); + _eventQueue?.Enqueue(new TransactionEvent(_name + ": transaction completed end")); + } + } - // Rollback any lingering prepared transactions from failed previous runs - var lingeringTrqnsqctions = new List(); - using (var cmd = new NpgsqlCommand("SELECT gid FROM pg_prepared_xacts WHERE database=@database", _controlConn)) - { - cmd.Parameters.AddWithValue("database", new NpgsqlConnectionStringBuilder(ConnectionString).Database); - using (var reader = cmd.ExecuteReader()) - { - while (reader.Read()) - lingeringTrqnsqctions.Add(reader.GetString(0)); - } - } - foreach (var xactGid in lingeringTrqnsqctions) - _controlConn.ExecuteNonQuery($"ROLLBACK PREPARED '{xactGid}'"); + public class TransactionEvent + { + public TransactionEvent(string message) + => Message = $"{message} (TId {Thread.CurrentThread.ManagedThreadId})"; + public string Message { get; } + } - // All tests in this fixture should have exclusive access to the database they're running on. - // If we run these tests in parallel (i.e. two builds in parallel) they will interfere. - // Solve this by taking a PostgreSQL advisory lock for the lifetime of the fixture. - _controlConn.ExecuteNonQuery("SELECT pg_advisory_lock(666)"); + #endregion Utilities - _controlConn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - _controlConn.ExecuteNonQuery("CREATE TABLE data (name TEXT)"); - } + #region Setup - [SetUp] - public void SetUp() + [OneTimeSetUp] + public void OneTimeSetUp() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - _controlConn.ExecuteNonQuery("TRUNCATE data"); - EnlistResource.Counter = 0; + Assert.Ignore("Distributed transactions are only supported on Windows"); + return; } - [OneTimeTearDown] - public void OneTimeTearDown() + using var connection = OpenConnection(); + + // Make sure prepared transactions are enabled in postgresql.conf (disabled by default) + if (int.Parse((string)connection.ExecuteScalar("SHOW max_prepared_transactions")!) == 0) { - _controlConn?.Close(); - _controlConn = null!; + IgnoreExceptOnBuildServer("max_prepared_transactions is set to 0 in your postgresql.conf"); + return; } - class FakePromotableSinglePhaseNotification : IPromotableSinglePhaseNotification + // Roll back any lingering prepared transactions from failed previous runs + var lingeringTransactions = new List(); + using (var cmd = new NpgsqlCommand("SELECT gid FROM pg_prepared_xacts WHERE database=@database", connection)) { - public byte[] Promote() => null!; - public void Initialize() {} - public void SinglePhaseCommit(SinglePhaseEnlistment singlePhaseEnlistment) {} - public void Rollback(SinglePhaseEnlistment singlePhaseEnlistment) {} + cmd.Parameters.AddWithValue("database", new NpgsqlConnectionStringBuilder(ConnectionString).Database!); + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + lingeringTransactions.Add(reader.GetString(0)); } + foreach (var xactGid in lingeringTransactions) + connection.ExecuteNonQuery($"ROLLBACK PREPARED '{xactGid}'"); - #endregion + EnlistOnDataSource = CreateDataSource(csb => csb.Enlist = true); + EnlistOffDataSource = CreateDataSource(csb => csb.Enlist = false); } + + [OneTimeTearDown] + public void OnTimeTearDown() + { + EnlistOnDataSource?.Dispose(); + EnlistOnDataSource = null!; + EnlistOffDataSource?.Dispose(); + EnlistOffDataSource = null!; + } + + [SetUp] + public void SetUp() + => EnlistResource.Counter = 0; + + internal static string CreateTempTable(NpgsqlConnection conn, string columns) + { + var tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); + conn.ExecuteNonQuery(@$" +START TRANSACTION; SELECT pg_advisory_xact_lock(0); +DROP TABLE IF EXISTS {tableName} CASCADE; +COMMIT; +CREATE TABLE {tableName} ({columns})"); + return tableName; + } + + #endregion } #endif diff --git a/test/Npgsql.Tests/ExceptionTests.cs b/test/Npgsql.Tests/ExceptionTests.cs index 4333700570..f9f8821c4d 100644 --- a/test/Npgsql.Tests/ExceptionTests.cs +++ b/test/Npgsql.Tests/ExceptionTests.cs @@ -9,63 +9,62 @@ using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class ExceptionTests : TestBase { - public class ExceptionTests : TestBase + [Test, Description("Generates a basic server-side exception, checks that it's properly raised and populated")] + public void Basic() { - [Test, Description("Generates a basic server-side exception, checks that it's properly raised and populated")] - public void Basic() + // Make sure messages are in English + using var dataSource = CreateDataSource(csb => csb.Options = "-c lc_messages=en_US.UTF-8"); + using var conn = dataSource.OpenConnection(); + conn.ExecuteNonQuery( +""" +CREATE OR REPLACE FUNCTION pg_temp.emit_exception() RETURNS VOID AS + 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345'', DETAIL = ''testdetail''; END;' +LANGUAGE 'plpgsql'; +"""); + + PostgresException ex = null!; + try { - using var conn = OpenConnection(new NpgsqlConnectionStringBuilder(ConnectionString) - { - // Make sure messages are in English - Options = "-c lc_messages=en_US.UTF-8" - }); - conn.ExecuteNonQuery(@" - CREATE OR REPLACE FUNCTION pg_temp.emit_exception() RETURNS VOID AS - 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345'', DETAIL = ''testdetail''; END;' - LANGUAGE 'plpgsql'; - "); - - PostgresException ex = null!; - try - { - conn.ExecuteNonQuery("SELECT pg_temp.emit_exception()"); - Assert.Fail("No exception was thrown"); - } - catch (PostgresException e) - { - ex = e; - } - - Assert.That(ex.MessageText, Is.EqualTo("testexception")); - Assert.That(ex.Severity, Is.EqualTo("ERROR")); - Assert.That(ex.InvariantSeverity, Is.EqualTo("ERROR")); - Assert.That(ex.SqlState, Is.EqualTo("12345")); - Assert.That(ex.Position, Is.EqualTo(0)); - Assert.That(ex.Message, Is.EqualTo("12345: testexception")); - - var data = ex.Data; - Assert.That(data[nameof(PostgresException.Severity)], Is.EqualTo("ERROR")); - Assert.That(data[nameof(PostgresException.SqlState)], Is.EqualTo("12345")); - Assert.That(data.Contains(nameof(PostgresException.Position)), Is.False); - - var exString = ex.ToString(); - Assert.That(exString, Does.StartWith("Npgsql.PostgresException (0x80004005): 12345: testexception")); - Assert.That(exString, Contains.Substring(nameof(PostgresException.Severity) + ": ERROR")); - Assert.That(exString, Contains.Substring(nameof(PostgresException.SqlState) + ": 12345")); - - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1), "Connection in bad state after an exception"); + conn.ExecuteNonQuery("SELECT pg_temp.emit_exception()"); + Assert.Fail("No exception was thrown"); } - - [Test, Description("Ensures Detail is redacted by default in PostgresException and PostgresNotice")] - public async Task ErrorDetailsAreRedacted() + catch (PostgresException e) { - await using var conn = await OpenConnectionAsync(); - await using var _ = GetTempFunctionName(conn, out var raiseExceptionFunc); - await using var __ = GetTempFunctionName(conn, out var raiseNoticeFunc); + ex = e; + } + + Assert.That(ex.MessageText, Is.EqualTo("testexception")); + Assert.That(ex.Severity, Is.EqualTo("ERROR")); + Assert.That(ex.InvariantSeverity, Is.EqualTo("ERROR")); + Assert.That(ex.SqlState, Is.EqualTo("12345")); + Assert.That(ex.Position, Is.EqualTo(0)); + Assert.That(ex.Message, Does.StartWith("12345: testexception")); + + var data = ex.Data; + Assert.That(data[nameof(PostgresException.Severity)], Is.EqualTo("ERROR")); + Assert.That(data[nameof(PostgresException.SqlState)], Is.EqualTo("12345")); + Assert.That(data.Contains(nameof(PostgresException.Position)), Is.False); + + var exString = ex.ToString(); + Assert.That(exString, Does.StartWith("Npgsql.PostgresException (0x80004005): 12345: testexception")); + Assert.That(exString, Contains.Substring(nameof(PostgresException.Severity) + ": ERROR")); + Assert.That(exString, Contains.Substring(nameof(PostgresException.SqlState) + ": 12345")); + + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1), "Connection in bad state after an exception"); + } + + [Test, Description("Ensures Detail is redacted by default in PostgresException and PostgresNotice")] + public async Task Error_details_are_redacted() + { + await using var conn = await OpenConnectionAsync(); + var raiseExceptionFunc = await GetTempFunctionName(conn); + var raiseNoticeFunc = await GetTempFunctionName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE OR REPLACE FUNCTION {raiseExceptionFunc}() RETURNS VOID AS $$ BEGIN RAISE EXCEPTION 'testexception' USING DETAIL = 'secret'; @@ -78,27 +77,27 @@ await conn.ExecuteNonQueryAsync($@" END; $$ LANGUAGE 'plpgsql';"); - var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseExceptionFunc}()")); - Assert.That(ex.Detail, Does.Not.Contain("secret")); - Assert.That(ex.Data[nameof(PostgresException.Detail)], Does.Not.Contain("secret")); - Assert.That(ex.ToString(), Does.Not.Contain("secret")); + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseExceptionFunc}()"))!; + Assert.That(ex.Detail, Does.Not.Contain("secret")); + Assert.That(ex.Message, Does.Not.Contain("secret")); + Assert.That(ex.Data[nameof(PostgresException.Detail)], Does.Not.Contain("secret")); + Assert.That(ex.ToString(), Does.Not.Contain("secret")); - PostgresNotice? notice = null; - conn.Notice += (___, a) => notice = a.Notice; - await conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseNoticeFunc}()"); - Assert.That(notice!.Detail, Does.Not.Contain("secret")); - } + PostgresNotice? notice = null; + conn.Notice += (___, a) => notice = a.Notice; + await conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseNoticeFunc}()"); + Assert.That(notice!.Detail, Does.Not.Contain("secret")); + } - [Test] - public async Task IncludeErrorDetails() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { IncludeErrorDetails = true }; - using var _ = CreateTempPool(builder, out var connectionStringWithDetails); - await using var conn = await OpenConnectionAsync(connectionStringWithDetails); - await using var __ = GetTempFunctionName(conn, out var raiseExceptionFunc); - await using var ___ = GetTempFunctionName(conn, out var raiseNoticeFunc); + [Test] + public async Task IncludeErrorDetail() + { + await using var dataSource = CreateDataSource(csb => csb.IncludeErrorDetail = true); + await using var conn = await dataSource.OpenConnectionAsync(); + var raiseExceptionFunc = await GetTempFunctionName(conn); + var raiseNoticeFunc = await GetTempFunctionName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE OR REPLACE FUNCTION {raiseExceptionFunc}() RETURNS VOID AS $$ BEGIN RAISE EXCEPTION 'testexception' USING DETAIL = 'secret'; @@ -111,208 +110,203 @@ await conn.ExecuteNonQueryAsync($@" END; $$ LANGUAGE 'plpgsql';"); - var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseExceptionFunc}()")); - Assert.That(ex.Detail, Does.Contain("secret")); - Assert.That(ex.Data[nameof(PostgresException.Detail)], Does.Contain("secret")); - Assert.That(ex.ToString(), Does.Contain("secret")); + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseExceptionFunc}()"))!; + Assert.That(ex.Detail, Does.Contain("secret")); + Assert.That(ex.Message, Does.Contain("secret")); + Assert.That(ex.Data[nameof(PostgresException.Detail)], Does.Contain("secret")); + Assert.That(ex.ToString(), Does.Contain("secret")); - PostgresNotice? notice = null; - conn.Notice += (____, a) => notice = a.Notice; - await conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseNoticeFunc}()"); - Assert.That(notice!.Detail, Does.Contain("secret")); - } + PostgresNotice? notice = null; + conn.Notice += (____, a) => notice = a.Notice; + await conn.ExecuteNonQueryAsync($"SELECT * FROM {raiseNoticeFunc}()"); + Assert.That(notice!.Detail, Does.Contain("secret")); + } - [Test] - public void ExceptionFieldsArePopulated() - { - using (var conn = OpenConnection()) - { - TestUtil.MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); - conn.ExecuteNonQuery("CREATE TEMP TABLE uniqueviolation (id INT NOT NULL, CONSTRAINT uniqueviolation_pkey PRIMARY KEY (id))"); - conn.ExecuteNonQuery("INSERT INTO uniqueviolation (id) VALUES(1)"); - try - { - conn.ExecuteNonQuery("INSERT INTO uniqueviolation (id) VALUES(1)"); - } - catch (PostgresException ex) - { - Assert.That(ex.ColumnName, Is.Null, "ColumnName should not be populated for unique violations"); - Assert.That(ex.TableName, Is.EqualTo("uniqueviolation")); - Assert.That(ex.SchemaName, Does.StartWith("pg_temp")); - Assert.That(ex.ConstraintName, Is.EqualTo("uniqueviolation_pkey")); - Assert.That(ex.DataTypeName, Is.Null, "DataTypeName should not be populated for unique violations"); - } - } - } + [Test] + public async Task Error_position() + { + await using var conn = await OpenConnectionAsync(); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT 1; SELECT * FROM \"NonExistingTable\""))!; + Assert.That(ex.Message, Does.Contain("POSITION: 15")); + } - [Test] - public void ColumnNameExceptionFieldIsPopulated() + [Test] + public void Exception_fields_are_populated() + { + using var conn = OpenConnection(); + TestUtil.MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); + conn.ExecuteNonQuery("CREATE TEMP TABLE uniqueviolation (id INT NOT NULL, CONSTRAINT uniqueviolation_pkey PRIMARY KEY (id))"); + conn.ExecuteNonQuery("INSERT INTO uniqueviolation (id) VALUES(1)"); + try { - using (var conn = OpenConnection()) - { - TestUtil.MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); - conn.ExecuteNonQuery("CREATE TEMP TABLE notnullviolation (id INT NOT NULL)"); - try - { - conn.ExecuteNonQuery("INSERT INTO notnullviolation (id) VALUES(NULL)"); - } - catch (PostgresException ex) - { - Assert.That(ex.SchemaName, Does.StartWith("pg_temp")); - Assert.That(ex.TableName, Is.EqualTo("notnullviolation")); - Assert.That(ex.ColumnName, Is.EqualTo("id")); - } - } + conn.ExecuteNonQuery("INSERT INTO uniqueviolation (id) VALUES(1)"); } - - [Test] - [NonParallelizable] - public void DataTypeNameExceptionFieldIsPopulated() + catch (PostgresException ex) { - // On reading the source code for PostgreSQL9.3beta1, the only time that the - // datatypename field is populated is when using domain types. So here we'll - // create a domain that simply does not allow NULLs then try and cast NULL - // to it. - const string dropDomain = @"DROP DOMAIN IF EXISTS public.intnotnull"; - const string createDomain = @"CREATE DOMAIN public.intnotnull AS INT NOT NULL"; - const string castStatement = @"SELECT CAST(NULL AS public.intnotnull)"; - - using (var conn = OpenConnection()) - { - TestUtil.MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); - try - { - var command = new NpgsqlCommand(dropDomain, conn); - command.ExecuteNonQuery(); - - command = new NpgsqlCommand(createDomain, conn); - command.ExecuteNonQuery(); - - command = new NpgsqlCommand(castStatement, conn); - //Cause the NOT NULL violation - command.ExecuteNonQuery(); - - } - catch (PostgresException ex) - { - Assert.AreEqual("public", ex.SchemaName); - Assert.AreEqual("intnotnull", ex.DataTypeName); - } - } + Assert.That(ex.ColumnName, Is.Null, "ColumnName should not be populated for unique violations"); + Assert.That(ex.TableName, Is.EqualTo("uniqueviolation")); + Assert.That(ex.SchemaName, Does.StartWith("pg_temp")); + Assert.That(ex.ConstraintName, Is.EqualTo("uniqueviolation_pkey")); + Assert.That(ex.DataTypeName, Is.Null, "DataTypeName should not be populated for unique violations"); } + } - [Test] - public void NpgsqlExceptionInAsync() + [Test] + public void Column_name_exception_field_is_populated() + { + using var conn = OpenConnection(); + TestUtil.MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); + conn.ExecuteNonQuery("CREATE TEMP TABLE notnullviolation (id INT NOT NULL)"); + try { - using (var conn = OpenConnection()) - { - Assert.That(async () => await conn.ExecuteNonQueryAsync("MALFORMED"), - Throws.Exception.TypeOf()); - // Just in case, anything but a PostgresException would trigger the connection breaking, check that - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); - } + conn.ExecuteNonQuery("INSERT INTO notnullviolation (id) VALUES(NULL)"); } - - [Test] - public void NpgsqlExceptionTransience() + catch (PostgresException ex) { - Assert.True(new NpgsqlException("", new IOException()).IsTransient); - Assert.True(new NpgsqlException("", new SocketException()).IsTransient); - Assert.True(new NpgsqlException("", new TimeoutException()).IsTransient); - Assert.False(new NpgsqlException().IsTransient); - Assert.False(new NpgsqlException("", new Exception("Inner Exception")).IsTransient); + Assert.That(ex.SchemaName, Does.StartWith("pg_temp")); + Assert.That(ex.TableName, Is.EqualTo("notnullviolation")); + Assert.That(ex.ColumnName, Is.EqualTo("id")); } + } - [Test] - public void PostgresExceptionTransience() + [Test] + public async Task DataTypeName_is_populated() + { + // On reading the source code for PostgreSQL9.3beta1, the only time that the + // datatypename field is populated is when using domain types. So here we'll + // create a domain that simply does not allow NULLs then try and cast NULL + // to it. + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.3.0", "5 error fields haven't been added yet"); + + var domainName = await GetTempTypeName(conn); + + await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {domainName} AS INT NOT NULL"); + var pgEx = Assert.ThrowsAsync(async () => await conn.ExecuteNonQueryAsync($"SELECT CAST(NULL AS {domainName})"))!; + + Assert.That(pgEx.SqlState, Is.EqualTo(PostgresErrorCodes.NotNullViolation)); + Assert.That(pgEx.SchemaName, Is.EqualTo("public")); + Assert.That(pgEx.DataTypeName, Is.EqualTo(domainName)); + } + + [Test] + public void NpgsqlException_with_async() + { + using var conn = OpenConnection(); + Assert.That(async () => await conn.ExecuteNonQueryAsync("MALFORMED"), + Throws.Exception.TypeOf()); + // Just in case, anything but a PostgresException would trigger the connection breaking, check that + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + + [Test] + public void NpgsqlException_IsTransient() + { + Assert.True(new NpgsqlException("", new IOException()).IsTransient); + Assert.True(new NpgsqlException("", new SocketException()).IsTransient); + Assert.True(new NpgsqlException("", new TimeoutException()).IsTransient); + Assert.False(new NpgsqlException().IsTransient); + Assert.False(new NpgsqlException("", new Exception("Inner Exception")).IsTransient); + } + +#pragma warning disable SYSLIB0051 +#pragma warning disable 618 + [Test] + public void PostgresException_IsTransient() + { + Assert.True(CreateWithSqlState("53300").IsTransient); + Assert.False(CreateWithSqlState("0").IsTransient); + + PostgresException CreateWithSqlState(string sqlState) { - Assert.True(CreateWithSqlState("53300").IsTransient); - Assert.False(CreateWithSqlState("0").IsTransient); - - PostgresException CreateWithSqlState(string sqlState) - { - var info = CreateSerializationInfo(); - new Exception().GetObjectData(info, default); - - info.AddValue(nameof(PostgresException.Severity), null); - info.AddValue(nameof(PostgresException.InvariantSeverity), null); - info.AddValue(nameof(PostgresException.SqlState), sqlState); - info.AddValue(nameof(PostgresException.MessageText), null); - info.AddValue(nameof(PostgresException.Detail), null); - info.AddValue(nameof(PostgresException.Hint), null); - info.AddValue(nameof(PostgresException.Position), 0); - info.AddValue(nameof(PostgresException.InternalPosition), 0); - info.AddValue(nameof(PostgresException.InternalQuery), null); - info.AddValue(nameof(PostgresException.Where), null); - info.AddValue(nameof(PostgresException.SchemaName), null); - info.AddValue(nameof(PostgresException.TableName), null); - info.AddValue(nameof(PostgresException.ColumnName), null); - info.AddValue(nameof(PostgresException.DataTypeName), null); - info.AddValue(nameof(PostgresException.ConstraintName), null); - info.AddValue(nameof(PostgresException.File), null); - info.AddValue(nameof(PostgresException.Line), null); - info.AddValue(nameof(PostgresException.Routine), null); - - return new PostgresException(info, default); - } + var info = CreateSerializationInfo(); + new Exception().GetObjectData(info, default); + + info.AddValue(nameof(PostgresException.Severity), null); + info.AddValue(nameof(PostgresException.InvariantSeverity), null); + info.AddValue(nameof(PostgresException.SqlState), sqlState); + info.AddValue(nameof(PostgresException.MessageText), null); + info.AddValue(nameof(PostgresException.Detail), null); + info.AddValue(nameof(PostgresException.Hint), null); + info.AddValue(nameof(PostgresException.Position), 0); + info.AddValue(nameof(PostgresException.InternalPosition), 0); + info.AddValue(nameof(PostgresException.InternalQuery), null); + info.AddValue(nameof(PostgresException.Where), null); + info.AddValue(nameof(PostgresException.SchemaName), null); + info.AddValue(nameof(PostgresException.TableName), null); + info.AddValue(nameof(PostgresException.ColumnName), null); + info.AddValue(nameof(PostgresException.DataTypeName), null); + info.AddValue(nameof(PostgresException.ConstraintName), null); + info.AddValue(nameof(PostgresException.File), null); + info.AddValue(nameof(PostgresException.Line), null); + info.AddValue(nameof(PostgresException.Routine), null); + + return new PostgresException(info, default); } + } +#pragma warning restore SYSLIB0051 +#pragma warning restore 618 #pragma warning disable SYSLIB0011 +#pragma warning disable SYSLIB0050 #pragma warning disable 618 - [Test] - public void Serialization() - { - var actual = new PostgresException("message text", "high", "high2", "53300", "detail", "hint", 18, 42, "internal query", - "where", "schema", "table", "column", "data type", "constraint", "file", "line", "routine"); - - var formatter = new BinaryFormatter(); - var stream = new MemoryStream(); - - formatter.Serialize(stream, actual); - stream.Seek(0, SeekOrigin.Begin); - - var expected = (PostgresException)formatter.Deserialize(stream); - - Assert.That(expected.Severity, Is.EqualTo(actual.Severity)); - Assert.That(expected.InvariantSeverity, Is.EqualTo(actual.InvariantSeverity)); - Assert.That(expected.SqlState, Is.EqualTo(actual.SqlState)); - Assert.That(expected.MessageText, Is.EqualTo(actual.MessageText)); - Assert.That(expected.Detail, Is.EqualTo(actual.Detail)); - Assert.That(expected.Hint, Is.EqualTo(actual.Hint)); - Assert.That(expected.Position, Is.EqualTo(actual.Position)); - Assert.That(expected.InternalPosition, Is.EqualTo(actual.InternalPosition)); - Assert.That(expected.InternalQuery, Is.EqualTo(actual.InternalQuery)); - Assert.That(expected.Where, Is.EqualTo(actual.Where)); - Assert.That(expected.SchemaName, Is.EqualTo(actual.SchemaName)); - Assert.That(expected.TableName, Is.EqualTo(actual.TableName)); - Assert.That(expected.ColumnName, Is.EqualTo(actual.ColumnName)); - Assert.That(expected.DataTypeName, Is.EqualTo(actual.DataTypeName)); - Assert.That(expected.ConstraintName, Is.EqualTo(actual.ConstraintName)); - Assert.That(expected.File, Is.EqualTo(actual.File)); - Assert.That(expected.Line, Is.EqualTo(actual.Line)); - Assert.That(expected.Routine, Is.EqualTo(actual.Routine)); - } + [Test] + public void Serialization() + { + var actual = new PostgresException("message text", "high", "high2", "53300", "detail", "hint", 18, 42, "internal query", + "where", "schema", "table", "column", "data type", "constraint", "file", "line", "routine"); + + var formatter = new BinaryFormatter(); + var stream = new MemoryStream(); + + formatter.Serialize(stream, actual); + stream.Seek(0, SeekOrigin.Begin); + + var expected = (PostgresException)formatter.Deserialize(stream); + + Assert.That(expected.Severity, Is.EqualTo(actual.Severity)); + Assert.That(expected.InvariantSeverity, Is.EqualTo(actual.InvariantSeverity)); + Assert.That(expected.SqlState, Is.EqualTo(actual.SqlState)); + Assert.That(expected.MessageText, Is.EqualTo(actual.MessageText)); + Assert.That(expected.Detail, Is.EqualTo(actual.Detail)); + Assert.That(expected.Hint, Is.EqualTo(actual.Hint)); + Assert.That(expected.Position, Is.EqualTo(actual.Position)); + Assert.That(expected.InternalPosition, Is.EqualTo(actual.InternalPosition)); + Assert.That(expected.InternalQuery, Is.EqualTo(actual.InternalQuery)); + Assert.That(expected.Where, Is.EqualTo(actual.Where)); + Assert.That(expected.SchemaName, Is.EqualTo(actual.SchemaName)); + Assert.That(expected.TableName, Is.EqualTo(actual.TableName)); + Assert.That(expected.ColumnName, Is.EqualTo(actual.ColumnName)); + Assert.That(expected.DataTypeName, Is.EqualTo(actual.DataTypeName)); + Assert.That(expected.ConstraintName, Is.EqualTo(actual.ConstraintName)); + Assert.That(expected.File, Is.EqualTo(actual.File)); + Assert.That(expected.Line, Is.EqualTo(actual.Line)); + Assert.That(expected.Routine, Is.EqualTo(actual.Routine)); + } - SerializationInfo CreateSerializationInfo() => new SerializationInfo(typeof(PostgresException), new FormatterConverter()); + SerializationInfo CreateSerializationInfo() => new(typeof(PostgresException), new FormatterConverter()); #pragma warning restore 618 #pragma warning restore SYSLIB0011 +#pragma warning disable SYSLIB0050 - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3204")] - public void BaseExceptionPropertySerialization() - { - var ex = new PostgresException("the message", "low", "low2", "XX123"); - - var info = CreateSerializationInfo(); - ex.GetObjectData(info, default); - - // Check virtual base properties, which can be incorrectly deserialized if overridden, because the base - // Exception.GetObjectData() method writes the fields, not the properties (e.g. "_message" instead of "Message"). - Assert.That(ex.Data, Is.EquivalentTo((IDictionary?)info.GetValue("Data", typeof(IDictionary)))); - Assert.That(ex.HelpLink, Is.EqualTo(info.GetValue("HelpURL", typeof(string)))); - Assert.That(ex.Message, Is.EqualTo(info.GetValue("Message", typeof(string)))); - Assert.That(ex.Source, Is.EqualTo(info.GetValue("Source", typeof(string)))); - Assert.That(ex.StackTrace, Is.EqualTo(info.GetValue("StackTraceString", typeof(string)))); - } +#pragma warning disable SYSLIB0051 + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3204")] + public void Base_exception_property_serialization() + { + var ex = new PostgresException("the message", "low", "low2", "XX123"); + + var info = CreateSerializationInfo(); + ex.GetObjectData(info, default); + + // Check virtual base properties, which can be incorrectly deserialized if overridden, because the base + // Exception.GetObjectData() method writes the fields, not the properties (e.g. "_message" instead of "Message"). + Assert.That(ex.Data, Is.EquivalentTo((IDictionary?)info.GetValue("Data", typeof(IDictionary)))); + Assert.That(ex.HelpLink, Is.EqualTo(info.GetValue("HelpURL", typeof(string)))); + Assert.That(ex.Message, Is.EqualTo(info.GetValue("Message", typeof(string)))); + Assert.That(ex.Source, Is.EqualTo(info.GetValue("Source", typeof(string)))); + Assert.That(ex.StackTrace, Is.EqualTo(info.GetValue("StackTraceString", typeof(string)))); } +#pragma warning restore SYSLIB0051 } diff --git a/test/Npgsql.Tests/FunctionTests.cs b/test/Npgsql.Tests/FunctionTests.cs index 1f5a2b7a09..37f203b812 100644 --- a/test/Npgsql.Tests/FunctionTests.cs +++ b/test/Npgsql.Tests/FunctionTests.cs @@ -1,149 +1,525 @@ using System; using System.Data; +using System.Threading.Tasks; +using Npgsql.PostgresTypes; +using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +/// +/// A fixture for tests which interact with functions. +/// All tests should create functions in the pg_temp schema only to ensure there's no interaction between +/// the tests. +/// +[NonParallelizable] // Manipulates the EnableStoredProcedureCompatMode global flag +public class FunctionTests : TestBase { - /// - /// A fixture for tests which interact with functions. - /// All tests should create functions in the pg_temp schema only to ensure there's no interaction between - /// the tests. - /// - public class FunctionTests : TestBase - { - [Test, Description("Simple function with no parameters, results accessed as a resultset")] - public void ResultSet() + [Test, Description("Simple function with no parameters, results accessed as a resultset")] + public async Task Resultset() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync($"CREATE FUNCTION {function}() RETURNS integer AS 'SELECT 8' LANGUAGE sql"); + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + } + + [Test, Description("Basic function call with an in parameter")] + public async Task Param_Input() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync($"CREATE FUNCTION {function}(IN param text) RETURNS text AS 'SELECT param' LANGUAGE sql"); + await using var cmd = new NpgsqlCommand(function, conn); + cmd.CommandType = CommandType.StoredProcedure; + cmd.Parameters.AddWithValue("@param", "hello"); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("hello")); + } + + [Test, Description("Basic function call with an out parameter")] + public async Task Param_Output() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync(@$" +CREATE FUNCTION {function} (IN param_in text, OUT param_out text) AS $$ +BEGIN + param_out=param_in; +END +$$ LANGUAGE plpgsql"); + await using var cmd = new NpgsqlCommand(function, conn); + cmd.CommandType = CommandType.StoredProcedure; + cmd.Parameters.AddWithValue("@param_in", "hello"); + var outParam = new NpgsqlParameter("param_out", DbType.String) { Direction = ParameterDirection.Output }; + cmd.Parameters.Add(outParam); + await cmd.ExecuteNonQueryAsync(); + Assert.That(outParam.Value, Is.EqualTo("hello")); + } + + [Test, Description("Basic function call with an in/out parameter")] + public async Task Param_InputOutput() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync($@" +CREATE FUNCTION {function} (INOUT param integer) AS $$ +BEGIN + param=param+1; +END +$$ LANGUAGE plpgsql"); + await using var cmd = new NpgsqlCommand(function, conn); + cmd.CommandType = CommandType.StoredProcedure; + var outParam = new NpgsqlParameter("param", DbType.Int32) { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery(@"CREATE FUNCTION pg_temp.func() RETURNS integer AS 'SELECT 8;' LANGUAGE 'sql'"); - using (var cmd = new NpgsqlCommand("pg_temp.func", conn) { CommandType = CommandType.StoredProcedure }) - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); - } - } + Direction = ParameterDirection.InputOutput, + Value = 8 + }; + cmd.Parameters.Add(outParam); + await cmd.ExecuteNonQueryAsync(); + Assert.That(outParam.Value, Is.EqualTo(9)); + } - [Test, Description("Basic function call with an in parameter")] - public void InParam() + [Test] + public async Task Void() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.1.0", "no binary output function available for type void before 9.1.0"); + var command = new NpgsqlCommand("pg_sleep", conn); + command.Parameters.AddWithValue(0); + command.CommandType = CommandType.StoredProcedure; + await command.ExecuteNonQueryAsync(); + } + + [Test] + public async Task Named_parameters() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.4.0", "make_timestamp was introduced in 9.4"); + await using var command = new NpgsqlCommand("make_timestamp", conn); + command.CommandType = CommandType.StoredProcedure; + command.Parameters.AddWithValue("year", 2015); + command.Parameters.AddWithValue("month", 8); + command.Parameters.AddWithValue("mday", 1); + command.Parameters.AddWithValue("hour", 2); + command.Parameters.AddWithValue("min", 3); + command.Parameters.AddWithValue("sec", 4); + var dt = (DateTime)(await command.ExecuteScalarAsync())!; + + Assert.AreEqual(new DateTime(2015, 8, 1, 2, 3, 4), dt); + + command.Parameters[0].Value = 2014; + command.Parameters[0].ParameterName = ""; // 2014 will be sent as a positional parameter + dt = (DateTime)(await command.ExecuteScalarAsync())!; + Assert.AreEqual(new DateTime(2014, 8, 1, 2, 3, 4), dt); + } + + [Test] + public async Task Too_many_output_params() + { + await using var conn = await OpenConnectionAsync(); + var command = new NpgsqlCommand("VALUES (4,5), (6,7)", conn); + command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32) { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery(@"CREATE FUNCTION pg_temp.echo(IN param text) RETURNS text AS 'BEGIN RETURN param; END;' LANGUAGE 'plpgsql'"); - using (var cmd = new NpgsqlCommand("pg_temp.echo", conn)) - { - cmd.CommandType = CommandType.StoredProcedure; - cmd.Parameters.AddWithValue("@param", "hello"); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo("hello")); - } - } - } + Direction = ParameterDirection.Output, + Value = -1 + }); + command.Parameters.Add(new NpgsqlParameter("b", DbType.Int32) + { + Direction = ParameterDirection.Output, + Value = -1 + }); + command.Parameters.Add(new NpgsqlParameter("c", DbType.Int32) + { + Direction = ParameterDirection.Output, + Value = -1 + }); + + await command.ExecuteNonQueryAsync(); - [Test, Description("Basic function call with an out parameter")] - public void OutParam() + Assert.That(command.Parameters["a"].Value, Is.EqualTo(4)); + Assert.That(command.Parameters["b"].Value, Is.EqualTo(5)); + Assert.That(command.Parameters["c"].Value, Is.EqualTo(-1)); + } + + [Test] + public async Task CommandBehavior_SchemaOnly_support_function_call() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($"CREATE OR REPLACE FUNCTION {function}() RETURNS SETOF integer as 'SELECT 1;' LANGUAGE 'sql';"); + var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + await using var dr = await command.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var i = 0; + while (dr.Read()) + i++; + Assert.AreEqual(0, i); + } + + #region DeriveParameters + + [Test, Description("Tests function parameter derivation with IN, OUT and INOUT parameters")] + public async Task DeriveParameters_function_various() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync($@" +CREATE FUNCTION {function}(IN param1 INT, OUT param2 text, INOUT param3 INT) RETURNS record AS $$ +BEGIN + param2 = 'sometext'; + param3 = param1 + param3; +END; +$$ LANGUAGE plpgsql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(cmd.Parameters[0].PostgresType, Is.TypeOf()); + Assert.That(cmd.Parameters[0].DataTypeName, Is.EqualTo("integer")); + Assert.That(cmd.Parameters[0].ParameterName, Is.EqualTo("param1")); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); + Assert.That(cmd.Parameters[1].PostgresType, Is.TypeOf()); + Assert.That(cmd.Parameters[1].DataTypeName, Is.EqualTo("text")); + Assert.That(cmd.Parameters[1].ParameterName, Is.EqualTo("param2")); + Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.InputOutput)); + Assert.That(cmd.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(cmd.Parameters[2].PostgresType, Is.TypeOf()); + Assert.That(cmd.Parameters[2].DataTypeName, Is.EqualTo("integer")); + Assert.That(cmd.Parameters[2].ParameterName, Is.EqualTo("param3")); + cmd.Parameters[0].Value = 5; + cmd.Parameters[2].Value = 4; + await cmd.ExecuteNonQueryAsync(); + Assert.That(cmd.Parameters[0].Value, Is.EqualTo(5)); + Assert.That(cmd.Parameters[1].Value, Is.EqualTo("sometext")); + Assert.That(cmd.Parameters[2].Value, Is.EqualTo(9)); + } + + [Test, Description("Tests function parameter derivation with IN-only parameters")] + public async Task DeriveParameters_function_in_only() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync( + $@"CREATE FUNCTION {function}(IN param1 INT, IN param2 INT) RETURNS int AS 'SELECT param1 + param2' LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); + cmd.Parameters[0].Value = 5; + cmd.Parameters[1].Value = 4; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); + } + + [Test, Description("Tests function parameter derivation with no parameters")] + public async Task DeriveParameters_function_no_params() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@"CREATE FUNCTION {function}() RETURNS int AS 'SELECT 4' LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Is.Empty); + } + + [Test] + public async Task DeriveParameters_function_with_case_sensitive_name() + { + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync( + @"CREATE OR REPLACE FUNCTION ""FunctionCaseSensitive""(int4, text) RETURNS int4 AS 'SELECT 0' LANGUAGE sql"); + + try { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery(@"CREATE FUNCTION pg_temp.echo (IN param_in text, OUT param_out text) AS 'BEGIN param_out=param_in; END;' LANGUAGE 'plpgsql'"); - using (var cmd = new NpgsqlCommand("pg_temp.echo", conn)) - { - cmd.CommandType = CommandType.StoredProcedure; - cmd.Parameters.AddWithValue("@param_in", "hello"); - var outParam = new NpgsqlParameter("param_out", DbType.String) { Direction = ParameterDirection.Output }; - cmd.Parameters.Add(outParam); - cmd.ExecuteNonQuery(); - Assert.That(outParam.Value, Is.EqualTo("hello")); - } - } + await using var command = new NpgsqlCommand(@"""FunctionCaseSensitive""", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); } - - [Test, Description("Basic function call with an in/out parameter")] - public void InOutParam() + finally { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery(@"CREATE FUNCTION pg_temp.inc (INOUT param integer) AS 'BEGIN param=param+1; END;' LANGUAGE 'plpgsql'"); - using (var cmd = new NpgsqlCommand("pg_temp.inc", conn)) - { - cmd.CommandType = CommandType.StoredProcedure; - var outParam = new NpgsqlParameter("param", DbType.Int32) - { - Direction = ParameterDirection.InputOutput, - Value = 8 - }; - cmd.Parameters.Add(outParam); - cmd.ExecuteNonQuery(); - Assert.That(outParam.Value, Is.EqualTo(9)); - } - } + await conn.ExecuteNonQueryAsync(@"DROP FUNCTION ""FunctionCaseSensitive"""); } + } + + [Test, Description("Tests function parameter derivation for quoted functions with double quotes in the name works")] + public async Task DeriveParameters_quote_characters_in_function_name() + { + await using var conn = await OpenConnectionAsync(); + var function = @"""""""FunctionQuote""""CharactersInName"""""""; + await conn.ExecuteNonQueryAsync($"CREATE OR REPLACE FUNCTION {function}(int4, text) RETURNS int4 AS 'SELECT 0' LANGUAGE sql"); - [Test] - public void Void() + try { - using (var conn = OpenConnection()) - { - TestUtil.MinimumPgVersion(conn, "9.1.0", "no binary output function available for type void before 9.1.0"); - var command = new NpgsqlCommand("pg_sleep", conn); - command.Parameters.AddWithValue(0); - command.CommandType = CommandType.StoredProcedure; - command.ExecuteNonQuery(); - } + await using var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); } - - [Test] - public void NamedParameters() + finally { - using (var conn = OpenConnection()) - { - TestUtil.MinimumPgVersion(conn, "9.4.0", "make_timestamp was introduced in 9.4"); - using (var command = new NpgsqlCommand("make_timestamp", conn)) - { - command.CommandType = CommandType.StoredProcedure; - command.Parameters.AddWithValue("year", 2015); - command.Parameters.AddWithValue("month", 8); - command.Parameters.AddWithValue("mday", 1); - command.Parameters.AddWithValue("hour", 2); - command.Parameters.AddWithValue("min", 3); - command.Parameters.AddWithValue("sec", 4); - var dt = (DateTime) command.ExecuteScalar()!; - - Assert.AreEqual(new DateTime(2015, 8, 1, 2, 3, 4), dt); - - command.Parameters[0].Value = 2014; - command.Parameters[0].ParameterName = ""; // 2014 will be sent as a positional parameter - dt = (DateTime) command.ExecuteScalar()!; - Assert.AreEqual(new DateTime(2014, 8, 1, 2, 3, 4), dt); - } - } + await conn.ExecuteNonQueryAsync("DROP FUNCTION " + function); } + } - [Test] - public void TooManyOutputParams() + [Test, Description("Tests function parameter derivation for quoted functions with dots in the name works")] + public async Task DeriveParameters_dot_character_in_function_name() + { + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync( + @"CREATE OR REPLACE FUNCTION ""My.Dotted.Function""(int4, text) RETURNS int4 AS 'SELECT 0' LANGUAGE sql"); + + try + { + await using var command = new NpgsqlCommand(@"""My.Dotted.Function""", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + } + finally { - using (var conn = OpenConnection()) - { - var command = new NpgsqlCommand("VALUES (4,5), (6,7)", conn); - command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32) - { - Direction = ParameterDirection.Output, - Value = -1 - }); - command.Parameters.Add(new NpgsqlParameter("b", DbType.Int32) - { - Direction = ParameterDirection.Output, - Value = -1 - }); - command.Parameters.Add(new NpgsqlParameter("c", DbType.Int32) - { - Direction = ParameterDirection.Output, - Value = -1 - }); - - command.ExecuteNonQuery(); - - Assert.That(command.Parameters["a"].Value, Is.EqualTo(4)); - Assert.That(command.Parameters["b"].Value, Is.EqualTo(5)); - Assert.That(command.Parameters["c"].Value, Is.EqualTo(-1)); - } + await conn.ExecuteNonQueryAsync(@"DROP FUNCTION ""My.Dotted.Function"""); } } + + [Test] + public async Task DeriveParameters_parameter_name_from_function() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync( + $"CREATE FUNCTION {function}(x int, y int, out sum int, out product int) AS 'SELECT $1 + $2, $1 * $2' LANGUAGE sql"); + await using var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual("x", command.Parameters[0].ParameterName); + Assert.AreEqual("y", command.Parameters[1].ParameterName); + } + + [Test] + public async Task DeriveParameters_non_existing_function() + { + await using var conn = await OpenConnectionAsync(); + var invalidCommandName = new NpgsqlCommand("invalidfunctionname", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(invalidCommandName), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFunction)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1212")] + public async Task DeriveParameters_function_with_table_parameters() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.2.0"); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync( + $"CREATE FUNCTION {function}(IN in1 INT) RETURNS TABLE(t1 INT, t2 INT) AS 'SELECT in1, in1+1' LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(3)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); + cmd.Parameters[0].Value = 5; + await cmd.ExecuteNonQueryAsync(); + Assert.That(cmd.Parameters[1].Value, Is.EqualTo(5)); + Assert.That(cmd.Parameters[2].Value, Is.EqualTo(6)); + } + + [Test, Description("Tests if the right function according to search_path is used in function parameter derivation")] + public async Task DeriveParameters_function_correct_schema_resolution() + { + await using var conn = await OpenConnectionAsync(); + var schema1 = await CreateTempSchema(conn); + var schema2 = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE FUNCTION {schema1}.redundantfunc() RETURNS int AS 'SELECT 1' LANGUAGE sql; +CREATE FUNCTION {schema2}.redundantfunc(IN param1 INT, IN param2 INT) RETURNS int AS 'SELECT param1 + param2' LANGUAGE sql; +SET search_path TO {schema2};"); + await using var command = new NpgsqlCommand("redundantfunc", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.That(command.Parameters, Has.Count.EqualTo(2)); + Assert.That(command.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(command.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); + command.Parameters[0].Value = 5; + command.Parameters[1].Value = 4; + Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo(9)); + } + + [Test, Description("Tests if function parameter derivation throws an exception if the specified function is not in the search_path")] + public async Task DeriveParameters_throws_for_existing_function_that_is_not_in_search_path() + { + await using var conn = await OpenConnectionAsync(); + var schema = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE FUNCTION {schema}.schema1func() RETURNS int AS 'SELECT 1' LANGUAGE sql; +RESET search_path;"); + await using var command = new NpgsqlCommand("schema1func", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFunction)); + } + + [Test, Description("Tests if an exception is thrown if multiple functions with the specified name are in the search_path")] + public async Task DeriveParameters_throws_for_multiple_function_name_hits_in_search_path() + { + await using var conn = await OpenConnectionAsync(); + var schema1 = await CreateTempSchema(conn); + var schema2 = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync( + $@" +CREATE FUNCTION {schema1}.redundantfunc() RETURNS int AS 'SELECT 1' LANGUAGE sql; +CREATE FUNCTION {schema1}.redundantfunc(IN param1 INT, IN param2 INT) RETURNS int AS 'SELECT param1 + param2' LANGUAGE sql; +SET search_path TO {schema1}, {schema2};"); + var command = new NpgsqlCommand("redundantfunc", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.AmbiguousFunction)); + } + + #region Set returning functions + + [Test, Description("Tests parameter derivation for a function that returns SETOF sometype")] + public async Task DeriveParameters_function_returning_setof_type() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.2.0"); + + var table = await GetTempTableName(conn); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (fooid int, foosubid int, fooname text); +INSERT INTO {table} VALUES (1, 1, 'Joe'), (1, 2, 'Ed'), (2, 1, 'Mary'); +CREATE FUNCTION {function}(int) RETURNS SETOF {table} AS $$ + SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; +$$ LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); + cmd.Parameters[0].Value = 1; + await cmd.ExecuteNonQueryAsync(); + Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); + } + + [Test, Description("Tests parameter derivation for a function that returns TABLE")] + public async Task DeriveParameters_function_returning_table() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.2.0"); + + var table = await GetTempTableName(conn); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (fooid int, foosubid int, fooname text); +INSERT INTO {table} VALUES (1, 1, 'Joe'), (1, 2, 'Ed'), (2, 1, 'Mary'); +CREATE FUNCTION {function}(int) RETURNS TABLE(fooid int, foosubid int, fooname text) AS $$ + SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; +$$ LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); + cmd.Parameters[0].Value = 1; + await cmd.ExecuteNonQueryAsync(); + Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); + } + + [Test, Description("Tests parameter derivation for a function that returns SETOF record")] + public async Task DeriveParameters_function_returning_setof_record() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.2.0"); + + var table = await GetTempTableName(conn); + var function = await GetTempFunctionName(conn); + + // This function returns record because of the two Out (InOut & Out) parameters + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (fooid int, foosubid int, fooname text); +INSERT INTO {table} VALUES (1, 1, 'Joe'), (1, 2, 'Ed'), (2, 1, 'Mary'); +CREATE FUNCTION {function}(int, OUT fooid int, OUT foosubid int, OUT fooname text) RETURNS SETOF record AS $$ + SELECT * FROM {table} WHERE {table}.fooid = $1 ORDER BY {table}.foosubid; +$$ LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(4)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[2].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[3].Direction, Is.EqualTo(ParameterDirection.Output)); + cmd.Parameters[0].Value = 1; + await cmd.ExecuteNonQueryAsync(); + Assert.That(cmd.Parameters[0].Value, Is.EqualTo(1)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2022")] + public async Task DeriveParameters_function_returning_setof_type_with_dropped_column() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "9.2.0"); + + var table = await GetTempTableName(conn); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table} (id serial PRIMARY KEY, t1 text, t2 text); +CREATE FUNCTION {function}() RETURNS SETOF {table} AS 'SELECT * FROM {table}' LANGUAGE sql; +ALTER TABLE {table} DROP t2;"); + + await using var cmd = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(cmd.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); + } + + #endregion + + #endregion DeriveParameters + +#if DEBUG + [OneTimeSetUp] + public void OneTimeSetup() => NpgsqlCommand.EnableStoredProcedureCompatMode = true; + + [OneTimeTearDown] + public void OneTimeTeardown() => NpgsqlCommand.EnableStoredProcedureCompatMode = false; +#else + [OneTimeSetUp] + public void OneTimeSetup() + => Assert.Ignore("Cannot test function invocation via CommandType.StoredProcedure since that depends on the global EnableStoredProcedureCompatMode compatibility flag"); +#endif } diff --git a/test/Npgsql.Tests/GlobalTypeMapperTests.cs b/test/Npgsql.Tests/GlobalTypeMapperTests.cs new file mode 100644 index 0000000000..a5c75e41bf --- /dev/null +++ b/test/Npgsql.Tests/GlobalTypeMapperTests.cs @@ -0,0 +1,125 @@ +using System; +using System.Threading.Tasks; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +#pragma warning disable CS0618 // GlobalTypeMapper is obsolete + +[NonParallelizable] +public class GlobalTypeMapperTests : TestBase +{ + [Test] + public async Task MapEnum() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + NpgsqlConnection.GlobalTypeMapper.MapEnum(type); + + await using var dataSource1 = CreateDataSource(); + + await using (var connection = await dataSource1.OpenConnectionAsync()) + { + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + + await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); + } + + NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); + + // Global mapping changes have no effect on already-built data sources + await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + + // But they do affect new data sources + await using var dataSource2 = CreateDataSource(); + await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + } + + [Test] + public async Task MapEnum_NonGeneric() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + NpgsqlConnection.GlobalTypeMapper.MapEnum(typeof(Mood), type); + + try + { + await using var dataSource1 = CreateDataSource(); + + await using (var connection = await dataSource1.OpenConnectionAsync()) + { + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + + await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); + } + + NpgsqlConnection.GlobalTypeMapper.UnmapEnum(typeof(Mood), type); + + // Global mapping changes have no effect on already-built data sources + await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + + // But they do affect new data sources + await using var dataSource2 = CreateDataSource(); + Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, npgsqlDbType: null)); + } + finally + { + NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); + } + } + + [Test] + public async Task Reset() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + NpgsqlConnection.GlobalTypeMapper.MapEnum(type); + + await using var dataSource1 = CreateDataSource(); + + await using (var connection = await dataSource1.OpenConnectionAsync()) + { + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + } + + // A global mapping change has no effects on data sources which have already been built + NpgsqlConnection.GlobalTypeMapper.Reset(); + + // Global mapping changes have no effect on already-built data sources + await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + + // But they do affect new data sources + await using var dataSource2 = CreateDataSource(); + await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + } + + [Test] + public void Reset_and_add_resolver() + { + NpgsqlConnection.GlobalTypeMapper.Reset(); + NpgsqlConnection.GlobalTypeMapper.AddTypeInfoResolverFactory(new DummyResolverFactory()); + } + + [TearDown] + public void Teardown() + => NpgsqlConnection.GlobalTypeMapper.Reset(); + + enum Mood { Sad, Ok, Happy } + + class DummyResolverFactory : PgTypeInfoResolverFactory + { + public override IPgTypeInfoResolver CreateResolver() => new DummyResolver(); + public override IPgTypeInfoResolver? CreateArrayResolver() => null; + + class DummyResolver : IPgTypeInfoResolver + { + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => null; + } + } +} diff --git a/test/Npgsql.Tests/LargeObjectTests.cs b/test/Npgsql.Tests/LargeObjectTests.cs index 5a10f1098a..fb7179abb4 100644 --- a/test/Npgsql.Tests/LargeObjectTests.cs +++ b/test/Npgsql.Tests/LargeObjectTests.cs @@ -2,50 +2,49 @@ using System.Text; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +#pragma warning disable CS0618 // Large object support is obsolete + +public class LargeObjectTests : TestBase { - public class LargeObjectTests : TestBase + [Test] + public void Test() { - [Test] - public void Test() + using var conn = OpenConnection(); + using var transaction = conn.BeginTransaction(); + var manager = new NpgsqlLargeObjectManager(conn); + var oid = manager.Create(); + using (var stream = manager.OpenReadWrite(oid)) { - using (var conn = OpenConnection()) - using (var transaction = conn.BeginTransaction()) - { - var manager = new NpgsqlLargeObjectManager(conn); - var oid = manager.Create(); - using (var stream = manager.OpenReadWrite(oid)) - { - var buf = Encoding.UTF8.GetBytes("Hello"); - stream.Write(buf, 0, buf.Length); - stream.Seek(0, System.IO.SeekOrigin.Begin); - var buf2 = new byte[buf.Length]; - stream.Read(buf2, 0, buf2.Length); - Assert.That(buf.SequenceEqual(buf2)); - - Assert.AreEqual(5, stream.Position); - - Assert.AreEqual(5, stream.Length); - - stream.Seek(-1, System.IO.SeekOrigin.Current); - Assert.AreEqual((int)'o', stream.ReadByte()); - - manager.MaxTransferBlockSize = 3; - - stream.Write(buf, 0, buf.Length); - stream.Seek(-5, System.IO.SeekOrigin.End); - var buf3 = new byte[100]; - Assert.AreEqual(5, stream.Read(buf3, 0, 100)); - Assert.That(buf.SequenceEqual(buf3.Take(5))); - - stream.SetLength(43); - Assert.AreEqual(43, stream.Length); - } - - manager.Unlink(oid); - - transaction.Rollback(); - } + var buf = Encoding.UTF8.GetBytes("Hello"); + stream.Write(buf, 0, buf.Length); + stream.Seek(0, System.IO.SeekOrigin.Begin); + var buf2 = new byte[buf.Length]; + stream.Read(buf2, 0, buf2.Length); + Assert.That(buf.SequenceEqual(buf2)); + + Assert.AreEqual(5, stream.Position); + + Assert.AreEqual(5, stream.Length); + + stream.Seek(-1, System.IO.SeekOrigin.Current); + Assert.AreEqual((int)'o', stream.ReadByte()); + + manager.MaxTransferBlockSize = 3; + + stream.Write(buf, 0, buf.Length); + stream.Seek(-5, System.IO.SeekOrigin.End); + var buf3 = new byte[100]; + Assert.AreEqual(5, stream.Read(buf3, 0, 100)); + Assert.That(buf.SequenceEqual(buf3.Take(5))); + + stream.SetLength(43); + Assert.AreEqual(43, stream.Length); } + + manager.Unlink(oid); + + transaction.Rollback(); } } diff --git a/test/Npgsql.Tests/MultipleHostsTests.cs b/test/Npgsql.Tests/MultipleHostsTests.cs new file mode 100644 index 0000000000..bbd2064504 --- /dev/null +++ b/test/Npgsql.Tests/MultipleHostsTests.cs @@ -0,0 +1,1185 @@ +using Npgsql.Internal; +using Npgsql.Tests.Support; +using NUnit.Framework; +using System; +using System.Collections.Generic; +using System.Data; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; +using Npgsql.Properties; +using static Npgsql.Tests.Support.MockState; +using static Npgsql.Tests.TestUtil; +using IsolationLevel = System.Transactions.IsolationLevel; +using TransactionStatus = Npgsql.Internal.TransactionStatus; + +namespace Npgsql.Tests; + +public class MultipleHostsTests : TestBase +{ + static readonly object[] MyCases = + { + new object[] { TargetSessionAttributes.Standby, new[] { Primary, Standby }, 1 }, + new object[] { TargetSessionAttributes.Standby, new[] { PrimaryReadOnly, Standby }, 1 }, + new object[] { TargetSessionAttributes.PreferStandby, new[] { Primary, Standby }, 1 }, + new object[] { TargetSessionAttributes.PreferStandby, new[] { PrimaryReadOnly, Standby }, 1 }, + new object[] { TargetSessionAttributes.PreferStandby, new[] { Primary, Primary }, 0 }, + new object[] { TargetSessionAttributes.Primary, new[] { Standby, Primary }, 1 }, + new object[] { TargetSessionAttributes.Primary, new[] { Standby, PrimaryReadOnly }, 1 }, + new object[] { TargetSessionAttributes.PreferPrimary, new[] { Standby, Primary }, 1 }, + new object[] { TargetSessionAttributes.PreferPrimary, new[] { Standby, PrimaryReadOnly }, 1 }, + new object[] { TargetSessionAttributes.PreferPrimary, new[] { Standby, Standby }, 0 }, + new object[] { TargetSessionAttributes.Any, new[] { Standby, Primary }, 0 }, + new object[] { TargetSessionAttributes.Any, new[] { Primary, Standby }, 0 }, + new object[] { TargetSessionAttributes.Any, new[] { PrimaryReadOnly, Standby }, 0 }, + new object[] { TargetSessionAttributes.ReadWrite, new[] { Standby, Primary }, 1 }, + new object[] { TargetSessionAttributes.ReadWrite, new[] { PrimaryReadOnly, Primary }, 1 }, + new object[] { TargetSessionAttributes.ReadOnly, new[] { Primary, Standby }, 1 }, + new object[] { TargetSessionAttributes.ReadOnly, new[] { PrimaryReadOnly, Standby }, 0 } + }; + + [Test] + [TestCaseSource(nameof(MyCases))] + public async Task Connect_to_correct_host_pooled(TargetSessionAttributes targetSessionAttributes, MockState[] servers, int expectedServer) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + Pooling = true + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionStringBuilder.ConnectionString) + .BuildMultiHost(); + await using var conn = await dataSource.OpenConnectionAsync(targetSessionAttributes); + + Assert.That(conn.Port, Is.EqualTo(postmasters[expectedServer].Port)); + + for (var i = 0; i <= expectedServer; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + + [Test] + [TestCaseSource(nameof(MyCases))] + public async Task Connect_to_correct_host_unpooled(TargetSessionAttributes targetSessionAttributes, MockState[] servers, int expectedServer) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + Pooling = false + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionStringBuilder.ConnectionString) + .BuildMultiHost(); + await using var conn = await dataSource.OpenConnectionAsync(targetSessionAttributes); + + Assert.That(conn.Port, Is.EqualTo(postmasters[expectedServer].Port)); + + for (var i = 0; i <= expectedServer; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + + [Test] + [TestCaseSource(nameof(MyCases))] + public async Task Connect_to_correct_host_with_available_idle( + TargetSessionAttributes targetSessionAttributes, MockState[] servers, int expectedServer) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + // First, open and close a connection with the TargetSessionAttributes matching the first server. + // This ensures wew have an idle connection in the pool. + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionStringBuilder.ConnectionString) + .BuildMultiHost(); + var idleConnTargetSessionAttributes = servers[0] switch + { + Primary => TargetSessionAttributes.ReadWrite, + PrimaryReadOnly => TargetSessionAttributes.ReadOnly, + Standby => TargetSessionAttributes.Standby, + _ => throw new ArgumentOutOfRangeException() + }; + await using (_ = await dataSource.OpenConnectionAsync(idleConnTargetSessionAttributes)) + { + // Do nothing, close to have an idle connection in the pool. + } + + // Now connect with the test TargetSessionAttributes + + await using var conn = await dataSource.OpenConnectionAsync(targetSessionAttributes); + + Assert.That(conn.Port, Is.EqualTo(postmasters[expectedServer].Port)); + + for (var i = 0; i <= expectedServer; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + + [Test] + [TestCase(TargetSessionAttributes.Standby, new[] { Primary, Primary })] + [TestCase(TargetSessionAttributes.Primary, new[] { Standby, Standby })] + [TestCase(TargetSessionAttributes.ReadWrite, new[] { PrimaryReadOnly, Standby })] + [TestCase(TargetSessionAttributes.ReadOnly, new[] { Primary, Primary })] + public async Task Valid_host_not_found(TargetSessionAttributes targetSessionAttributes, MockState[] servers) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionStringBuilder.ConnectionString) + .BuildMultiHost(); + + var exception = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync(targetSessionAttributes))!; + Assert.That(exception.Message, Is.EqualTo("No suitable host was found.")); + Assert.That(exception.InnerException, Is.Null); + + for (var i = 0; i < servers.Length; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + + [Test, Platform(Exclude = "MacOsX", Reason = "#3786")] + public void All_hosts_are_down() + { + // Different exception raised in .NET Core 3.1, skip (NUnit doesn't seem to support detecting .NET Core versions) + if (RuntimeInformation.FrameworkDescription.StartsWith(".NET Core 3.1")) + return; + + var endpoint = new IPEndPoint(IPAddress.Loopback, 0); + + using var socket1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket1.Bind(endpoint); + var localEndPoint1 = (IPEndPoint)socket1.LocalEndPoint!; + + using var socket2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket2.Bind(endpoint); + var localEndPoint2 = (IPEndPoint)socket2.LocalEndPoint!; + + // Note that we Bind (to reserve the port), but do not Listen - connection attempts will fail. + + var connectionString = new NpgsqlConnectionStringBuilder + { + Host = $"{localEndPoint1.Address}:{localEndPoint1.Port},{localEndPoint2.Address}:{localEndPoint2.Port}" + }.ConnectionString; + using var dataSource = new NpgsqlDataSourceBuilder(connectionString).BuildMultiHost(); + + var exception = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync(TargetSessionAttributes.Any))!; + var aggregateException = (AggregateException)exception.InnerException!; + Assert.That(aggregateException.InnerExceptions, Has.Count.EqualTo(2)); + + for (var i = 0; i < aggregateException.InnerExceptions.Count; i++) + { + Assert.That(aggregateException.InnerExceptions[i], Is.TypeOf() + .With.InnerException.TypeOf() + .With.InnerException.Property(nameof(SocketException.SocketErrorCode)).EqualTo(SocketError.ConnectionRefused)); + } + } + + [Test] + public async Task All_hosts_are_unavailable( + [Values] bool pooling, + [Values(PostgresErrorCodes.InvalidCatalogName, PostgresErrorCodes.CannotConnectNow)] string errorCode) + { + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary, startupErrorCode: errorCode); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby, startupErrorCode: errorCode); + + var builder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + Pooling = pooling, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(builder.ConnectionString).BuildMultiHost(); + + var ex = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync(TargetSessionAttributes.Any))!; + Assert.That(ex.SqlState, Is.EqualTo(errorCode)); + } + + [Test] + [Platform(Exclude = "MacOsX", Reason = "Flaky in CI on Mac")] + public async Task First_host_is_down() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var endpoint = new IPEndPoint(IPAddress.Loopback, 0); + socket.Bind(endpoint); + var localEndPoint = (IPEndPoint)socket.LocalEndPoint!; + // Note that we Bind (to reserve the port), but do not Listen - connection attempts will fail. + + await using var postmaster = PgPostmasterMock.Start(state: Primary); + + var connectionString = new NpgsqlConnectionStringBuilder + { + Host = $"{localEndPoint.Address}:{localEndPoint.Port},{postmaster.Host}:{postmaster.Port}", + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading + }.ConnectionString; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionString).BuildMultiHost(); + + await using var conn = await dataSource.OpenConnectionAsync(TargetSessionAttributes.Any); + Assert.That(conn.Port, Is.EqualTo(postmaster.Port)); + } + + [Test] + [TestCase("any")] + [TestCase("primary")] + [TestCase("standby")] + [TestCase("prefer-primary")] + [TestCase("prefer-standby")] + [TestCase("read-write")] + [TestCase("read-only")] + public async Task TargetSessionAttributes_with_single_host(string targetSessionAttributes) + { + var connectionString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + TargetSessionAttributes = targetSessionAttributes + }.ConnectionString; + + if (targetSessionAttributes == "any") + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var pool = CreateTempPool(postmasterMock.ConnectionString, out connectionString); + await using var conn = new NpgsqlConnection(connectionString); + await conn.OpenAsync(); + _ = await postmasterMock.WaitForServerConnection(); + } + else + { + Assert.That(() => new NpgsqlConnection(connectionString), Throws.Exception.TypeOf()); + } + } + + [Test] + public void TargetSessionAttributes_default_is_null() + => Assert.That(new NpgsqlConnectionStringBuilder().TargetSessionAttributes, Is.Null); + + [Test] + [NonParallelizable] // Sets environment variable + public async Task TargetSessionAttributes_uses_environment_variable() + { + using var envVarResetter = SetEnvironmentVariable("PGTARGETSESSIONATTRS", "prefer-standby"); + + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby); + + var builder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading + }; + + Assert.That(builder.TargetSessionAttributes, Is.Null); + + await using var dataSource = new NpgsqlDataSourceBuilder(builder.ConnectionString) + .BuildMultiHost(); + + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.Port, Is.EqualTo(standbyPostmaster.Port)); + } + + [Test] + public void TargetSessionAttributes_invalid_throws() + => Assert.Throws(() => + new NpgsqlConnectionStringBuilder + { + TargetSessionAttributes = nameof(TargetSessionAttributes_invalid_throws) + }); + + [Test] + public void HostRecheckSeconds_default_value() + { + var builder = new NpgsqlConnectionStringBuilder(); + Assert.That(builder.HostRecheckSeconds, Is.EqualTo(10)); + Assert.That(builder.HostRecheckSecondsTranslated, Is.EqualTo(TimeSpan.FromSeconds(10))); + } + + [Test] + public void HostRecheckSeconds_zero_value() + { + var builder = new NpgsqlConnectionStringBuilder + { + HostRecheckSeconds = 0, + }; + Assert.That(builder.HostRecheckSeconds, Is.EqualTo(0)); + Assert.That(builder.HostRecheckSecondsTranslated, Is.EqualTo(TimeSpan.FromSeconds(-1))); + } + + [Test] + public void HostRecheckSeconds_invalid_throws() + => Assert.Throws(() => + new NpgsqlConnectionStringBuilder + { + HostRecheckSeconds = -1 + }); + + [Test] + public async Task Connect_with_load_balancing() + { + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby); + + var defaultCsb = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + MaxPoolSize = 1, + LoadBalanceHosts = true, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(defaultCsb.ConnectionString) + .BuildMultiHost(); + + NpgsqlConnector firstConnector; + NpgsqlConnector secondConnector; + + await using (var firstConnection = await dataSource.OpenConnectionAsync()) + { + firstConnector = firstConnection.Connector!; + } + + await using (var secondConnection = await dataSource.OpenConnectionAsync()) + { + secondConnector = secondConnection.Connector!; + } + + Assert.AreNotSame(firstConnector, secondConnector); + + await using (var firstBalancedConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(firstConnector, firstBalancedConnection.Connector); + } + + await using (var secondBalancedConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(secondConnector, secondBalancedConnection.Connector); + } + + await using (var thirdBalancedConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(firstConnector, thirdBalancedConnection.Connector); + } + } + + [Test] + public async Task Connect_without_load_balancing() + { + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby); + + var defaultCsb = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + MaxPoolSize = 1, + LoadBalanceHosts = false, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(defaultCsb.ConnectionString) + .BuildMultiHost(); + + NpgsqlConnector firstConnector; + NpgsqlConnector secondConnector; + + await using (var firstConnection = await dataSource.OpenConnectionAsync()) + { + firstConnector = firstConnection.Connector!; + } + await using (var secondConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(firstConnector, secondConnection.Connector); + } + await using (var firstConnection = await dataSource.OpenConnectionAsync()) + await using (var secondConnection = await dataSource.OpenConnectionAsync()) + { + secondConnector = secondConnection.Connector!; + } + + Assert.AreNotSame(firstConnector, secondConnector); + + await using (var firstUnbalancedConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(firstConnector, firstUnbalancedConnection.Connector); + } + + await using (var secondUnbalancedConnection = await dataSource.OpenConnectionAsync()) + { + Assert.AreSame(firstConnector, secondUnbalancedConnection.Connector); + } + } + + [Test] + public async Task Connect_state_changing_hosts([Values] bool alwaysCheckHostState) + { + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby); + + var defaultCsb = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + MaxPoolSize = 1, + HostRecheckSeconds = alwaysCheckHostState ? 0 : int.MaxValue, + NoResetOnClose = true, + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(defaultCsb.ConnectionString) + .BuildMultiHost(); + + NpgsqlConnector firstConnector; + NpgsqlConnector secondConnector; + var firstServerTask = Task.Run(async () => + { + var server = await primaryPostmaster.WaitForServerConnection(); + if (!alwaysCheckHostState) + return; + + // If we always check the host, we will send the request for the state + // even though we got one while opening the connection + await server.SendMockState(Primary); + + // Update the state after a 'failover' + await server.SendMockState(Standby); + }); + var secondServerTask = Task.Run(async () => + { + var server = await standbyPostmaster.WaitForServerConnection(); + if (!alwaysCheckHostState) + return; + + // If we always check the host, we will send the request for the state + // even though we got one while opening the connection + await server.SendMockState(Standby); + + // As TargetSessionAttributes is 'prefer', it does another cycle for the 'unpreferred' + await server.SendMockState(Standby); + // Update the state after a 'failover' + await server.SendMockState(Primary); + }); + + await using (var firstConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.PreferPrimary)) + await using (var secondConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.PreferPrimary)) + { + firstConnector = firstConnection.Connector!; + secondConnector = secondConnection.Connector!; + } + + await using var thirdConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.PreferPrimary); + Assert.AreSame(alwaysCheckHostState ? secondConnector : firstConnector, thirdConnection.Connector); + + await firstServerTask; + await secondServerTask; + } + + [Test] + public void Database_state_cache_basic() + { + using var dataSource = CreateDataSource(); + var timeStamp = DateTime.UtcNow; + + dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadWrite, timeStamp, TimeSpan.Zero); + Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState()); + + // Update with the same timestamp - shouldn't change anything + dataSource.UpdateDatabaseState(DatabaseState.Standby, timeStamp, TimeSpan.Zero); + Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState()); + + // Update with a new timestamp + timeStamp = timeStamp.AddSeconds(1); + dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadOnly, timeStamp, TimeSpan.Zero); + Assert.AreEqual(DatabaseState.PrimaryReadOnly, dataSource.GetDatabaseState()); + + // Expired state returns as Unknown (depending on ignoreExpiration) + timeStamp = timeStamp.AddSeconds(1); + dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadWrite, timeStamp, TimeSpan.FromSeconds(-1)); + Assert.AreEqual(DatabaseState.Unknown, dataSource.GetDatabaseState(ignoreExpiration: false)); + Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState(ignoreExpiration: true)); + } + + [Test] + public async Task Offline_state_on_connection_failure() + { + await using var server = PgPostmasterMock.Start(ConnectionString, startupErrorCode: PostgresErrorCodes.ConnectionFailure); + await using var dataSource = server.CreateDataSource(); + await using var conn = dataSource.CreateConnection(); + + var ex = Assert.ThrowsAsync(conn.OpenAsync)!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.ConnectionFailure)); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Offline)); + } + + [Test] + public async Task Unknown_state_on_connection_authentication_failure() + { + await using var server = PgPostmasterMock.Start(ConnectionString, startupErrorCode: PostgresErrorCodes.InvalidAuthorizationSpecification); + await using var dataSource = server.CreateDataSource(); + await using var conn = dataSource.CreateConnection(); + + var ex = Assert.ThrowsAsync(conn.OpenAsync)!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.InvalidAuthorizationSpecification)); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + } + + [Test] + public async Task Offline_state_on_query_execution_pg_critical_failure() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = postmaster.CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var anotherConn = await dataSource.OpenConnectionAsync(); + await anotherConn.CloseAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(2)); + + var server = await postmaster.WaitForServerConnection(); + await server.WriteErrorResponse(PostgresErrorCodes.CrashShutdown).FlushAsync(); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.CrashShutdown)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Offline)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(0)); + } + + [Test, NonParallelizable] + public async Task Offline_state_on_query_execution_pg_non_critical_failure() + { + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + + // Starting with PG14 we get the cluster's state from PG automatically + var expectedState = conn.PostgreSqlVersion.Major > 13 ? DatabaseState.PrimaryReadWrite : DatabaseState.Unknown; + + var state = dataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(expectedState)); + Assert.That(dataSource.Statistics.Total, Is.EqualTo(1)); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT abc"))!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.UndefinedColumn)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + + state = dataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(expectedState)); + Assert.That(dataSource.Statistics.Total, Is.EqualTo(1)); + } + + [Test] + public async Task Offline_state_on_query_execution_IOException() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = postmaster.CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var anotherConn = await dataSource.OpenConnectionAsync(); + await anotherConn.CloseAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(2)); + + var server = await postmaster.WaitForServerConnection(); + server.Close(); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(ex.InnerException, Is.InstanceOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Offline)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(0)); + } + + [Test] + public async Task Offline_state_on_query_execution_TimeoutException() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + var dataSourceBuilder = postmaster.GetDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; + dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = 1; + await using var dataSource = dataSourceBuilder.Build(); + + await using var conn = await dataSource.OpenConnectionAsync(); + await using var anotherConn = await dataSource.OpenConnectionAsync(); + await anotherConn.CloseAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(2)); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(ex.InnerException, Is.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Offline)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(0)); + } + + [Test] + public async Task Unknown_state_on_query_execution_TimeoutException_with_disabled_cancellation() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + var dataSourceBuilder = postmaster.GetDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; + dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = -1; + await using var dataSource = dataSourceBuilder.Build(); + + await using var conn = await dataSource.OpenConnectionAsync(); + await using var anotherConn = await dataSource.OpenConnectionAsync(); + await anotherConn.CloseAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(2)); + + var ex = Assert.ThrowsAsync(() => conn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(ex.InnerException, Is.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); + } + + [Test] + public async Task Unknown_state_on_query_execution_cancellation_with_disabled_cancellation_timeout() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + var dataSourceBuilder = postmaster.GetDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 30; + dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = -1; + await using var dataSource = dataSourceBuilder.Build(); + + await using var conn = await dataSource.OpenConnectionAsync(); + await using var anotherConn = await dataSource.OpenConnectionAsync(); + await anotherConn.CloseAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(2)); + + using var cts = new CancellationTokenSource(); + + var query = conn.ExecuteNonQueryAsync("SELECT 1", cancellationToken: cts.Token); + cts.Cancel(); + var ex = Assert.ThrowsAsync(async () => await query)!; + Assert.That(ex.InnerException, Is.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); + } + + [Test] + public async Task Unknown_state_on_query_execution_TimeoutException_with_cancellation_failure() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + var dataSourceBuilder = postmaster.GetDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; + dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = 0; + await using var dataSource = dataSourceBuilder.Build(); + + await using var conn = await dataSource.OpenConnectionAsync(); + + var state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); + + var server = await postmaster.WaitForServerConnection(); + + var query = conn.ExecuteNonQueryAsync("SELECT 1"); + + await postmaster.WaitForCancellationRequest(); + await server.WriteCancellationResponse().WriteReadyForQuery().FlushAsync(); + + var ex = Assert.ThrowsAsync(async () => await query)!; + Assert.That(ex.InnerException, Is.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + + state = conn.NpgsqlDataSource.GetDatabaseState(); + Assert.That(state, Is.EqualTo(DatabaseState.Unknown)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); + } + + [Test] + public async Task Clear_pool_one_host_only_on_admin_shutdown() + { + await using var primaryPostmaster = PgPostmasterMock.Start(ConnectionString, state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(ConnectionString, state: Standby); + var dataSourceBuilder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + MaxPoolSize = 2 + } + }; + await using var multiHostDataSource = dataSourceBuilder.BuildMultiHost(); + await using var preferPrimaryDataSource = multiHostDataSource.WithTargetSession(TargetSessionAttributes.PreferPrimary); + + await using var primaryConn = await preferPrimaryDataSource.OpenConnectionAsync(); + await using var anotherPrimaryConn = await preferPrimaryDataSource.OpenConnectionAsync(); + await using var standbyConn = await preferPrimaryDataSource.OpenConnectionAsync(); + var primaryDataSource = primaryConn.Connector!.DataSource; + var standbyDataSource = standbyConn.Connector!.DataSource; + await anotherPrimaryConn.CloseAsync(); + await standbyConn.CloseAsync(); + + Assert.That(primaryDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.PrimaryReadWrite)); + Assert.That(standbyDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Standby)); + Assert.That(primaryConn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(3)); + + var server = await primaryPostmaster.WaitForServerConnection(); + await server.WriteErrorResponse(PostgresErrorCodes.AdminShutdown).FlushAsync(); + + var ex = Assert.ThrowsAsync(() => primaryConn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(ex.SqlState, Is.EqualTo(PostgresErrorCodes.AdminShutdown)); + Assert.That(primaryConn.State, Is.EqualTo(ConnectionState.Closed)); + + Assert.That(primaryDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Offline)); + Assert.That(standbyDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Standby)); + Assert.That(primaryConn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); + + multiHostDataSource.ClearDatabaseStates(); + Assert.That(primaryDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Unknown)); + Assert.That(standbyDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Unknown)); + } + + [Test] + [TestCase("any", true)] + [TestCase("primary", true)] + [TestCase("standby", false)] + [TestCase("prefer-primary", true)] + [TestCase("prefer-standby", false)] + [TestCase("read-write", true)] + [TestCase("read-only", false)] + public async Task Transaction_enlist_reuses_connection(string targetSessionAttributes, bool primary) + { + await using var primaryPostmaster = PgPostmasterMock.Start(ConnectionString, state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(ConnectionString, state: Standby); + var csb = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + TargetSessionAttributes = targetSessionAttributes, + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + MaxPoolSize = 10, + }; + + using var _ = CreateTempPool(csb, out var connString); + + using var scope = new TransactionScope(TransactionScopeOption.Required, + new TransactionOptions { IsolationLevel = IsolationLevel.ReadCommitted }, TransactionScopeAsyncFlowOption.Enabled); + + var query1Task = Query(connString); + + var server = primary + ? await primaryPostmaster.WaitForServerConnection() + : await standbyPostmaster.WaitForServerConnection(); + + await server + .WriteCommandComplete() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock) + .WriteParseComplete() + .WriteBindComplete() + .WriteNoData() + .WriteCommandComplete() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock) + .FlushAsync(); + await query1Task; + + var query2Task = Query(connString); + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteNoData() + .WriteCommandComplete() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock) + .FlushAsync(); + await query2Task; + + await server + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + scope.Complete(); + + async Task Query(string connectionString) + { + await using var conn = new NpgsqlConnection(connectionString); + await conn.OpenAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; + await cmd.ExecuteNonQueryAsync(); + } + } + + [Test] + public async Task Primary_host_failover_can_connect() + { + await using var firstPostmaster = PgPostmasterMock.Start(ConnectionString, state: Primary); + await using var secondPostmaster = PgPostmasterMock.Start(ConnectionString, state: Standby); + var dataSourceBuilder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = MultipleHosts(firstPostmaster, secondPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + HostRecheckSeconds = 5 + } + }; + await using var multiHostDataSource = dataSourceBuilder.BuildMultiHost(); + var (firstDataSource, secondDataSource) = (multiHostDataSource.Pools[0], multiHostDataSource.Pools[1]); + await using var primaryDataSource = multiHostDataSource.WithTargetSession(TargetSessionAttributes.Primary); + + await using var conn = await primaryDataSource.OpenConnectionAsync(); + Assert.That(conn.Port, Is.EqualTo(firstPostmaster.Port)); + var firstServer = await firstPostmaster.WaitForServerConnection(); + await firstServer + .WriteErrorResponse(PostgresErrorCodes.AdminShutdown) + .FlushAsync(); + + var failoverEx = Assert.ThrowsAsync(async () => await conn.ExecuteNonQueryAsync("SELECT 1"))!; + Assert.That(failoverEx.SqlState, Is.EqualTo(PostgresErrorCodes.AdminShutdown)); + + var noHostFoundEx = Assert.ThrowsAsync(async () => await conn.OpenAsync())!; + Assert.That(noHostFoundEx.Message, Is.EqualTo("No suitable host was found.")); + + Assert.That(firstDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Offline)); + Assert.That(secondDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Standby)); + + firstPostmaster.State = Standby; + secondPostmaster.State = Primary; + var secondServer = await secondPostmaster.WaitForServerConnection(); + await secondServer.SendMockState(Primary); + + await Task.Delay(TimeSpan.FromSeconds(10)); + Assert.That(firstDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Unknown)); + Assert.That(secondDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Unknown)); + + await conn.OpenAsync(); + Assert.That(conn.Port, Is.EqualTo(secondPostmaster.Port)); + Assert.That(firstDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.Standby)); + Assert.That(secondDataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.PrimaryReadWrite)); + } + + [Test, NonParallelizable] + public void IntegrationTest([Values] bool loadBalancing, [Values] bool alwaysCheckHostState) + { + PoolManager.Reset(); + + var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString) + { + ConnectionStringBuilder = + { + Host = "localhost,127.0.0.1", + Pooling = true, + MaxPoolSize = 2, + LoadBalanceHosts = loadBalancing, + HostRecheckSeconds = alwaysCheckHostState ? 0 : 10, + } + }; + using var dataSource = dataSourceBuilder.BuildMultiHost(); + + var queriesDone = 0; + + var clientsTask = Task.WhenAll( + Client(dataSource, TargetSessionAttributes.Any), + Client(dataSource, TargetSessionAttributes.Primary), + Client(dataSource, TargetSessionAttributes.PreferPrimary), + Client(dataSource, TargetSessionAttributes.PreferStandby), + Client(dataSource, TargetSessionAttributes.ReadWrite)); + + var onlyStandbyClient = Client(dataSource, TargetSessionAttributes.Standby); + var readOnlyClient = Client(dataSource, TargetSessionAttributes.ReadOnly); + + Assert.DoesNotThrowAsync(() => clientsTask); + Assert.ThrowsAsync(() => onlyStandbyClient); + Assert.ThrowsAsync(() => readOnlyClient); + Assert.AreEqual(125, queriesDone); + + Task Client(NpgsqlMultiHostDataSource multiHostDataSource, TargetSessionAttributes targetSessionAttributes) + { + var dataSource = multiHostDataSource.WithTargetSession(targetSessionAttributes); + var tasks = new List(5); + + for (var i = 0; i < 5; i++) + { + tasks.Add(Task.Run(() => Query(dataSource))); + } + + return Task.WhenAll(tasks); + } + + async Task Query(NpgsqlDataSource dataSource) + { + await using var conn = dataSource.CreateConnection(); + for (var i = 0; i < 5; i++) + { + await conn.OpenAsync(); + await conn.ExecuteNonQueryAsync("SELECT 1"); + await conn.CloseAsync(); + Interlocked.Increment(ref queriesDone); + } + } + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/5055")] + [NonParallelizable] // Disables sql rewriting + public async Task Multiple_hosts_with_disabled_sql_rewriting() + { + using var _ = DisableSqlRewriting(); + + var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString) + { + ConnectionStringBuilder = + { + Host = "localhost,127.0.0.1", + Pooling = true, + HostRecheckSeconds = 0 + } + }; + await using var dataSource = dataSourceBuilder.BuildMultiHost(); + await using var conn = await dataSource.OpenConnectionAsync(); + } + + [Test] + public async Task DataSource_with_wrappers() + { + await using var primaryPostmasterMock = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmasterMock = PgPostmasterMock.Start(state: Standby); + + var builder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = MultipleHosts(primaryPostmasterMock, standbyPostmasterMock), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + } + }; + + await using var dataSource = builder.BuildMultiHost(); + await using var primaryDataSource = dataSource.WithTargetSession(TargetSessionAttributes.Primary); + await using var standbyDataSource = dataSource.WithTargetSession(TargetSessionAttributes.Standby); + + await using var primaryConnection = await primaryDataSource.OpenConnectionAsync(); + Assert.That(primaryConnection.Port, Is.EqualTo(primaryPostmasterMock.Port)); + + await using var standbyConnection = await standbyDataSource.OpenConnectionAsync(); + Assert.That(standbyConnection.Port, Is.EqualTo(standbyPostmasterMock.Port)); + } + + [Test] + public async Task DataSource_without_wrappers() + { + await using var primaryPostmasterMock = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmasterMock = PgPostmasterMock.Start(state: Standby); + + var builder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = MultipleHosts(primaryPostmasterMock, standbyPostmasterMock), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + } + }; + + await using var dataSource = builder.BuildMultiHost(); + + await using var primaryConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.Primary); + Assert.That(primaryConnection.Port, Is.EqualTo(primaryPostmasterMock.Port)); + + await using var standbyConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.Standby); + Assert.That(standbyConnection.Port, Is.EqualTo(standbyPostmasterMock.Port)); + } + + [Test] + public void DataSource_with_TargetSessionAttributes_is_not_supported() + { + var builder = new NpgsqlDataSourceBuilder("Host=foo,bar;Target Session Attributes=primary"); + + Assert.That(() => builder.BuildMultiHost(), Throws.Exception.TypeOf() + .With.Message.EqualTo(NpgsqlStrings.CannotSpecifyTargetSessionAttributes)); + } + + [Test] + public async Task BuildMultiHost_with_single_host_is_supported() + { + var builder = new NpgsqlDataSourceBuilder(ConnectionString); + await using var dataSource = builder.BuildMultiHost(); + await using var connection = await dataSource.OpenConnectionAsync(); + Assert.That(await connection.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task Build_with_multiple_hosts_is_supported() + { + await using var primaryPostmasterMock = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmasterMock = PgPostmasterMock.Start(state: Standby); + + var builder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = MultipleHosts(primaryPostmasterMock, standbyPostmasterMock), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + } + }; + + await using var dataSource = builder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4181")] + [Explicit("Fails until #4181 is fixed.")] + public async Task LoadBalancing_is_fair_if_first_host_is_down([Values]TargetSessionAttributes targetSessionAttributes) + { + await using var pDown = PgPostmasterMock.Start(state: Primary, startupErrorCode: PostgresErrorCodes.CannotConnectNow); + await using var pRw1 = PgPostmasterMock.Start(state: Primary); + await using var pR1 = PgPostmasterMock.Start(state: PrimaryReadOnly); + await using var s1 = PgPostmasterMock.Start(state: Standby); + await using var pRw2 = PgPostmasterMock.Start(state: Primary); + await using var pR2 = PgPostmasterMock.Start(state: PrimaryReadOnly); + await using var s2 = PgPostmasterMock.Start(state: Standby); + + var hostList = $"{pDown.Host}:{pDown.Port}," + + $"{pRw1.Host}:{pRw1.Port}," + + $"{pR1.Host}:{pR1.Port}," + + $"{s1.Host}:{s1.Port}," + + $"{pRw2.Host}:{pRw2.Port}," + + $"{pR2.Host}:{pR2.Port}," + + $"{s2.Host}:{s2.Port}"; + + await using var dataSource = CreateDataSource(builder => + { + builder.Host = hostList; + builder.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; + builder.LoadBalanceHosts = true; + builder.TargetSessionAttributesParsed = targetSessionAttributes; + + }); + var connections = Enumerable.Repeat(0, 12).Select(_ => dataSource.OpenConnection()).ToArray(); + await using var __ = new DisposableWrapper(connections); + + switch (targetSessionAttributes) + { + case TargetSessionAttributes.Any: + Assert.That(connections[0].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[1].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[2].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[3].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[4].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[5].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[6].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[7].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[8].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[9].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[10].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[11].Port, Is.EqualTo(s2.Port)); + break; + case TargetSessionAttributes.ReadWrite: + Assert.That(connections[0].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[1].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[2].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[3].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[4].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[5].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[6].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[7].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[8].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[9].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[10].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[11].Port, Is.EqualTo(pRw2.Port)); + break; + case TargetSessionAttributes.ReadOnly: + Assert.That(connections[0].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[1].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[2].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[3].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[4].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[5].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[6].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[7].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[8].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[9].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[10].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[11].Port, Is.EqualTo(s2.Port)); + break; + case TargetSessionAttributes.Primary: + case TargetSessionAttributes.PreferPrimary: + Assert.That(connections[0].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[1].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[2].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[3].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[4].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[5].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[6].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[7].Port, Is.EqualTo(pR2.Port)); + Assert.That(connections[8].Port, Is.EqualTo(pRw1.Port)); + Assert.That(connections[9].Port, Is.EqualTo(pR1.Port)); + Assert.That(connections[10].Port, Is.EqualTo(pRw2.Port)); + Assert.That(connections[11].Port, Is.EqualTo(pR2.Port)); + break; + case TargetSessionAttributes.Standby: + case TargetSessionAttributes.PreferStandby: + Assert.That(connections[0].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[1].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[2].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[3].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[4].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[5].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[6].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[7].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[8].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[9].Port, Is.EqualTo(s2.Port)); + Assert.That(connections[10].Port, Is.EqualTo(s1.Port)); + Assert.That(connections[11].Port, Is.EqualTo(s2.Port)); + break; + } + } + + static string MultipleHosts(params PgPostmasterMock[] postmasters) + => string.Join(",", postmasters.Select(p => $"{p.Host}:{p.Port}")); + + class DisposableWrapper : IAsyncDisposable + { + readonly IEnumerable _disposables; + + public DisposableWrapper(IEnumerable disposables) => _disposables = disposables; + + public async ValueTask DisposeAsync() + { + foreach (var disposable in _disposables) + await disposable.DisposeAsync(); + } + } +} diff --git a/test/Npgsql.Tests/MultiplexingTestBase.cs b/test/Npgsql.Tests/MultiplexingTestBase.cs deleted file mode 100644 index bd736c8512..0000000000 --- a/test/Npgsql.Tests/MultiplexingTestBase.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System.Collections.Concurrent; -using NUnit.Framework; - -namespace Npgsql.Tests -{ - [TestFixture(MultiplexingMode.NonMultiplexing)] - [TestFixture(MultiplexingMode.Multiplexing)] - public abstract class MultiplexingTestBase : TestBase - { - protected bool IsMultiplexing => MultiplexingMode == MultiplexingMode.Multiplexing; - - protected MultiplexingMode MultiplexingMode { get; } - - readonly ConcurrentDictionary<(string ConnString, bool IsMultiplexing), string> _connStringCache - = new ConcurrentDictionary<(string ConnString, bool IsMultiplexing), string>(); - - public override string ConnectionString { get; } - - protected MultiplexingTestBase(MultiplexingMode multiplexingMode) - { - MultiplexingMode = multiplexingMode; - - // If the test requires multiplexing to be on or off, use a small cache to avoid reparsing and - // regenerating the connection string every time - ConnectionString = _connStringCache.GetOrAdd((base.ConnectionString, IsMultiplexing), - tup => new NpgsqlConnectionStringBuilder(tup.ConnString) - { - Multiplexing = tup.IsMultiplexing - }.ToString()); - } - } - - public enum MultiplexingMode - { - NonMultiplexing, - Multiplexing - } -} diff --git a/test/Npgsql.Tests/NestedDataReaderTests.cs b/test/Npgsql.Tests/NestedDataReaderTests.cs new file mode 100644 index 0000000000..72553a6b5e --- /dev/null +++ b/test/Npgsql.Tests/NestedDataReaderTests.cs @@ -0,0 +1,233 @@ +using NUnit.Framework; +using System; +using System.Threading.Tasks; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +public class NestedDataReaderTests : TestBase +{ + [Test] + public async Task Basic() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand(@"SELECT ARRAY[ROW(1, 2, 3), ROW(4, 5, 6)] + UNION ALL + SELECT ARRAY[ROW(7, 8, 9), ROW(10, 11, 12)]", conn); + await using var reader = await command.ExecuteReaderAsync(); + for (var i = 0; i < 2; i++) + { + await reader.ReadAsync(); + using var nestedReader = reader.GetData(0); + Assert.That(nestedReader.HasRows, Is.True); + + for (var j = 0; j < 2; j++) + { + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.FieldCount, Is.EqualTo(3)); + Assert.That(nestedReader.GetFieldType(0), Is.EqualTo(typeof(int))); + Assert.That(nestedReader.GetDataTypeName(0), Is.EqualTo("integer")); + Assert.That(nestedReader.GetName(0), Is.EqualTo("?column?")); + Assert.Throws(() => nestedReader.GetOrdinal("c0")); + for (var k = 0; k < 3; k++) + { + Assert.That(nestedReader.GetInt32(k), Is.EqualTo(1 + 6 * i + j * 3 + k)); + Assert.That(nestedReader.GetValue(k), Is.EqualTo(1 + 6 * i + j * 3 + k)); + } + } + if (i == 0) + Assert.That(nestedReader.Read(), Is.False); + + Assert.That(nestedReader.NextResult(), Is.False); + Assert.That(nestedReader.HasRows, Is.False); + } + } + + [Test] + public async Task Different_field_count() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand(@"SELECT ARRAY[ROW(1), ROW(), ROW('2'::TEXT, 3), ROW(4)]", conn); + await using var reader = await command.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync(), Is.True); + using var nestedReader = reader.GetData(0); + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.FieldCount, Is.EqualTo(1)); + Assert.That(nestedReader.GetFieldType(0), Is.EqualTo(typeof(int))); + Assert.That(nestedReader.GetInt32(0), Is.EqualTo(1)); + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.FieldCount, Is.EqualTo(0)); + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.FieldCount, Is.EqualTo(2)); + Assert.That(nestedReader.GetFieldType(0), Is.EqualTo(typeof(string))); + Assert.That(nestedReader.GetFieldType(1), Is.EqualTo(typeof(int))); + Assert.That(nestedReader.GetString(0), Is.EqualTo("2")); + Assert.That(nestedReader.GetInt32(1), Is.EqualTo(3)); + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.GetFieldType(0), Is.EqualTo(typeof(int))); + Assert.That(nestedReader.GetInt32(0), Is.EqualTo(4)); + Assert.That(nestedReader.Read(), Is.False); + } + + [Test] + public async Task Nested() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand(@"SELECT + ARRAY[ + ROW( + ARRAY[ + ROW('row000'::TEXT, NULL::TEXT), + ROW('row010'::TEXT, 'row011'::TEXT) + ] + ), + ROW( + ARRAY[ + ROW('row100'::TEXT, NULL::TEXT), + ROW('row110'::TEXT, 'row111'::TEXT) + ] + ) + ], 2", conn); + await using var reader = await command.ExecuteReaderAsync(); + + for (var i = 0; i < 1; i++) + { + await reader.ReadAsync(); + using var nestedReader = reader.GetData(0); + for (var j = 0; j < 2; j++) + { + Assert.That(nestedReader.Read(), Is.True); + var nestedReader2 = nestedReader.GetData(0); + for (var k = 0; k < 2; k++) + { + Assert.That(nestedReader2.Read(), Is.True); + for (var l = 0; l < 2; l++) + { + if (k == 0 && l == 1) + { + Assert.That(nestedReader2.IsDBNull(l), Is.True); + Assert.That(nestedReader2.GetValue(l), Is.EqualTo(DBNull.Value)); + Assert.That(nestedReader2.GetProviderSpecificValue(l), Is.EqualTo(DBNull.Value)); + } + else + { + Assert.That(nestedReader2.GetString(l), Is.EqualTo("row" + j + k + l)); + } + } + } + } + Assert.That(reader.GetInt32(1), Is.EqualTo(2)); + } + } + + [Test] + public async Task Single_row() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand("SELECT ROW(1, ARRAY[ROW(2), ROW(3)])", conn); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + using var nestedReader = reader.GetData(0); + Assert.That(nestedReader.Read(), Is.True); + Assert.That(nestedReader.FieldCount, Is.EqualTo(2)); + Assert.That(nestedReader.GetInt32(0), Is.EqualTo(1)); + using var nestedReader2 = nestedReader.GetData(1); + for (var i = 0; i < 2; i++) + { + Assert.That(nestedReader2.Read(), Is.True); + Assert.That(nestedReader2.FieldCount, Is.EqualTo(1)); + Assert.That(nestedReader2.GetInt32(0), Is.EqualTo(2 + i)); + } + Assert.That(nestedReader2.Read(), Is.False); + } + + [Test] + public async Task Empty_array() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand("SELECT ARRAY[]::RECORD[]", conn); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + using var nestedReader = reader.GetData(0); + Assert.That(nestedReader.Read(), Is.False); + Assert.That(nestedReader.NextResult(), Is.False); + } + + [Test] + public async Task Composite() + { + await using var conn = await OpenConnectionAsync(); + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS (c0 integer, c1 text)"); + conn.ReloadTypes(); + var sqls = new string[] + { + $"SELECT ROW('1', '2')::{typeName}", + $"SELECT ARRAY[ROW('1', '2')::{typeName}]" + }; + foreach (var sql in sqls) + { + await using var command = new NpgsqlCommand(sql, conn); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + using var nestedReader = reader.GetData(0); + nestedReader.Read(); + Assert.That(nestedReader.GetDataTypeName(0), Is.EqualTo("integer")); + Assert.That(nestedReader.GetDataTypeName(1), Is.EqualTo("text")); + Assert.That(nestedReader.GetInt32(0), Is.EqualTo(1)); + Assert.That(nestedReader.GetString(1), Is.EqualTo("2")); + Assert.That(nestedReader.GetName(0), Is.EqualTo("c0")); + Assert.That(nestedReader.GetName(1), Is.EqualTo("c1")); + Assert.That(nestedReader.GetOrdinal("C1"), Is.EqualTo(1)); + Assert.That(nestedReader["C1"], Is.EqualTo("2")); + Assert.Throws(() => nestedReader.GetOrdinal("ABC")); + } + } + + [Test] + public void GetBytes() + { + using var conn = OpenConnection(); + using var command = new NpgsqlCommand(@"SELECT ROW('\x010203'::BYTEA, NULL::BYTEA)", conn); + using var reader = command.ExecuteReader(); + reader.Read(); + using var nestedReader = reader.GetData(0); + nestedReader.Read(); + Assert.That(nestedReader.GetFieldType(0), Is.EqualTo(typeof(byte[]))); + var buf = new byte[4]; + Assert.That(nestedReader.GetBytes(0, 0, null, 0, 3), Is.EqualTo(3)); + Assert.That(nestedReader.GetBytes(0, 0, null, 0, 4), Is.EqualTo(3)); + Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 3), Is.EqualTo(3)); + Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 4), Is.EqualTo(3)); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3, 0 }, buf); + buf = new byte[2]; + Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 2), Is.EqualTo(2)); + CollectionAssert.AreEqual(new byte[] { 1, 2 }, buf); + buf = new byte[2]; + Assert.That(nestedReader.GetBytes(0, 1, buf, 1, 1), Is.EqualTo(1)); + CollectionAssert.AreEqual(new byte[] { 0, 2 }, buf); + Assert.That(nestedReader.GetBytes(0, 2, buf, 1, 1), Is.EqualTo(1)); + CollectionAssert.AreEqual(new byte[] { 0, 3 }, buf); + Assert.Throws(() => nestedReader.GetBytes(1, 0, buf, 0, 1)); + Assert.Throws(() => nestedReader.GetBytes(0, 4, buf, 0, 1)); + } + + [Test] + public async Task Throw_after_next_row() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand(@"SELECT ROW(1) UNION ALL SELECT ROW(2) UNION ALL SELECT ROW(3)", conn); + await using var reader = await command.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync(), Is.True); + var nestedReader = reader.GetData(0); + nestedReader.Read(); + await reader.ReadAsync(); + Assert.Throws(() => nestedReader.IsDBNull(0)); + nestedReader = reader.GetData(0); + reader.Read(); + Assert.Throws(() => nestedReader.Read()); + nestedReader = reader.GetData(0); + nestedReader.Read(); + Assert.That(nestedReader.IsDBNull(0), Is.False); + } +} diff --git a/test/Npgsql.Tests/NotificationTests.cs b/test/Npgsql.Tests/NotificationTests.cs index d8ccc3dce0..9df9aba44d 100644 --- a/test/Npgsql.Tests/NotificationTests.cs +++ b/test/Npgsql.Tests/NotificationTests.cs @@ -1,218 +1,230 @@ -using System; +using NUnit.Framework; +using System; using System.Data; using System.Threading; using System.Threading.Tasks; -using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class NotificationTests : TestBase { - public class NotificationTests : TestBase + [Test, Description("Simple LISTEN/NOTIFY scenario")] + public void Notification() { - [Test, Description("Simple LISTEN/NOTIFY scenario")] - public void Notification() - { - using (var conn = OpenConnection()) - { - var receivedNotification = false; - conn.ExecuteNonQuery("LISTEN notifytest"); - conn.Notification += (o, e) => receivedNotification = true; - conn.ExecuteNonQuery("NOTIFY notifytest"); - Assert.IsTrue(receivedNotification); - } - } + var notify = GetUniqueIdentifier(nameof(NotificationTests)); + + using var conn = OpenConnection(); + var receivedNotification = false; + conn.ExecuteNonQuery($"LISTEN {notify}"); + conn.Notification += (o, e) => receivedNotification = true; + conn.ExecuteNonQuery($"NOTIFY {notify}"); + Assert.IsTrue(receivedNotification); + } + + [Test, Description("Generates a notification that arrives after reader data that is already being read")] + [IssueLink("https://github.com/npgsql/npgsql/issues/252")] + public async Task Notification_after_data() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); - //[Test, Description("Generates a notification that arrives after reader data that is already being read")] - [IssueLink("https://github.com/npgsql/npgsql/issues/252")] - public void NotificationAfterData() + var receivedNotification = false; + using var conn = OpenConnection(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = $"LISTEN {notify}"; + cmd.ExecuteNonQuery(); + conn.Notification += (o, e) => receivedNotification = true; + + cmd.CommandText = "SELECT generate_series(1,10000)"; + using (var reader = cmd.ExecuteReader()) { - var receivedNotification = false; - using (var conn = OpenConnection()) - using (var cmd = conn.CreateCommand()) + //After "notify notifytest1", a notification message will be sent to client, + //And so the notification message will stick with the last response message of "select generate_series(1,10000)" in Npgsql's tcp receiving buffer. + using (var conn2 = new NpgsqlConnection(ConnectionString)) { - cmd.CommandText = "LISTEN notifytest1"; - cmd.ExecuteNonQuery(); - conn.Notification += (o, e) => receivedNotification = true; - - cmd.CommandText = "SELECT generate_series(1,10000)"; - using (var reader = cmd.ExecuteReader()) + conn2.Open(); + using (var command = conn2.CreateCommand()) { - //After "notify notifytest1", a notification message will be sent to client, - //And so the notification message will stick with the last response message of "select generate_series(1,10000)" in Npgsql's tcp receiving buffer. - using (var conn2 = new NpgsqlConnection(ConnectionString)) - { - conn2.Open(); - using (var command = conn2.CreateCommand()) - { - command.CommandText = "NOTIFY notifytest1"; - command.ExecuteNonQuery(); - } - } - - // Allow some time for the notification to get delivered - Thread.Sleep(2000); - - Assert.IsTrue(reader.Read()); - Assert.AreEqual(1, reader.GetValue(0)); + command.CommandText = $"NOTIFY {notify}"; + command.ExecuteNonQuery(); } - - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - Assert.IsTrue(receivedNotification); } - } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] - [Timeout(10000)] - public void Wait() - { - using (var conn = OpenConnection()) - using (var notifyingConn = OpenConnection()) - { - var receivedNotification = false; - conn.ExecuteNonQuery("LISTEN notifytest"); - notifyingConn.ExecuteNonQuery("NOTIFY notifytest"); - conn.Notification += (o, e) => receivedNotification = true; - Assert.That(conn.Wait(0), Is.EqualTo(true)); - Assert.IsTrue(receivedNotification); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } - } + // Allow some time for the notification to get delivered + await Task.Delay(2000); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] - //[Timeout(10000)] - public void WaitWithTimeout() - { - using (var conn = OpenConnection()) - { - Assert.That(conn.Wait(100), Is.EqualTo(false)); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } + Assert.IsTrue(reader.Read()); + Assert.AreEqual(1, reader.GetValue(0)); } - [Test] - public void WaitWithPrependedMessage() - { - using (OpenConnection()) {} // A DISCARD ALL is now prepended in the connection's write buffer - using (var conn = OpenConnection()) - Assert.That(conn.Wait(100), Is.EqualTo(false)); - } + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + Assert.IsTrue(receivedNotification); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] + public void Wait() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); + + using var conn = OpenConnection(); + using var notifyingConn = OpenConnection(); + var receivedNotification = false; + conn.ExecuteNonQuery($"LISTEN {notify}"); + notifyingConn.ExecuteNonQuery($"NOTIFY {notify}"); + conn.Notification += (o, e) => receivedNotification = true; + Assert.That(conn.Wait(0), Is.EqualTo(true)); + Assert.IsTrue(receivedNotification); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] + public void Wait_with_timeout() + { + using var conn = OpenConnection(); + Assert.That(conn.Wait(100), Is.EqualTo(false)); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public void Wait_with_prepended_message() + { + using var dataSource = CreateDataSource(); + using (dataSource.OpenConnection()) {} // A DISCARD ALL is now prepended in the connection's write buffer + using var conn = dataSource.OpenConnection(); + Assert.That(conn.Wait(100), Is.EqualTo(false)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] + public async Task WaitAsync() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); + + await using var conn = await OpenConnectionAsync(); + await using var notifyingConn = await OpenConnectionAsync(); + var receivedNotification = false; + await conn.ExecuteNonQueryAsync($"LISTEN {notify}"); + await notifyingConn.ExecuteNonQueryAsync($"NOTIFY {notify}"); + conn.Notification += (o, e) => receivedNotification = true; + await conn.WaitAsync(0); + Assert.IsTrue(receivedNotification); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public void WaitAsync_with_timeout() + { + using var conn = OpenConnection(); + Assert.That(async () => await conn.WaitAsync(100), Is.EqualTo(false)); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] - [Timeout(10000)] - public async Task WaitAsync() + [Test] + public void Wait_with_keepalive() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); + + using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - using (var notifyingConn = OpenConnection()) - { - var receivedNotification = false; - conn.ExecuteNonQuery("LISTEN notifytest"); - notifyingConn.ExecuteNonQuery("NOTIFY notifytest"); - conn.Notification += (o, e) => receivedNotification = true; - await conn.WaitAsync(0); - Assert.IsTrue(receivedNotification); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } - } + csb.KeepAlive = 1; + csb.Pooling = false; + }); + using var conn = dataSource.OpenConnection(); + using var notifyingConn = dataSource.OpenConnection(); + conn.ExecuteNonQuery($"LISTEN {notify}"); + var notificationTask = Task.Delay(2000).ContinueWith(t => notifyingConn.ExecuteNonQuery($"NOTIFY {notify}")); + conn.Wait(); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + // A safeguard against closing an active connection + notificationTask.GetAwaiter().GetResult(); + //Assert.That(TestLoggerSink.Records, Has.Some.With.Property("EventId").EqualTo(new EventId(NpgsqlEventId.Keepalive))); + } + + [Test] + public async Task WaitAsync_with_keepalive() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); - [Test] - public void WaitAsyncWithTimeout() + await using var dataSource = CreateDataSource(csb => { - using var conn = OpenConnection(); - Assert.That(async () => await conn.WaitAsync(100), Is.EqualTo(false)); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } + csb.KeepAlive = 1; + csb.Pooling = false; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var notifyingConn = await dataSource.OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync($"LISTEN {notify}"); + var notificationTask = Task.Delay(2000).ContinueWith(t => notifyingConn.ExecuteNonQuery($"NOTIFY {notify}")); + await conn.WaitAsync(); + //Assert.That(TestLoggerSink.Records, Has.Some.With.Property("EventId").EqualTo(new EventId(NpgsqlEventId.Keepalive))); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + // A safeguard against closing an active connection + await notificationTask; + } + + [Test] + public void WaitAsync_cancellation() + { + var notify = GetUniqueIdentifier(nameof(NotificationTests)); - [Test] - public async Task WaitWithKeepalive() + using (var conn = OpenConnection()) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - KeepAlive = 1, - Pooling = false - }; - using (var conn = OpenConnection(csb)) - using (var notifyingConn = OpenConnection()) - { - conn.ExecuteNonQuery("LISTEN notifytest"); - var notificationTask = Task.Delay(2000).ContinueWith(t => notifyingConn.ExecuteNonQuery("NOTIFY notifytest")); - conn.Wait(); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - // A safeguard against closing an active connection - await notificationTask; - } - //Assert.That(TestLoggerSink.Records, Has.Some.With.Property("EventId").EqualTo(new EventId(NpgsqlEventId.Keepalive))); + Assert.ThrowsAsync(async () => await conn.WaitAsync(new CancellationToken(true))); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); } - [Test] - public async Task WaitAsyncWithKeepalive() + using (var conn = OpenConnection()) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - KeepAlive = 1, - Pooling = false - }; - using (var conn = OpenConnection(csb)) - using (var notifyingConn = OpenConnection()) - { - conn.ExecuteNonQuery("LISTEN notifytest"); - var notificationTask = Task.Delay(2000).ContinueWith(t => notifyingConn.ExecuteNonQuery("NOTIFY notifytest")); - await conn.WaitAsync(); - //Assert.That(TestLoggerSink.Records, Has.Some.With.Property("EventId").EqualTo(new EventId(NpgsqlEventId.Keepalive))); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - // A safeguard against closing an active connection - await notificationTask; - } + conn.ExecuteNonQuery($"LISTEN {notify}"); + var cts = new CancellationTokenSource(1000); + Assert.ThrowsAsync(async () => await conn.WaitAsync(cts.Token)); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); } + } - [Test] - public void WaitAsyncCancellation() + [Test] + public void Wait_breaks_connection() + { + using var dataSource = CreateDataSource(); + using var conn = dataSource.OpenConnection(); + Task.Delay(1000).ContinueWith(t => { - using (var conn = OpenConnection()) - { - Assert.That(async () => await conn.WaitAsync(new CancellationToken(true)), - Throws.Exception.TypeOf()); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } + using var conn2 = OpenConnection(); + conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); + }); - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("LISTEN notifytest"); - var cts = new CancellationTokenSource(1000); - Assert.That(async () => await conn.WaitAsync(cts.Token), - Throws.Exception.TypeOf()); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } - } + var pgEx = Assert.Throws(conn.Wait)!; + Assert.That(pgEx.SqlState, Is.EqualTo(PostgresErrorCodes.AdminShutdown)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test] - public void WaitBreaksConnection() + [Test] + public void WaitAsync_breaks_connection() + { + using var dataSource = CreateDataSource(); + using var conn = dataSource.OpenConnection(); + Task.Delay(1000).ContinueWith(t => { - using (var conn = OpenConnection()) - { - Task.Delay(1000).ContinueWith(t => - { - using (var conn2 = OpenConnection()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); - }); + using var conn2 = OpenConnection(); + conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); + }); - Assert.That(() => conn.Wait(), Throws.Exception.TypeOf()); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } - } + var pgEx = Assert.ThrowsAsync(async () => await conn.WaitAsync())!; + Assert.That(pgEx.SqlState, Is.EqualTo(PostgresErrorCodes.AdminShutdown)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test] - public void WaitAsyncBreaksConnection() - { - using (var conn = OpenConnection()) - { - Task.Delay(1000).ContinueWith(t => - { - using (var conn2 = OpenConnection()) - conn2.ExecuteNonQuery($"SELECT pg_terminate_backend({conn.ProcessID})"); - }); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4911")] + public async Task Big_notice_while_loading_types() + { + await using var adminConn = await OpenConnectionAsync(); + // Max notification payload is 8000 + await using var dataSource = CreateDataSource(csb => csb.ReadBufferSize = 4096); + await using var conn = await dataSource.OpenConnectionAsync(); - Assert.That(async () => await conn.WaitAsync(), Throws.Exception.TypeOf()); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } - } + var notify = GetUniqueIdentifier(nameof(Big_notice_while_loading_types)); + await conn.ExecuteNonQueryAsync($"LISTEN {notify}"); + var payload = new string('a', 5000); + await adminConn.ExecuteNonQueryAsync($"NOTIFY {notify}, '{payload}'"); + + await conn.ReloadTypesAsync(); } } diff --git a/test/Npgsql.Tests/Npgsql.Tests.csproj b/test/Npgsql.Tests/Npgsql.Tests.csproj index 933629f602..6b7baca8ad 100644 --- a/test/Npgsql.Tests/Npgsql.Tests.csproj +++ b/test/Npgsql.Tests/Npgsql.Tests.csproj @@ -1,7 +1,8 @@  - + + @@ -9,4 +10,9 @@ + + true + $(NoWarn);NPG9001 + $(NoWarn);NPG9002 + diff --git a/test/Npgsql.Tests/NpgsqlEventSourceTests.cs b/test/Npgsql.Tests/NpgsqlEventSourceTests.cs index 5d8ab0371c..c1659e6fba 100644 --- a/test/Npgsql.Tests/NpgsqlEventSourceTests.cs +++ b/test/Npgsql.Tests/NpgsqlEventSourceTests.cs @@ -3,54 +3,53 @@ using System.Diagnostics.Tracing; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +[NonParallelizable] // Events +public class NpgsqlEventSourceTests : TestBase { - [NonParallelizable] - public class NpgsqlEventSourceTests : TestBase + [Test] + public void Command_start_stop() { - [Test] - public void CommandStartStop() + using (var conn = OpenConnection()) { - using (var conn = OpenConnection()) - { - // There is a new pool created, which sends a few queries to load pg types - ClearEvents(); - conn.ExecuteScalar("SELECT 1"); - } - - var commandStart = _events.Single(e => e.EventId == NpgsqlEventSource.CommandStartId); - Assert.That(commandStart.EventName, Is.EqualTo("CommandStart")); - - var commandStop = _events.Single(e => e.EventId == NpgsqlEventSource.CommandStopId); - Assert.That(commandStop.EventName, Is.EqualTo("CommandStop")); + // There is a new pool created, which sends a few queries to load pg types + ClearEvents(); + conn.ExecuteScalar("SELECT 1"); } - [OneTimeSetUp] - public void EnableEventSource() - { - _listener = new TestEventListener(_events); - _listener.EnableEvents(NpgsqlSqlEventSource.Log, EventLevel.Informational); - } + var commandStart = _events.Single(e => e.EventId == NpgsqlEventSource.CommandStartId); + Assert.That(commandStart.EventName, Is.EqualTo("CommandStart")); - [OneTimeTearDown] - public void DisableEventSource() - { - _listener.DisableEvents(NpgsqlSqlEventSource.Log); - _listener.Dispose(); - } + var commandStop = _events.Single(e => e.EventId == NpgsqlEventSource.CommandStopId); + Assert.That(commandStop.EventName, Is.EqualTo("CommandStop")); + } - [SetUp] - public void ClearEvents() => _events.Clear(); + [OneTimeSetUp] + public void EnableEventSource() + { + _listener = new TestEventListener(_events); + _listener.EnableEvents(NpgsqlSqlEventSource.Log, EventLevel.Informational); + } - TestEventListener _listener = null!; + [OneTimeTearDown] + public void DisableEventSource() + { + _listener.DisableEvents(NpgsqlSqlEventSource.Log); + _listener.Dispose(); + } - readonly List _events = new List(); + [SetUp] + public void ClearEvents() => _events.Clear(); - class TestEventListener : EventListener - { - readonly List _events; - public TestEventListener(List events) => _events = events; - protected override void OnEventWritten(EventWrittenEventArgs eventData) => _events.Add(eventData); - } + TestEventListener _listener = null!; + + readonly List _events = new(); + + class TestEventListener : EventListener + { + readonly List _events; + public TestEventListener(List events) => _events = events; + protected override void OnEventWritten(EventWrittenEventArgs eventData) => _events.Add(eventData); } } diff --git a/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs b/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs new file mode 100644 index 0000000000..6c09b7b708 --- /dev/null +++ b/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs @@ -0,0 +1,354 @@ +using NpgsqlTypes; +using NUnit.Framework; +using System; +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Tests; + +[TestFixture(CompatMode.OnePass)] +#if DEBUG +[TestFixture(CompatMode.TwoPass)] +[NonParallelizable] // This test class has global effects on case sensitive matching in param collection. +#endif +public class NpgsqlParameterCollectionTests +{ + readonly CompatMode _compatMode; + const int LookupThreshold = NpgsqlParameterCollection.LookupThreshold; + + [Test] + public void Can_only_add_NpgsqlParameter() + { + using var command = new NpgsqlCommand(); + Assert.That(() => command.Parameters.Add("hello"), Throws.Exception.TypeOf()); + Assert.That(() => command.Parameters.Add(new SomeOtherDbParameter()), Throws.Exception.TypeOf()); + Assert.That(() => command.Parameters.Add(null!), Throws.Exception.TypeOf()); + } + + /// + /// Test which validates that Clear() indeed cleans up the parameters in a command so they can be added to other commands safely. + /// + [Test] + public void Clear() + { + var p = new NpgsqlParameter(); + var c1 = new NpgsqlCommand(); + var c2 = new NpgsqlCommand(); + c1.Parameters.Add(p); + Assert.AreEqual(1, c1.Parameters.Count); + Assert.AreEqual(0, c2.Parameters.Count); + c1.Parameters.Clear(); + Assert.AreEqual(0, c1.Parameters.Count); + c2.Parameters.Add(p); + Assert.AreEqual(0, c1.Parameters.Count); + Assert.AreEqual(1, c2.Parameters.Count); + } + + [Test] + public void Hash_lookup_parameter_rename_bug() + { + if (_compatMode == CompatMode.TwoPass) + return; + + using var command = new NpgsqlCommand(); + // Put plenty of parameters in the collection to turn on hash lookup functionality. + for (var i = 0; i < LookupThreshold; i++) + { + command.Parameters.AddWithValue(string.Format("p{0:00}", i + 1), NpgsqlDbType.Text, string.Format("String parameter value {0}", i + 1)); + } + + // Make sure hash lookup is generated. + Assert.AreEqual(command.Parameters["p03"].ParameterName, "p03"); + + // Rename the target parameter. + command.Parameters["p03"].ParameterName = "a_new_name"; + + // Try to exploit the hash lookup bug. + // If the bug exists, the hash lookups will be out of sync with the list, and be unable + // to find the parameter by its new name. + Assert.That(command.Parameters.IndexOf("a_new_name"), Is.GreaterThanOrEqualTo(0)); + } + + [Test] + public void Remove_duplicate_parameter([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + if (_compatMode == CompatMode.OnePass) + return; + + using var command = new NpgsqlCommand(); + // Put plenty of parameters in the collection to turn on hash lookup functionality. + for (var i = 0; i < count; i++) + { + command.Parameters.AddWithValue(string.Format("p{0:00}", i + 1), NpgsqlDbType.Text, + string.Format("String parameter value {0}", i + 1)); + } + + // Make sure lookup is generated. + Assert.AreEqual(command.Parameters["p02"].ParameterName, "p02"); + + // Add uppercased version causing a list to be created. + command.Parameters.AddWithValue("P02", NpgsqlDbType.Text, "String parameter value 2"); + + // Remove the original parameter by its name causing the multivalue to use a single value again. + command.Parameters.Remove(command.Parameters["p02"]); + + // Test whether we can still find the last added parameter, and if its index is correctly shifted in the lookup. + Assert.IsTrue(command.Parameters.IndexOf("p02") == count - 1); + Assert.IsTrue(command.Parameters.IndexOf("P02") == count - 1); + // And finally test whether other parameters were also correctly shifted. + Assert.IsTrue(command.Parameters.IndexOf("p03") == 1); + } + + [Test] + public void Remove_parameter([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + using var command = new NpgsqlCommand(); + // Put plenty of parameters in the collection to turn on hash lookup functionality. + for (var i = 0; i < count; i++) + { + command.Parameters.AddWithValue(string.Format("p{0:00}", i + 1), NpgsqlDbType.Text, + string.Format("String parameter value {0}", i + 1)); + } + + // Remove the parameter by its name + command.Parameters.Remove(command.Parameters["p02"]); + + // Make sure we cannot find it, also not case insensitively. + Assert.IsTrue(command.Parameters.IndexOf("p02") == -1); + Assert.IsTrue(command.Parameters.IndexOf("P02") == -1); + } + + [Test] + public void Remove_case_differing_parameter([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + // We add two case-differing parameters which will match as well, before adding the others. + using var command = new NpgsqlCommand(); + command.Parameters.Add(new NpgsqlParameter("PP0", 1)); + command.Parameters.Add(new NpgsqlParameter("Pp0", 1)); + for (var i = 0; i < count - 2; i++) + command.Parameters.Add(new NpgsqlParameter($"pp{i}", i)); + + // Removing Pp0. + command.Parameters.RemoveAt(1); + + // Exact match to pp0 or case insensitive match to PP0 depending on mode. + Assert.That(command.Parameters.IndexOf("pp0"), Is.EqualTo(_compatMode == CompatMode.TwoPass ? 1 : 0)); + // Exact match to PP0. + Assert.That(command.Parameters.IndexOf("PP0"), Is.EqualTo(0)); + // Case insensitive match to PP0. + Assert.That(command.Parameters.IndexOf("Pp0"), Is.EqualTo(0)); + } + + + [Test] + public void Correct_index_returned_for_duplicate_ParameterName([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + if (_compatMode == CompatMode.OnePass) + return; + + using var command = new NpgsqlCommand(); + // Put plenty of parameters in the collection to turn on hash lookup functionality. + for (var i = 0; i < count; i++) + { + command.Parameters.AddWithValue(string.Format("parameter{0:00}", i + 1), NpgsqlDbType.Text, string.Format("String parameter value {0}", i + 1)); + } + + // Make sure lookup is generated. + Assert.AreEqual(command.Parameters["parameter02"].ParameterName, "parameter02"); + + // Add uppercased version. + command.Parameters.AddWithValue("Parameter02", NpgsqlDbType.Text, "String parameter value 2"); + + // Insert another case insensitive before the original. + command.Parameters.Insert(0, new NpgsqlParameter("ParameteR02", NpgsqlDbType.Text) { Value = "String parameter value 2" }); + + // Try to find the exact index. + Assert.IsTrue(command.Parameters.IndexOf("parameter02") == 2); + Assert.IsTrue(command.Parameters.IndexOf("Parameter02") == command.Parameters.Count - 1); + Assert.IsTrue(command.Parameters.IndexOf("ParameteR02") == 0); + // This name does not exist so we expect the first case insensitive match to be returned. + Assert.IsTrue(command.Parameters.IndexOf("ParaMeteR02") == 0); + + // And finally test whether other parameters were also correctly shifted. + Assert.IsTrue(command.Parameters.IndexOf("parameter03") == 3); + } + + [Test] + public void Finds_case_insensitive_lookups([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + parameters.Add(new NpgsqlParameter($"p{i}", i)); + + Assert.That(command.Parameters.IndexOf("P1"), Is.EqualTo(1)); + } + + [Test] + public void Finds_case_sensitive_lookups([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + parameters.Add(new NpgsqlParameter($"p{i}", i)); + + Assert.That(command.Parameters.IndexOf("p1"), Is.EqualTo(1)); + } + + [Test] + public void Throws_on_indexer_mismatch([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + parameters.Add(new NpgsqlParameter($"p{i}", i)); + + Assert.DoesNotThrow(() => + { + command.Parameters["p1"] = new NpgsqlParameter("p1", 1); + command.Parameters["p1"] = new NpgsqlParameter("P1", 1); + }); + + Assert.Throws(() => + { + command.Parameters["p1"] = new NpgsqlParameter("p2", 1); + }); + } + + [Test] + public void Positional_parameter_lookup_returns_first_match([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + parameters.Add(new NpgsqlParameter(NpgsqlParameter.PositionalName, i)); + + Assert.That(command.Parameters.IndexOf(""), Is.EqualTo(0)); + } + + [Test] + public void Throw_multiple_positions_same_instance() + { + using var cmd = new NpgsqlCommand("SELECT $1, $2"); + var p = new NpgsqlParameter("", "Hello world"); + cmd.Parameters.Add(p); + Assert.Throws(() => cmd.Parameters.Add(p)); + } + + [Test] + public void IndexOf_falls_back_to_first_insensitive_match([Values] bool manyParams) + { + if (_compatMode == CompatMode.OnePass) + return; + + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + + parameters.Add(new NpgsqlParameter("foo", 8)); + parameters.Add(new NpgsqlParameter("bar", 8)); + parameters.Add(new NpgsqlParameter("BAR", 8)); + Assert.That(parameters, Has.Count.LessThan(LookupThreshold)); + + if (manyParams) + for (var i = 0; i < LookupThreshold; i++) + parameters.Add(new NpgsqlParameter($"p{i}", i)); + + Assert.That(parameters.IndexOf("Bar"), Is.EqualTo(1)); + } + + [Test] + public void IndexOf_prefers_case_sensitive_match([Values] bool manyParams) + { + if (_compatMode == CompatMode.OnePass) + return; + + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + + parameters.Add(new NpgsqlParameter("FOO", 8)); + parameters.Add(new NpgsqlParameter("foo", 8)); + Assert.That(parameters, Has.Count.LessThan(LookupThreshold)); + + if (manyParams) + for (var i = 0; i < LookupThreshold; i++) + parameters.Add(new NpgsqlParameter($"p{i}", i)); + + Assert.That(parameters.IndexOf("foo"), Is.EqualTo(1)); + } + + [Test] + public void IndexOf_matches_all_parameter_syntaxes() + { + using var command = new NpgsqlCommand(); + var parameters = command.Parameters; + + parameters.Add(new NpgsqlParameter("@foo0", 8)); + parameters.Add(new NpgsqlParameter(":foo1", 8)); + parameters.Add(new NpgsqlParameter("foo2", 8)); + + for (var i = 0; i < parameters.Count; i++) + { + Assert.That(parameters.IndexOf("foo" + i), Is.EqualTo(i)); + Assert.That(parameters.IndexOf("@foo" + i), Is.EqualTo(i)); + Assert.That(parameters.IndexOf(":foo" + i), Is.EqualTo(i)); + } + } + + [Test] + public void Cloning_succeeds([Values(LookupThreshold, LookupThreshold - 2)] int count) + { + var command = new NpgsqlCommand(); + for (var i = 0; i < count; i++) + { + command.Parameters.Add(new NpgsqlParameter()); + } + Assert.DoesNotThrow(() => command.Clone()); + } + + [Test] + public void Clean_name() + { + var param = new NpgsqlParameter(); + var command = new NpgsqlCommand(); + command.Parameters.Add(param); + + param.ParameterName = null; + + // These should not throw exceptions + Assert.AreEqual(0, command.Parameters.IndexOf(param.ParameterName)); + Assert.AreEqual(NpgsqlParameter.PositionalName, param.ParameterName); + } + + public NpgsqlParameterCollectionTests(CompatMode compatMode) + { + _compatMode = compatMode; + +#if DEBUG + NpgsqlParameterCollection.TwoPassCompatMode = compatMode == CompatMode.TwoPass; +#else + if (compatMode == CompatMode.TwoPass) + Assert.Ignore("Cannot test case-insensitive NpgsqlParameterCollection behavior in RELEASE"); +#endif + } + + class SomeOtherDbParameter : DbParameter + { + public override void ResetDbType() {} + + public override DbType DbType { get; set; } + public override ParameterDirection Direction { get; set; } + public override bool IsNullable { get; set; } + [AllowNull] public override string ParameterName { get; set; } = ""; + [AllowNull] public override string SourceColumn { get; set; } = ""; + public override object? Value { get; set; } + public override bool SourceColumnNullMapping { get; set; } + public override int Size { get; set; } + } +} + +public enum CompatMode +{ + TwoPass, + OnePass +} diff --git a/test/Npgsql.Tests/NpgsqlParameterTests.cs b/test/Npgsql.Tests/NpgsqlParameterTests.cs index 4850ea1289..4f8d89a9f0 100644 --- a/test/Npgsql.Tests/NpgsqlParameterTests.cs +++ b/test/Npgsql.Tests/NpgsqlParameterTests.cs @@ -1,801 +1,874 @@ -#define NET_2_0 - using NpgsqlTypes; using NUnit.Framework; using System; using System.Data; using System.Data.Common; +using System.Threading.Tasks; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Tests; -namespace Npgsql.Tests +public class NpgsqlParameterTest : TestBase { - [TestFixture] - public class NpgsqlParameterTest : TestBase + [Test, Description("Makes sure that when NpgsqlDbType or Value/NpgsqlValue are set, DbType and NpgsqlDbType are set accordingly")] + public void Implicit_setting_of_DbType() { - [Test, Description("Makes sure that when NpgsqlDbType or Value/NpgsqlValue are set, DbType and NpgsqlDbType are set accordingly")] - public void ImplicitSettingOfDbTypes() - { - var p = new NpgsqlParameter("p", DbType.Int32); - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + var p = new NpgsqlParameter("p", DbType.Int32); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - // As long as NpgsqlDbType/DbType aren't set explicitly, infer them from Value - p = new NpgsqlParameter("p", 8); - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(p.DbType, Is.EqualTo(DbType.Int32)); + // As long as NpgsqlDbType/DbType aren't set explicitly, infer them from Value + p = new NpgsqlParameter("p", 8); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(p.DbType, Is.EqualTo(DbType.Int32)); - p.Value = 3.0; - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); - Assert.That(p.DbType, Is.EqualTo(DbType.Double)); + p.Value = 3.0; + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Double)); + Assert.That(p.DbType, Is.EqualTo(DbType.Double)); - p.NpgsqlDbType = NpgsqlDbType.Bytea; - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); - Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); + p.NpgsqlDbType = NpgsqlDbType.Bytea; + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); + Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); - p.Value = "dont_change"; - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); - Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); + p.Value = "dont_change"; + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); + Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); - p = new NpgsqlParameter("p", new int[0]); - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Integer)); - Assert.That(p.DbType, Is.EqualTo(DbType.Object)); - } + p = new NpgsqlParameter("p", new int[0]); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Integer)); + Assert.That(p.DbType, Is.EqualTo(DbType.Object)); + } - [Test] - public void TypeName() - { - using (var conn = OpenConnection()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var p1 = new NpgsqlParameter { ParameterName = "p", Value = 8, DataTypeName = "integer" }; - cmd.Parameters.Add(p1); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); - // Purposefully try to send int as string, which should fail. This makes sure - // the above doesn't work simply because of type inference from the CLR type. - p1.DataTypeName = "text"; - Assert.That(() => cmd.ExecuteScalar(), Throws.Exception.TypeOf()); - - cmd.Parameters.Clear(); - - var p2 = new NpgsqlParameter { ParameterName = "p", TypedValue = 8, DataTypeName = "integer" }; - cmd.Parameters.Add(p2); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); - // Purposefully try to send int as string, which should fail. This makes sure - // the above doesn't work simply because of type inference from the CLR type. - p2.DataTypeName = "text"; - Assert.That(() => cmd.ExecuteScalar(), Throws.Exception.TypeOf()); - } - } + [Test] + public void DataTypeName() + { + using var conn = OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + var p1 = new NpgsqlParameter { ParameterName = "p", Value = 8, DataTypeName = "integer" }; + cmd.Parameters.Add(p1); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); + // Purposefully try to send int as string, which should fail. This makes sure + // the above doesn't work simply because of type inference from the CLR type. + p1.DataTypeName = "text"; + Assert.That(() => cmd.ExecuteScalar(), Throws.Exception.TypeOf()); + + cmd.Parameters.Clear(); + + var p2 = new NpgsqlParameter { ParameterName = "p", TypedValue = 8, DataTypeName = "integer" }; + cmd.Parameters.Add(p2); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); + // Purposefully try to send int as string, which should fail. This makes sure + // the above doesn't work simply because of type inference from the CLR type. + p2.DataTypeName = "text"; + Assert.That(() => cmd.ExecuteScalar(), Throws.Exception.TypeOf()); + } - [Test] - public void SettingDbTypeSetsNpgsqlDbType() - { - var p = new NpgsqlParameter(); - p.DbType = DbType.Binary; - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); - } + [Test] + public void Positional_parameter_is_positional() + { + var p = new NpgsqlParameter(NpgsqlParameter.PositionalName, 1); + Assert.That(p.IsPositional, Is.True); - [Test] - public void SettingNpgsqlDbTypeSetsDbType() - { - var p = new NpgsqlParameter(); - p.NpgsqlDbType = NpgsqlDbType.Bytea; - Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); - } + var p2 = new NpgsqlParameter(null, 1); + Assert.That(p2.IsPositional, Is.True); + } - [Test] - public void SettingValueDoesNotChangeDbType() - { - var p = new NpgsqlParameter { DbType = DbType.String, NpgsqlDbType = NpgsqlDbType.Bytea }; - p.Value = 8; - Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); - } + [Test] + public void Infer_data_type_name_from_NpgsqlDbType() + { + var p = new NpgsqlParameter("par_field1", NpgsqlDbType.Varchar, 50); + Assert.That(p.DataTypeName, Is.EqualTo("character varying")); + } - // Older tests + [Test] + public void Infer_data_type_name_from_DbType() + { + var p = new NpgsqlParameter("par_field1", DbType.String , 50); + Assert.That(p.DataTypeName, Is.EqualTo("text")); + } - /// - /// Test which validates that Clear() indeed cleans up the parameters in a command so they can be added to other commands safely. - /// - [Test] - public void NpgsqlParameterCollectionClearTest() - { - var p = new NpgsqlParameter(); - var c1 = new NpgsqlCommand(); - var c2 = new NpgsqlCommand(); - c1.Parameters.Add(p); - Assert.AreEqual(1, c1.Parameters.Count); - Assert.AreEqual(0, c2.Parameters.Count); - c1.Parameters.Clear(); - Assert.AreEqual(0, c1.Parameters.Count); - c2.Parameters.Add(p); - Assert.AreEqual(0, c1.Parameters.Count); - Assert.AreEqual(1, c2.Parameters.Count); - } + [Test] + public void Infer_data_type_name_from_NpgsqlDbType_for_array() + { + var p = new NpgsqlParameter("int_array", NpgsqlDbType.Array | NpgsqlDbType.Integer); + Assert.That(p.DataTypeName, Is.EqualTo("integer[]")); + } - #region Constructors + [Test] + public void Infer_data_type_name_from_NpgsqlDbType_for_built_in_range() + { + var p = new NpgsqlParameter("numeric_range", NpgsqlDbType.Range | NpgsqlDbType.Numeric); + Assert.That(p.DataTypeName, Is.EqualTo("numrange")); + } - [Test] - public void Constructor1() - { - var p = new NpgsqlParameter(); - Assert.AreEqual(DbType.Object, p.DbType, "DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "Direction"); - Assert.IsFalse(p.IsNullable, "IsNullable"); - Assert.AreEqual(string.Empty, p.ParameterName, "ParameterName"); - Assert.AreEqual(0, p.Precision, "Precision"); - Assert.AreEqual(0, p.Scale, "Scale"); - Assert.AreEqual(0, p.Size, "Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "NpgsqlDbType"); - Assert.IsNull(p.Value, "Value"); - } + [Test] + public void Cannot_infer_data_type_name_from_NpgsqlDbType_for_unknown_range() + { + var p = new NpgsqlParameter("text_range", NpgsqlDbType.Range | NpgsqlDbType.Text); + Assert.That(p.DataTypeName, Is.EqualTo(null)); + } - [Test] - public void Constructor2_Value_DateTime() - { - var value = new DateTime(2004, 8, 24); - - var p = new NpgsqlParameter("address", value); - Assert.AreEqual(DbType.DateTime, p.DbType, "B:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); - Assert.IsFalse(p.IsNullable, "B:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); - Assert.AreEqual(0, p.Precision, "B:Precision"); - Assert.AreEqual(0, p.Scale, "B:Scale"); - //Assert.AreEqual (0, p.Size, "B:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "B:NpgsqlDbType"); - Assert.AreEqual(value, p.Value, "B:Value"); - } + [Test] + public void Infer_data_type_name_from_ClrType() + { + var p = new NpgsqlParameter("p1", Array.Empty()); + Assert.That(p.DataTypeName, Is.EqualTo("bytea")); + } - [Test] - public void Constructor2_Value_DBNull() - { - var p = new NpgsqlParameter("address", DBNull.Value); - Assert.AreEqual(DbType.Object, p.DbType, "B:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); - Assert.IsFalse(p.IsNullable, "B:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); - Assert.AreEqual(0, p.Precision, "B:Precision"); - Assert.AreEqual(0, p.Scale, "B:Scale"); - Assert.AreEqual(0, p.Size, "B:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "B:NpgsqlDbType"); - Assert.AreEqual(DBNull.Value, p.Value, "B:Value"); - } + [Test] + public void Setting_DbType_sets_NpgsqlDbType() + { + var p = new NpgsqlParameter(); + p.DbType = DbType.Binary; + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); + } - [Test] - public void Constructor2_Value_Null() - { - var p = new NpgsqlParameter("address", null); - Assert.AreEqual(DbType.Object, p.DbType, "A:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "A:Direction"); - Assert.IsFalse(p.IsNullable, "A:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "A:ParameterName"); - Assert.AreEqual(0, p.Precision, "A:Precision"); - Assert.AreEqual(0, p.Scale, "A:Scale"); - Assert.AreEqual(0, p.Size, "A:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "A:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "A:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "A:NpgsqlDbType"); - Assert.IsNull(p.Value, "A:Value"); - } + [Test] + public void Setting_NpgsqlDbType_sets_DbType() + { + var p = new NpgsqlParameter(); + p.NpgsqlDbType = NpgsqlDbType.Bytea; + Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); + } - [Test] - //.ctor (String, NpgsqlDbType, Int32, String, ParameterDirection, bool, byte, byte, DataRowVersion, object) - public void Constructor7() + [Test] + public void Setting_value_does_not_change_DbType() + { + var p = new NpgsqlParameter { DbType = DbType.String, NpgsqlDbType = NpgsqlDbType.Bytea }; + p.Value = 8; + Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); + } + + // Older tests + + #region Constructors + + [Test] + public void Constructor1() + { + var p = new NpgsqlParameter(); + Assert.AreEqual(DbType.Object, p.DbType, "DbType"); + Assert.AreEqual(ParameterDirection.Input, p.Direction, "Direction"); + Assert.IsFalse(p.IsNullable, "IsNullable"); + Assert.AreEqual(string.Empty, p.ParameterName, "ParameterName"); + Assert.AreEqual(0, p.Precision, "Precision"); + Assert.AreEqual(0, p.Scale, "Scale"); + Assert.AreEqual(0, p.Size, "Size"); + Assert.AreEqual(string.Empty, p.SourceColumn, "SourceColumn"); + Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "SourceVersion"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "NpgsqlDbType"); + Assert.IsNull(p.Value, "Value"); + } + + [Test] + public void Constructor2_Value_DateTime() + { + var value = new DateTime(2004, 8, 24); + + var p = new NpgsqlParameter("address", value); + Assert.AreEqual(DbType.DateTime2, p.DbType, "B:DbType"); + Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); + Assert.IsFalse(p.IsNullable, "B:IsNullable"); + Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); + Assert.AreEqual(0, p.Precision, "B:Precision"); + Assert.AreEqual(0, p.Scale, "B:Scale"); + //Assert.AreEqual (0, p.Size, "B:Size"); + Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); + Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "B:NpgsqlDbType"); + Assert.AreEqual(value, p.Value, "B:Value"); + } + + [Test] + public void Constructor2_Value_DBNull() + { + var p = new NpgsqlParameter("address", DBNull.Value); + Assert.AreEqual(DbType.Object, p.DbType, "B:DbType"); + Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); + Assert.IsFalse(p.IsNullable, "B:IsNullable"); + Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); + Assert.AreEqual(0, p.Precision, "B:Precision"); + Assert.AreEqual(0, p.Scale, "B:Scale"); + Assert.AreEqual(0, p.Size, "B:Size"); + Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); + Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "B:NpgsqlDbType"); + Assert.AreEqual(DBNull.Value, p.Value, "B:Value"); + } + + [Test] + public void Constructor2_Value_null() + { + var p = new NpgsqlParameter("address", null); + Assert.AreEqual(DbType.Object, p.DbType, "A:DbType"); + Assert.AreEqual(ParameterDirection.Input, p.Direction, "A:Direction"); + Assert.IsFalse(p.IsNullable, "A:IsNullable"); + Assert.AreEqual("address", p.ParameterName, "A:ParameterName"); + Assert.AreEqual(0, p.Precision, "A:Precision"); + Assert.AreEqual(0, p.Scale, "A:Scale"); + Assert.AreEqual(0, p.Size, "A:Size"); + Assert.AreEqual(string.Empty, p.SourceColumn, "A:SourceColumn"); + Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "A:SourceVersion"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "A:NpgsqlDbType"); + Assert.IsNull(p.Value, "A:Value"); + } + + [Test] + //.ctor (String, NpgsqlDbType, Int32, String, ParameterDirection, bool, byte, byte, DataRowVersion, object) + public void Constructor7() + { + var p1 = new NpgsqlParameter("p1Name", NpgsqlDbType.Varchar, 20, + "srcCol", ParameterDirection.InputOutput, false, 0, 0, + DataRowVersion.Original, "foo"); + Assert.AreEqual(DbType.String, p1.DbType, "DbType"); + Assert.AreEqual(ParameterDirection.InputOutput, p1.Direction, "Direction"); + Assert.AreEqual(false, p1.IsNullable, "IsNullable"); + //Assert.AreEqual (999, p1.LocaleId, "#"); + Assert.AreEqual("p1Name", p1.ParameterName, "ParameterName"); + Assert.AreEqual(0, p1.Precision, "Precision"); + Assert.AreEqual(0, p1.Scale, "Scale"); + Assert.AreEqual(20, p1.Size, "Size"); + Assert.AreEqual("srcCol", p1.SourceColumn, "SourceColumn"); + Assert.AreEqual(false, p1.SourceColumnNullMapping, "SourceColumnNullMapping"); + Assert.AreEqual(DataRowVersion.Original, p1.SourceVersion, "SourceVersion"); + Assert.AreEqual(NpgsqlDbType.Varchar, p1.NpgsqlDbType, "NpgsqlDbType"); + //Assert.AreEqual (3210, p1.NpgsqlValue, "#"); + Assert.AreEqual("foo", p1.Value, "Value"); + //Assert.AreEqual ("database", p1.XmlSchemaCollectionDatabase, "XmlSchemaCollectionDatabase"); + //Assert.AreEqual ("name", p1.XmlSchemaCollectionName, "XmlSchemaCollectionName"); + //Assert.AreEqual ("schema", p1.XmlSchemaCollectionOwningSchema, "XmlSchemaCollectionOwningSchema"); + } + + [Test] + public void Clone() + { + var expected = new NpgsqlParameter { - var p1 = new NpgsqlParameter("p1Name", NpgsqlDbType.Varchar, 20, - "srcCol", ParameterDirection.InputOutput, false, 0, 0, - DataRowVersion.Original, "foo"); - Assert.AreEqual(DbType.String, p1.DbType, "DbType"); - Assert.AreEqual(ParameterDirection.InputOutput, p1.Direction, "Direction"); - Assert.AreEqual(false, p1.IsNullable, "IsNullable"); - //Assert.AreEqual (999, p1.LocaleId, "#"); - Assert.AreEqual("p1Name", p1.ParameterName, "ParameterName"); - Assert.AreEqual(0, p1.Precision, "Precision"); - Assert.AreEqual(0, p1.Scale, "Scale"); - Assert.AreEqual(20, p1.Size, "Size"); - Assert.AreEqual("srcCol", p1.SourceColumn, "SourceColumn"); - Assert.AreEqual(false, p1.SourceColumnNullMapping, "SourceColumnNullMapping"); - Assert.AreEqual(DataRowVersion.Original, p1.SourceVersion, "SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Varchar, p1.NpgsqlDbType, "NpgsqlDbType"); - //Assert.AreEqual (3210, p1.NpgsqlValue, "#"); - Assert.AreEqual("foo", p1.Value, "Value"); - //Assert.AreEqual ("database", p1.XmlSchemaCollectionDatabase, "XmlSchemaCollectionDatabase"); - //Assert.AreEqual ("name", p1.XmlSchemaCollectionName, "XmlSchemaCollectionName"); - //Assert.AreEqual ("schema", p1.XmlSchemaCollectionOwningSchema, "XmlSchemaCollectionOwningSchema"); - } + Value = 42, + ParameterName = "TheAnswer", + + DbType = DbType.Int32, + NpgsqlDbType = NpgsqlDbType.Integer, + DataTypeName = "integer", + + Direction = ParameterDirection.InputOutput, + IsNullable = true, + Precision = 1, + Scale = 2, + Size = 4, + + SourceVersion = DataRowVersion.Proposed, + SourceColumn = "source", + SourceColumnNullMapping = true, + }; + var actual = expected.Clone(); + + Assert.AreEqual(expected.Value, actual.Value); + Assert.AreEqual(expected.ParameterName, actual.ParameterName); + + Assert.AreEqual(expected.DbType, actual.DbType); + Assert.AreEqual(expected.NpgsqlDbType, actual.NpgsqlDbType); + Assert.AreEqual(expected.DataTypeName, actual.DataTypeName); + + Assert.AreEqual(expected.Direction, actual.Direction); + Assert.AreEqual(expected.IsNullable, actual.IsNullable); + Assert.AreEqual(expected.Precision, actual.Precision); + Assert.AreEqual(expected.Scale, actual.Scale); + Assert.AreEqual(expected.Size, actual.Size); + + Assert.AreEqual(expected.SourceVersion, actual.SourceVersion); + Assert.AreEqual(expected.SourceColumn, actual.SourceColumn); + Assert.AreEqual(expected.SourceColumnNullMapping, actual.SourceColumnNullMapping); + } - #endregion + [Test] + public void Clone_generic() + { + var expected = new NpgsqlParameter + { + TypedValue = 42, + ParameterName = "TheAnswer", + + DbType = DbType.Int32, + NpgsqlDbType = NpgsqlDbType.Integer, + DataTypeName = "integer", + + Direction = ParameterDirection.InputOutput, + IsNullable = true, + Precision = 1, + Scale = 2, + Size = 4, + + SourceVersion = DataRowVersion.Proposed, + SourceColumn ="source", + SourceColumnNullMapping = true, + }; + var actual = (NpgsqlParameter)expected.Clone(); + + Assert.AreEqual(expected.Value, actual.Value); + Assert.AreEqual(expected.TypedValue, actual.TypedValue); + Assert.AreEqual(expected.ParameterName, actual.ParameterName); + + Assert.AreEqual(expected.DbType, actual.DbType); + Assert.AreEqual(expected.NpgsqlDbType, actual.NpgsqlDbType); + Assert.AreEqual(expected.DataTypeName, actual.DataTypeName); + + Assert.AreEqual(expected.Direction, actual.Direction); + Assert.AreEqual(expected.IsNullable, actual.IsNullable); + Assert.AreEqual(expected.Precision, actual.Precision); + Assert.AreEqual(expected.Scale, actual.Scale); + Assert.AreEqual(expected.Size, actual.Size); + + Assert.AreEqual(expected.SourceVersion, actual.SourceVersion); + Assert.AreEqual(expected.SourceColumn, actual.SourceColumn); + Assert.AreEqual(expected.SourceColumnNullMapping, actual.SourceColumnNullMapping); + } -#if NeedsPorting + #endregion - [Test] -#if NET_2_0 - [Category ("NotWorking")] -#endif - public void InferType_Char() + [Test] + [Ignore("")] + public void InferType_invalid_throws() + { + var notsupported = new object[] { - Char value = 'X'; - -#if NET_2_0 - String string_value = "X"; - - NpgsqlParameter p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#A:NpgsqlDbType"); - Assert.AreEqual (DbType.String, p.DbType, "#A:DbType"); - Assert.AreEqual (string_value, p.Value, "#A:Value"); - - p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (value, p.Value, "#B:Value1"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#B:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#B:Value2"); - - p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (value, p.Value, "#C:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#C:DbType"); - Assert.AreEqual (string_value, p.Value, "#C:Value2"); - - p = new NpgsqlParameter ("name", value); - Assert.AreEqual (value, p.Value, "#D:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#D:DbType"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#D:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#D:Value2"); - - p = new NpgsqlParameter ("name", 5); - p.Value = value; - Assert.AreEqual (value, p.Value, "#E:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#E:DbType"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#E:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#E:Value2"); - - p = new NpgsqlParameter ("name", NpgsqlDbType.Text); - p.Value = value; - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#F:NpgsqlDbType"); - Assert.AreEqual (value, p.Value, "#F:Value"); -#else - NpgsqlParameter p = new NpgsqlParameter(); - try - { - p.Value = value; - Assert.Fail("#1"); - } - catch (ArgumentException ex) - { - // The parameter data type of Char is invalid - Assert.AreEqual(typeof(ArgumentException), ex.GetType(), "#2"); - Assert.IsNull(ex.InnerException, "#3"); - Assert.IsNotNull(ex.Message, "#4"); - Assert.IsNull(ex.ParamName, "#5"); - } -#endif - } + ushort.MaxValue, + uint.MaxValue, + ulong.MaxValue, + sbyte.MaxValue, + new NpgsqlParameter() + }; - [Test] -#if NET_2_0 - [Category ("NotWorking")] -#endif - public void InferType_CharArray() + var param = new NpgsqlParameter(); + + for (var i = 0; i < notsupported.Length; i++) { - Char[] value = new Char[] { 'A', 'X' }; - -#if NET_2_0 - String string_value = "AX"; - - NpgsqlParameter p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (value, p.Value, "#A:Value1"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#A:NpgsqlDbType"); - Assert.AreEqual (DbType.String, p.DbType, "#A:DbType"); - Assert.AreEqual (string_value, p.Value, "#A:Value2"); - - p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (value, p.Value, "#B:Value1"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#B:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#B:Value2"); - - p = new NpgsqlParameter (); - p.Value = value; - Assert.AreEqual (value, p.Value, "#C:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#C:DbType"); - Assert.AreEqual (string_value, p.Value, "#C:Value2"); - - p = new NpgsqlParameter ("name", value); - Assert.AreEqual (value, p.Value, "#D:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#D:DbType"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#D:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#D:Value2"); - - p = new NpgsqlParameter ("name", 5); - p.Value = value; - Assert.AreEqual (value, p.Value, "#E:Value1"); - Assert.AreEqual (DbType.String, p.DbType, "#E:DbType"); - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#E:NpgsqlDbType"); - Assert.AreEqual (string_value, p.Value, "#E:Value2"); - - p = new NpgsqlParameter ("name", NpgsqlDbType.Text); - p.Value = value; - Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#F:NpgsqlDbType"); - Assert.AreEqual (value, p.Value, "#F:Value"); -#else - NpgsqlParameter p = new NpgsqlParameter(); try { - p.Value = value; - Assert.Fail("#1"); + param.Value = notsupported[i]; + Assert.Fail("#A1:" + i); } catch (FormatException) { - // appears to be bug in .NET 1.1 while constructing - // exception message + // appears to be bug in .NET 1.1 while + // constructing exception message } catch (ArgumentException ex) { - // The parameter data type of Char[] is invalid - Assert.AreEqual(typeof(ArgumentException), ex.GetType(), "#2"); - Assert.IsNull(ex.InnerException, "#3"); - Assert.IsNotNull(ex.Message, "#4"); - Assert.IsNull(ex.ParamName, "#5"); + // The parameter data type of ... is invalid + Assert.AreEqual(typeof(ArgumentException), ex.GetType(), "#A2"); + Assert.IsNull(ex.InnerException, "#A3"); + Assert.IsNotNull(ex.Message, "#A4"); + Assert.IsNull(ex.ParamName, "#A5"); } -#endif } + } -#endif - - [Test] - [Ignore("")] - public void InferType_Invalid() - { - var notsupported = new object[] - { - ushort.MaxValue, - uint.MaxValue, - ulong.MaxValue, - sbyte.MaxValue, - new NpgsqlParameter() - }; - - var param = new NpgsqlParameter(); - - for (var i = 0; i < notsupported.Length; i++) - { - try - { - param.Value = notsupported[i]; - Assert.Fail("#A1:" + i); - } - catch (FormatException) - { - // appears to be bug in .NET 1.1 while - // constructing exception message - } - catch (ArgumentException ex) - { - // The parameter data type of ... is invalid - Assert.AreEqual(typeof(ArgumentException), ex.GetType(), "#A2"); - Assert.IsNull(ex.InnerException, "#A3"); - Assert.IsNotNull(ex.Message, "#A4"); - Assert.IsNull(ex.ParamName, "#A5"); - } - } - } + [Test] // bug #320196 + public void Parameter_null() + { + var param = new NpgsqlParameter("param", NpgsqlDbType.Numeric); + Assert.AreEqual(0, param.Scale, "#A1"); + param.Value = DBNull.Value; + Assert.AreEqual(0, param.Scale, "#A2"); + + param = new NpgsqlParameter("param", NpgsqlDbType.Integer); + Assert.AreEqual(0, param.Scale, "#B1"); + param.Value = DBNull.Value; + Assert.AreEqual(0, param.Scale, "#B2"); + } -#if NeedsPorting - [Test] - public void InferType_Object() - { - Object value = new Object(); + [Test] + [Ignore("")] + public void Parameter_type() + { + NpgsqlParameter p; + + // If Type is not set, then type is inferred from the value + // assigned. The Type should be inferred everytime Value is assigned + // If value is null or DBNull, then the current Type should be reset to Text. + p = new NpgsqlParameter(); + Assert.AreEqual(DbType.String, p.DbType, "#A1"); + Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#A2"); + p.Value = DBNull.Value; + Assert.AreEqual(DbType.String, p.DbType, "#B1"); + Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#B2"); + p.Value = 1; + Assert.AreEqual(DbType.Int32, p.DbType, "#C1"); + Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#C2"); + p.Value = DBNull.Value; + Assert.AreEqual(DbType.String, p.DbType, "#D1"); + Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#D2"); + p.Value = new byte[] { 0x0a }; + Assert.AreEqual(DbType.Binary, p.DbType, "#E1"); + Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#E2"); + p.Value = null; + Assert.AreEqual(DbType.String, p.DbType, "#F1"); + Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#F2"); + p.Value = DateTime.Now; + Assert.AreEqual(DbType.DateTime, p.DbType, "#G1"); + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#G2"); + p.Value = null; + Assert.AreEqual(DbType.String, p.DbType, "#H1"); + Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#H2"); + + // If DbType is set, then the NpgsqlDbType should not be + // inferred from the value assigned. + p = new NpgsqlParameter(); + p.DbType = DbType.DateTime; + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I1"); + p.Value = 1; + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I2"); + p.Value = null; + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I3"); + p.Value = DBNull.Value; + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I4"); + + // If NpgsqlDbType is set, then the DbType should not be + // inferred from the value assigned. + p = new NpgsqlParameter(); + p.NpgsqlDbType = NpgsqlDbType.Bytea; + Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J1"); + p.Value = 1; + Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J2"); + p.Value = null; + Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J3"); + p.Value = DBNull.Value; + Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J4"); + } - NpgsqlParameter param = new NpgsqlParameter(); - param.Value = value; - Assert.AreEqual(NpgsqlDbType.Variant, param.NpgsqlDbType, "#1"); - Assert.AreEqual(DbType.Object, param.DbType, "#2"); - } -#endif + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5428")] + public async Task Match_param_index_case_insensitively() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p,@P", conn); + cmd.Parameters.AddWithValue("p", "Hello world"); + await cmd.ExecuteNonQueryAsync(); + } -#if NeedsPorting -#if NET_2_0 - [Test] - public void LocaleId () - { - NpgsqlParameter parameter = new NpgsqlParameter (); - Assert.AreEqual (0, parameter.LocaleId, "#1"); - parameter.LocaleId = 15; - Assert.AreEqual(15, parameter.LocaleId, "#2"); - } -#endif -#endif + [Test] + [Ignore("")] + public void ParameterName() + { + var p = new NpgsqlParameter(); + p.ParameterName = "name"; + Assert.AreEqual("name", p.ParameterName, "#A:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#A:SourceColumn"); + + p.ParameterName = null; + Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); + + p.ParameterName = " "; + Assert.AreEqual(" ", p.ParameterName, "#C:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#C:SourceColumn"); + + p.ParameterName = " name "; + Assert.AreEqual(" name ", p.ParameterName, "#D:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#D:SourceColumn"); + + p.ParameterName = string.Empty; + Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); + } - [Test] // bug #320196 - public void ParameterNullTest() - { - var param = new NpgsqlParameter("param", NpgsqlDbType.Numeric); - Assert.AreEqual(0, param.Scale, "#A1"); - param.Value = DBNull.Value; - Assert.AreEqual(0, param.Scale, "#A2"); - - param = new NpgsqlParameter("param", NpgsqlDbType.Integer); - Assert.AreEqual(0, param.Scale, "#B1"); - param.Value = DBNull.Value; - Assert.AreEqual(0, param.Scale, "#B2"); - } + [Test] + public void ResetDbType() + { + NpgsqlParameter p; + + //Parameter with an assigned value but no DbType specified + p = new NpgsqlParameter("foo", 42); + p.ResetDbType(); + Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); + Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); + Assert.AreEqual(42, p.Value, "#A:Value"); + + p.DbType = DbType.DateTime; //assigning a DbType + Assert.AreEqual(DbType.DateTime, p.DbType, "#B:DbType1"); + Assert.AreEqual(NpgsqlDbType.TimestampTz, p.NpgsqlDbType, "#B:SqlDbType1"); + p.ResetDbType(); + Assert.AreEqual(DbType.Int32, p.DbType, "#B:DbType2"); + Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#B:SqlDbtype2"); + + //Parameter with an assigned NpgsqlDbType but no specified value + p = new NpgsqlParameter("foo", NpgsqlDbType.Integer); + p.ResetDbType(); + Assert.AreEqual(DbType.Object, p.DbType, "#C:DbType"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#C:NpgsqlDbType"); + + p.NpgsqlDbType = NpgsqlDbType.TimestampTz; //assigning a NpgsqlDbType + Assert.AreEqual(DbType.DateTime, p.DbType, "#D:DbType1"); + Assert.AreEqual(NpgsqlDbType.TimestampTz, p.NpgsqlDbType, "#D:SqlDbType1"); + p.ResetDbType(); + Assert.AreEqual(DbType.Object, p.DbType, "#D:DbType2"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#D:SqlDbType2"); + + p = new NpgsqlParameter(); + p.Value = DateTime.MaxValue; + Assert.AreEqual(DbType.DateTime2, p.DbType, "#E:DbType1"); + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#E:SqlDbType1"); + p.Value = null; + p.ResetDbType(); + Assert.AreEqual(DbType.Object, p.DbType, "#E:DbType2"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#E:SqlDbType2"); + + p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); + p.Value = DateTime.MaxValue; + p.ResetDbType(); + Assert.AreEqual(DbType.DateTime2, p.DbType, "#F:DbType"); + Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#F:NpgsqlDbType"); + Assert.AreEqual(DateTime.MaxValue, p.Value, "#F:Value"); + + p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); + p.Value = DBNull.Value; + p.ResetDbType(); + Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); + Assert.AreEqual(DBNull.Value, p.Value, "#G:Value"); + + p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); + p.Value = null; + p.ResetDbType(); + Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); + Assert.IsNull(p.Value, "#G:Value"); + } - [Test] - [Ignore("")] - public void ParameterType() - { - NpgsqlParameter p; - - // If Type is not set, then type is inferred from the value - // assigned. The Type should be inferred everytime Value is assigned - // If value is null or DBNull, then the current Type should be reset to Text. - p = new NpgsqlParameter(); - Assert.AreEqual(DbType.String, p.DbType, "#A1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#A2"); - p.Value = DBNull.Value; - Assert.AreEqual(DbType.String, p.DbType, "#B1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#B2"); - p.Value = 1; - Assert.AreEqual(DbType.Int32, p.DbType, "#C1"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#C2"); - p.Value = DBNull.Value; -#if NET_2_0 - Assert.AreEqual(DbType.String, p.DbType, "#D1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#D2"); -#else - Assert.AreEqual(DbType.Int32, p.DbType, "#D1"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#D2"); -#endif - p.Value = new byte[] { 0x0a }; - Assert.AreEqual(DbType.Binary, p.DbType, "#E1"); - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#E2"); - p.Value = null; -#if NET_2_0 - Assert.AreEqual(DbType.String, p.DbType, "#F1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#F2"); -#else - Assert.AreEqual(DbType.Binary, p.DbType, "#F1"); - Assert.AreEqual(NpgsqlDbType.VarBinary, p.NpgsqlDbType, "#F2"); -#endif - p.Value = DateTime.Now; - Assert.AreEqual(DbType.DateTime, p.DbType, "#G1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#G2"); - p.Value = null; -#if NET_2_0 - Assert.AreEqual(DbType.String, p.DbType, "#H1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#H2"); -#else - Assert.AreEqual(DbType.DateTime, p.DbType, "#H1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#H2"); -#endif + [Test] + public void ParameterName_retains_prefix() + => Assert.That(new NpgsqlParameter("@p", DbType.String).ParameterName, Is.EqualTo("@p")); - // If DbType is set, then the NpgsqlDbType should not be - // inferred from the value assigned. - p = new NpgsqlParameter(); - p.DbType = DbType.DateTime; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I1"); - p.Value = 1; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I2"); - p.Value = null; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I3"); - p.Value = DBNull.Value; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I4"); - - // If NpgsqlDbType is set, then the DbType should not be - // inferred from the value assigned. - p = new NpgsqlParameter(); - p.NpgsqlDbType = NpgsqlDbType.Bytea; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J1"); - p.Value = 1; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J2"); - p.Value = null; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J3"); - p.Value = DBNull.Value; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J4"); - } + [Test] + [Ignore("")] + public void SourceColumn() + { + var p = new NpgsqlParameter(); + p.SourceColumn = "name"; + Assert.AreEqual(string.Empty, p.ParameterName, "#A:ParameterName"); + Assert.AreEqual("name", p.SourceColumn, "#A:SourceColumn"); + + p.SourceColumn = null; + Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); + + p.SourceColumn = " "; + Assert.AreEqual(string.Empty, p.ParameterName, "#C:ParameterName"); + Assert.AreEqual(" ", p.SourceColumn, "#C:SourceColumn"); + + p.SourceColumn = " name "; + Assert.AreEqual(string.Empty, p.ParameterName, "#D:ParameterName"); + Assert.AreEqual(" name ", p.SourceColumn, "#D:SourceColumn"); + + p.SourceColumn = string.Empty; + Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); + Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); + } - [Test] - [Ignore("")] - public void ParameterName() - { - var p = new NpgsqlParameter(); - p.ParameterName = "name"; - Assert.AreEqual("name", p.ParameterName, "#A:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#A:SourceColumn"); - - p.ParameterName = null; - Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); - - p.ParameterName = " "; - Assert.AreEqual(" ", p.ParameterName, "#C:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#C:SourceColumn"); - - p.ParameterName = " name "; - Assert.AreEqual(" name ", p.ParameterName, "#D:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#D:SourceColumn"); - - p.ParameterName = string.Empty; - Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); - } + [Test] + public void Bug1011100_NpgsqlDbType() + { + var p = new NpgsqlParameter(); + p.Value = DBNull.Value; + Assert.AreEqual(DbType.Object, p.DbType, "#A:DbType"); + Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#A:NpgsqlDbType"); -#if NET_2_0 - [Test] - public void ResetDbType() - { - NpgsqlParameter p; - - //Parameter with an assigned value but no DbType specified - p = new NpgsqlParameter("foo", 42); - p.ResetDbType(); - Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); - Assert.AreEqual(42, p.Value, "#A:Value"); - - p.DbType = DbType.DateTime; //assigning a DbType - Assert.AreEqual(DbType.DateTime, p.DbType, "#B:DbType1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#B:SqlDbType1"); - p.ResetDbType(); - Assert.AreEqual(DbType.Int32, p.DbType, "#B:DbType2"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#B:SqlDbtype2"); - - //Parameter with an assigned NpgsqlDbType but no specified value - p = new NpgsqlParameter("foo", NpgsqlDbType.Integer); - p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#C:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#C:NpgsqlDbType"); - - p.DbType = DbType.DateTime; //assigning a NpgsqlDbType - Assert.AreEqual(DbType.DateTime, p.DbType, "#D:DbType1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#D:SqlDbType1"); - p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#D:DbType2"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#D:SqlDbType2"); - - p = new NpgsqlParameter(); - p.Value = DateTime.MaxValue; - Assert.AreEqual(DbType.DateTime, p.DbType, "#E:DbType1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#E:SqlDbType1"); - p.Value = null; - p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#E:DbType2"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#E:SqlDbType2"); - - p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); - p.Value = DateTime.MaxValue; - p.ResetDbType(); - Assert.AreEqual(DbType.DateTime, p.DbType, "#F:DbType"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#F:NpgsqlDbType"); - Assert.AreEqual(DateTime.MaxValue, p.Value, "#F:Value"); - - p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); - p.Value = DBNull.Value; - p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); - Assert.AreEqual(DBNull.Value, p.Value, "#G:Value"); - - p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); - p.Value = null; - p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); - Assert.IsNull(p.Value, "#G:Value"); - } + // Now change parameter value. + // Note that as we didn't explicitly specified a dbtype, the dbtype property should change when + // the value changes... -#endif + p.Value = 8; - [Test] - public void ParameterNameRetainsPrefix() - => Assert.That(new NpgsqlParameter("@p", DbType.String).ParameterName, Is.EqualTo("@p")); + Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); + Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); - [Test] - [Ignore("")] - public void SourceColumn() - { - var p = new NpgsqlParameter(); - p.SourceColumn = "name"; - Assert.AreEqual(string.Empty, p.ParameterName, "#A:ParameterName"); - Assert.AreEqual("name", p.SourceColumn, "#A:SourceColumn"); - - p.SourceColumn = null; - Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); - - p.SourceColumn = " "; - Assert.AreEqual(string.Empty, p.ParameterName, "#C:ParameterName"); - Assert.AreEqual(" ", p.SourceColumn, "#C:SourceColumn"); - - p.SourceColumn = " name "; - Assert.AreEqual(string.Empty, p.ParameterName, "#D:ParameterName"); - Assert.AreEqual(" name ", p.SourceColumn, "#D:SourceColumn"); - - p.SourceColumn = string.Empty; - Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); - } + //Assert.AreEqual(3510, p.Value, "#A:Value"); + //p.NpgsqlDbType = NpgsqlDbType.Varchar; + //Assert.AreEqual(DbType.String, p.DbType, "#B:DbType"); + //Assert.AreEqual(NpgsqlDbType.Varchar, p.NpgsqlDbType, "#B:NpgsqlDbType"); + //Assert.AreEqual(3510, p.Value, "#B:Value"); + } - [Test] - public void Bug1011100NpgsqlDbTypeTest() - { - var p = new NpgsqlParameter(); - p.Value = DBNull.Value; - Assert.AreEqual(DbType.Object, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#A:NpgsqlDbType"); + [Test] + public void NpgsqlParameter_Clone() + { + var param = new NpgsqlParameter(); + + param.Value = 5; + param.Precision = 1; + param.Scale = 1; + param.Size = 1; + param.Direction = ParameterDirection.Input; + param.IsNullable = true; + param.ParameterName = "parameterName"; + param.SourceColumn = "source_column"; + param.SourceVersion = DataRowVersion.Current; + param.NpgsqlValue = 5; + param.SourceColumnNullMapping = false; + + var newParam = param.Clone(); + + Assert.AreEqual(param.Value, newParam.Value); + Assert.AreEqual(param.Precision, newParam.Precision); + Assert.AreEqual(param.Scale, newParam.Scale); + Assert.AreEqual(param.Size, newParam.Size); + Assert.AreEqual(param.Direction, newParam.Direction); + Assert.AreEqual(param.IsNullable, newParam.IsNullable); + Assert.AreEqual(param.ParameterName, newParam.ParameterName); + Assert.AreEqual(param.TrimmedName, newParam.TrimmedName); + Assert.AreEqual(param.SourceColumn, newParam.SourceColumn); + Assert.AreEqual(param.SourceVersion, newParam.SourceVersion); + Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); + Assert.AreEqual(param.SourceColumnNullMapping, newParam.SourceColumnNullMapping); + Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); - // Now change parameter value. - // Note that as we didn't explicitly specified a dbtype, the dbtype property should change when - // the value changes... + } - p.Value = 8; + [Test] + public void Precision_via_interface() + { + var parameter = new NpgsqlParameter(); + var paramIface = (IDbDataParameter)parameter; - Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); + paramIface.Precision = 42; - //Assert.AreEqual(3510, p.Value, "#A:Value"); - //p.NpgsqlDbType = NpgsqlDbType.Varchar; - //Assert.AreEqual(DbType.String, p.DbType, "#B:DbType"); - //Assert.AreEqual(NpgsqlDbType.Varchar, p.NpgsqlDbType, "#B:NpgsqlDbType"); - //Assert.AreEqual(3510, p.Value, "#B:Value"); - } + Assert.AreEqual((byte)42, paramIface.Precision); + } - [Test] - public void ParameterCollectionHashLookupParameterRenameBug() - { - using (var command = new NpgsqlCommand()) - { - // Put plenty of parameters in the collection to turn on hash lookup functionality. - for (var i = 0; i < 10; i++) - { - command.Parameters.AddWithValue(string.Format("p{0:00}", i + 1), NpgsqlDbType.Text, string.Format("String parameter value {0}", i + 1)); - } - - // Make sure both hash lookups have been generated. - Assert.AreEqual(command.Parameters["p03"].ParameterName, "p03"); - Assert.AreEqual(command.Parameters["P03"].ParameterName, "p03"); - - // Rename the target parameter. - command.Parameters["p03"].ParameterName = "a_new_name"; - - try - { - // Try to exploit the hash lookup bug. - // If the bug exists, the hash lookups will be out of sync with the list, and be unable - // to find the parameter by its new name. - Assert.IsTrue(command.Parameters.IndexOf("a_new_name") >= 0); - } - catch (Exception e) - { - throw new Exception("NpgsqlParameterCollection hash lookup/parameter rename bug detected", e); - } - } - } + [Test] + public void Precision_via_base_class() + { + var parameter = new NpgsqlParameter(); + var paramBase = (DbParameter)parameter; - [Test] - public void NpgsqlParameterCloneTest() - { - var param = new NpgsqlParameter(); - - param.Value = 5; - param.Precision = 1; - param.Scale = 1; - param.Size = 1; - param.Direction = ParameterDirection.Input; - param.IsNullable = true; - param.ParameterName = "parameterName"; - param.SourceColumn = "source_column"; - param.SourceVersion = DataRowVersion.Current; - param.NpgsqlValue = 5; - param.SourceColumnNullMapping = false; - - var newParam = param.Clone(); - - Assert.AreEqual(param.Value, newParam.Value); - Assert.AreEqual(param.Precision, newParam.Precision); - Assert.AreEqual(param.Scale, newParam.Scale); - Assert.AreEqual(param.Size, newParam.Size); - Assert.AreEqual(param.Direction, newParam.Direction); - Assert.AreEqual(param.IsNullable, newParam.IsNullable); - Assert.AreEqual(param.ParameterName, newParam.ParameterName); - Assert.AreEqual(param.TrimmedName, newParam.TrimmedName); - Assert.AreEqual(param.SourceColumn, newParam.SourceColumn); - Assert.AreEqual(param.SourceVersion, newParam.SourceVersion); - Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); - Assert.AreEqual(param.SourceColumnNullMapping, newParam.SourceColumnNullMapping); - Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); + paramBase.Precision = 42; - } + Assert.AreEqual((byte)42, paramBase.Precision); + } - [Test] - public void CleanName() - { - var param = new NpgsqlParameter(); - var command = new NpgsqlCommand(); - command.Parameters.Add(param); + [Test] + public void Scale_via_interface() + { + var parameter = new NpgsqlParameter(); + var paramIface = (IDbDataParameter)parameter; - param.ParameterName = ""; + paramIface.Scale = 42; - // These should not throw exceptions - Assert.AreEqual(0, command.Parameters.IndexOf("")); - Assert.AreEqual("", param.ParameterName); - } + Assert.AreEqual((byte)42, paramIface.Scale); + } - [Test] - public void PrecisionViaInterface() - { - var parameter = new NpgsqlParameter(); - var paramIface = (IDbDataParameter)parameter; + [Test] + public void Scale_via_base_class() + { + var parameter = new NpgsqlParameter(); + var paramBase = (DbParameter)parameter; - paramIface.Precision = 42; + paramBase.Scale = 42; - Assert.AreEqual((byte)42, paramIface.Precision); - } + Assert.AreEqual((byte)42, paramBase.Scale); + } - [Test] - public void PrecisionViaBaseClass() + [Test] + public void Null_value_throws() + { + using var connection = OpenConnection(); + using var command = new NpgsqlCommand("SELECT @p", connection) { - var parameter = new NpgsqlParameter(); - var paramBase = (DbParameter)parameter; - - paramBase.Precision = 42; + Parameters = { new NpgsqlParameter("p", null) } + }; - Assert.AreEqual((byte)42, paramBase.Precision); - } + Assert.That(() => command.ExecuteReader(), Throws.InvalidOperationException); + } - [Test] - public void ScaleViaInterface() + [Test] + public void Null_value_with_nullable_type() + { + using var connection = OpenConnection(); + using var command = new NpgsqlCommand("SELECT @p", connection) { - var parameter = new NpgsqlParameter(); - var paramIface = (IDbDataParameter)parameter; + Parameters = { new NpgsqlParameter("p", null) } + }; + using var reader = command.ExecuteReader(); - paramIface.Scale = 42; + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetFieldValue(0), Is.Null); + } - Assert.AreEqual((byte)42, paramIface.Scale); - } + [Test] + public void DBNull_reuses_type_info([Values]bool generic) + { + // Bootstrap datasource. + using (var _ = OpenConnection()) {} + + var param = generic ? new NpgsqlParameter { Value = "value" } : new NpgsqlParameter { Value = "value" }; + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var typeInfo, out _, out _); + Assert.That(typeInfo, Is.Not.Null); + + // Make sure we don't reset the type info when setting DBNull. + param.Value = DBNull.Value; + param.GetResolutionInfo(out var secondTypeInfo, out _, out _); + Assert.That(typeInfo, Is.SameAs(secondTypeInfo)); + + // Make sure we don't resolve a different type info either. + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); + Assert.That(secondTypeInfo, Is.SameAs(thirdTypeInfo)); + } - [Test] - public void ScaleViaBaseClass() - { - var parameter = new NpgsqlParameter(); - var paramBase = (DbParameter)parameter; + [Test] + public void DBNull_followed_by_non_null_reresolves([Values]bool generic) + { + // Bootstrap datasource. + using (var _ = OpenConnection()) {} + + var param = generic ? new NpgsqlParameter { Value = DBNull.Value } : new NpgsqlParameter { Value = DBNull.Value }; + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var typeInfo, out _, out var pgTypeId); + Assert.That(typeInfo, Is.Not.Null); + Assert.That(pgTypeId.IsUnspecified, Is.True); + + param.Value = "value"; + param.GetResolutionInfo(out var secondTypeInfo, out _, out _); + Assert.That(secondTypeInfo, Is.Null); + + // Make sure we don't resolve the same type info either. + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); + Assert.That(typeInfo, Is.Not.SameAs(thirdTypeInfo)); + } - paramBase.Scale = 42; + [Test] + public void Changing_value_type_reresolves([Values]bool generic) + { + // Bootstrap datasource. + using (var _ = OpenConnection()) {} + + var param = generic ? new NpgsqlParameter { Value = "value" } : new NpgsqlParameter { Value = "value" }; + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var typeInfo, out _, out _); + Assert.That(typeInfo, Is.Not.Null); + + param.Value = 1; + param.GetResolutionInfo(out var secondTypeInfo, out _, out _); + Assert.That(secondTypeInfo, Is.Null); + + // Make sure we don't resolve a different type info either. + param.ResolveTypeInfo(DataSource.SerializerOptions); + param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); + Assert.That(typeInfo, Is.Not.SameAs(thirdTypeInfo)); + } - Assert.AreEqual((byte)42, paramBase.Scale); - } +#if NeedsPorting + [Test] + [Category ("NotWorking")] + public void InferType_Char() + { + Char value = 'X'; + + String string_value = "X"; + + NpgsqlParameter p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#A:NpgsqlDbType"); + Assert.AreEqual (DbType.String, p.DbType, "#A:DbType"); + Assert.AreEqual (string_value, p.Value, "#A:Value"); + + p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (value, p.Value, "#B:Value1"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#B:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#B:Value2"); + + p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (value, p.Value, "#C:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#C:DbType"); + Assert.AreEqual (string_value, p.Value, "#C:Value2"); + + p = new NpgsqlParameter ("name", value); + Assert.AreEqual (value, p.Value, "#D:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#D:DbType"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#D:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#D:Value2"); + + p = new NpgsqlParameter ("name", 5); + p.Value = value; + Assert.AreEqual (value, p.Value, "#E:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#E:DbType"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#E:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#E:Value2"); + + p = new NpgsqlParameter ("name", NpgsqlDbType.Text); + p.Value = value; + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#F:NpgsqlDbType"); + Assert.AreEqual (value, p.Value, "#F:Value"); + } - [Test] - public void ResolveHandler_NullValue_ThrowsInvalidOperationException() - { - using var connection = OpenConnection(); - using var command = new NpgsqlCommand("SELECT @p", connection) - { - Parameters = { new NpgsqlParameter("p", null) } - }; + [Test] + [Category ("NotWorking")] + public void InferType_CharArray() + { + Char[] value = new Char[] { 'A', 'X' }; + + String string_value = "AX"; + + NpgsqlParameter p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (value, p.Value, "#A:Value1"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#A:NpgsqlDbType"); + Assert.AreEqual (DbType.String, p.DbType, "#A:DbType"); + Assert.AreEqual (string_value, p.Value, "#A:Value2"); + + p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (value, p.Value, "#B:Value1"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#B:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#B:Value2"); + + p = new NpgsqlParameter (); + p.Value = value; + Assert.AreEqual (value, p.Value, "#C:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#C:DbType"); + Assert.AreEqual (string_value, p.Value, "#C:Value2"); + + p = new NpgsqlParameter ("name", value); + Assert.AreEqual (value, p.Value, "#D:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#D:DbType"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#D:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#D:Value2"); + + p = new NpgsqlParameter ("name", 5); + p.Value = value; + Assert.AreEqual (value, p.Value, "#E:Value1"); + Assert.AreEqual (DbType.String, p.DbType, "#E:DbType"); + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#E:NpgsqlDbType"); + Assert.AreEqual (string_value, p.Value, "#E:Value2"); + + p = new NpgsqlParameter ("name", NpgsqlDbType.Text); + p.Value = value; + Assert.AreEqual (NpgsqlDbType.Text, p.NpgsqlDbType, "#F:NpgsqlDbType"); + Assert.AreEqual (value, p.Value, "#F:Value"); + } - Assert.That(() => command.ExecuteReader(), Throws.InvalidOperationException); - } + [Test] + public void InferType_Object() + { + Object value = new Object(); - [Test] - public void ResolveHandler_NullableValue_Succeeds() - { - using var connection = OpenConnection(); - using var command = new NpgsqlCommand("SELECT @p", connection) - { - Parameters = { new NpgsqlParameter("p", null) } - }; - using var reader = command.ExecuteReader(); + NpgsqlParameter param = new NpgsqlParameter(); + param.Value = value; + Assert.AreEqual(NpgsqlDbType.Variant, param.NpgsqlDbType, "#1"); + Assert.AreEqual(DbType.Object, param.DbType, "#2"); + } - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetFieldValue(0), Is.Null); - } + [Test] + public void LocaleId () + { + NpgsqlParameter parameter = new NpgsqlParameter (); + Assert.AreEqual (0, parameter.LocaleId, "#1"); + parameter.LocaleId = 15; + Assert.AreEqual(15, parameter.LocaleId, "#2"); } +#endif } diff --git a/test/Npgsql.Tests/PgPassEntryTests.cs b/test/Npgsql.Tests/PgPassEntryTests.cs index 4ddf9fb49f..9db518aabc 100644 --- a/test/Npgsql.Tests/PgPassEntryTests.cs +++ b/test/Npgsql.Tests/PgPassEntryTests.cs @@ -2,101 +2,99 @@ using NUnit.Framework; using NUnit.Framework.Constraints; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class PgPassEntryTests { - [TestFixture] - public class PgPassEntryTests + [Test] + public void Parses_well_formed_entry() + { + var input = "test:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + Assert.That(entry, Is.Not.Null); + Assert.That("test", Is.EqualTo(entry.Host)); + Assert.That(1234, Is.EqualTo(entry.Port)); + Assert.That("test2", Is.EqualTo(entry.Database)); + Assert.That("test3", Is.EqualTo(entry.Username)); + Assert.That("test4", Is.EqualTo(entry.Password)); + } + + [Test] + [TestCase("test:1234:test2:test3")] + [TestCase("test:myport:test2:test3:test4")] + public void Bad_entry_throws(string input) + { + ActualValueDelegate createDelegate = () => PgPassFile.Entry.Parse(input); + Assert.That(createDelegate, Throws.TypeOf()); + } + + [Test] + public void Escaped_characters() + { + var input = "t\\:est:1234:test2:test3:test\\\\4"; + var entry = PgPassFile.Entry.Parse(input); + + Assert.That(entry, Is.Not.Null); + Assert.That("t:est", Is.EqualTo(entry.Host)); + Assert.That(1234, Is.EqualTo(entry.Port)); + Assert.That("test2", Is.EqualTo(entry.Database)); + Assert.That("test3", Is.EqualTo(entry.Username)); + Assert.That("test\\4", Is.EqualTo(entry.Password)); + } + + [Test] + public void Match_true_for_exact_match() + { + var input = "test:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + var isMatch = entry.IsMatch("test", 1234, "test2", "test3"); + + Assert.That(isMatch, Is.True); + } + + [Test] + public void Match_true_for_wildcard_entry() + { + var input = "*:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + var isMatch = entry.IsMatch("test", 1234, "test2", "test3"); + + Assert.That(isMatch, Is.True); + } + + [Test] + public void Match_true_for_wildcard_query() + { + var input = "test:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + var isMatch = entry.IsMatch(null, 1234, "test2", "test3"); + + Assert.That(isMatch, Is.True); + } + + [Test] + public void Match_false_for_bad_query() + { + var input = "test:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + var isMatch = entry.IsMatch("notamatch", 1234, "test2", "test3"); + + Assert.That(isMatch, Is.False); + } + + [Test] + public void Match_true_for_null_query() { - [Test] - public void ParsesWellFormedEntry() - { - var input = "test:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - Assert.That(entry, Is.Not.Null); - Assert.That("test", Is.EqualTo(entry.Host)); - Assert.That(1234, Is.EqualTo(entry.Port)); - Assert.That("test2", Is.EqualTo(entry.Database)); - Assert.That("test3", Is.EqualTo(entry.Username)); - Assert.That("test4", Is.EqualTo(entry.Password)); - } - - [Test] - [TestCase("test:1234:test2:test3")] - [TestCase("test:myport:test2:test3:test4")] - public void ThrowFormatExceptionForBadEntry(string input) - { - ActualValueDelegate createDelegate = () => PgPassFile.Entry.Parse(input); - Assert.That(createDelegate, Throws.TypeOf()); - } - - [Test] - public void HandleEscapedCharacters() - { - var input = "t\\:est:1234:test2:test3:test\\\\4"; - var entry = PgPassFile.Entry.Parse(input); - - Assert.That(entry, Is.Not.Null); - Assert.That("t:est", Is.EqualTo(entry.Host)); - Assert.That(1234, Is.EqualTo(entry.Port)); - Assert.That("test2", Is.EqualTo(entry.Database)); - Assert.That("test3", Is.EqualTo(entry.Username)); - Assert.That("test\\4", Is.EqualTo(entry.Password)); - } - - [Test] - public void MatchTrueForExactMatch() - { - var input = "test:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - var isMatch = entry.IsMatch("test", 1234, "test2", "test3"); - - Assert.That(isMatch, Is.True); - } - - [Test] - public void MatchTrueForWildcardEntry() - { - var input = "*:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - var isMatch = entry.IsMatch("test", 1234, "test2", "test3"); - - Assert.That(isMatch, Is.True); - } - - [Test] - public void MatchTrueForWildcardQuery() - { - var input = "test:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - var isMatch = entry.IsMatch(null, 1234, "test2", "test3"); - - Assert.That(isMatch, Is.True); - } - - [Test] - public void MatchFalseForBadQuery() - { - var input = "test:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - var isMatch = entry.IsMatch("notamatch", 1234, "test2", "test3"); - - Assert.That(isMatch, Is.False); - } - - [Test] - public void MatchTrueForNullQuery() - { - var input = "test:1234:test2:test3:test4"; - var entry = PgPassFile.Entry.Parse(input); - - var isMatch = entry.IsMatch(null, 1234, "test2", "test3"); - - Assert.That(isMatch, Is.True); - } + var input = "test:1234:test2:test3:test4"; + var entry = PgPassFile.Entry.Parse(input); + + var isMatch = entry.IsMatch(null, 1234, "test2", "test3"); + + Assert.That(isMatch, Is.True); } } diff --git a/test/Npgsql.Tests/PgPassFileTests.cs b/test/Npgsql.Tests/PgPassFileTests.cs index 3e906aff0b..593e522e89 100644 --- a/test/Npgsql.Tests/PgPassFileTests.cs +++ b/test/Npgsql.Tests/PgPassFileTests.cs @@ -2,54 +2,52 @@ using System.Linq; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class PgPassFileTests { - [TestFixture] - public class PgPassFileTests + [Test] + public void Should_parse_all_entries() + { + var file = new PgPassFile(_pgpassFile); + var entries = file.Entries.ToList(); + Assert.That(entries.Count, Is.EqualTo(3)); + } + + [Test] + public void Should_find_first_entry_when_multiple_match() + { + var file = new PgPassFile(_pgpassFile); + var entry = file.GetFirstMatchingEntry("testhost")!; + Assert.That(entry.Password, Is.EqualTo("testpass")); + } + + [Test] + public void Should_find_default_for_no_matches() { - [Test] - public void ShouldParseAllEntries() - { - var file = new PgPassFile(_pgpassFile); - var entries = file.Entries.ToList(); - Assert.That(entries.Count, Is.EqualTo(3)); - } - - [Test] - public void ShouldFindFirstEntryWhenMultipleMatch() - { - var file = new PgPassFile(_pgpassFile); - var entry = file.GetFirstMatchingEntry("testhost")!; - Assert.That(entry.Password, Is.EqualTo("testpass")); - } - - [Test] - public void ShouldFindDefaultForNoMatches() - { - var file = new PgPassFile(_pgpassFile); - var entry = file.GetFirstMatchingEntry("notarealhost")!; - Assert.That(entry.Password, Is.EqualTo("defaultpass")); - } - - readonly string _pgpassFile = Path.GetTempFileName(); - - [OneTimeSetUp] - public void CreateTestFile() - { - // set up pgpass file with fake content that can be used for this test - const string content = @"testhost:1234:testdatabase:testuser:testpass + var file = new PgPassFile(_pgpassFile); + var entry = file.GetFirstMatchingEntry("notarealhost")!; + Assert.That(entry.Password, Is.EqualTo("defaultpass")); + } + + readonly string _pgpassFile = Path.GetTempFileName(); + + [OneTimeSetUp] + public void CreateTestFile() + { + // set up pgpass file with fake content that can be used for this test + const string content = @"testhost:1234:testdatabase:testuser:testpass testhost:*:*:*:testdefaultpass # helpful comment goes here *:*:*:*:defaultpass"; - File.WriteAllText(_pgpassFile, content); - } + File.WriteAllText(_pgpassFile, content); + } - [OneTimeTearDown] - public void DeleteTestFile() - { - if (File.Exists(_pgpassFile)) - File.Delete(_pgpassFile); - } + [OneTimeTearDown] + public void DeleteTestFile() + { + if (File.Exists(_pgpassFile)) + File.Delete(_pgpassFile); } } diff --git a/test/Npgsql.Tests/PoolManagerTests.cs b/test/Npgsql.Tests/PoolManagerTests.cs index 0513a7617b..afd716dab5 100644 --- a/test/Npgsql.Tests/PoolManagerTests.cs +++ b/test/Npgsql.Tests/PoolManagerTests.cs @@ -1,77 +1,79 @@ using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +[NonParallelizable] +class PoolManagerTests : TestBase { - [NonParallelizable] - class PoolManagerTests : TestBase + [Test] + public void With_canonical_connection_string() { - [Test] - public void WithCanonicalConnString() + var connString = new NpgsqlConnectionStringBuilder(ConnectionString).ToString(); + using (var conn = new NpgsqlConnection(connString)) + conn.Open(); + var connString2 = new NpgsqlConnectionStringBuilder(ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString).ToString(); - using (var conn = new NpgsqlConnection(connString)) - conn.Open(); - var connString2 = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = "Another connstring" - }.ToString(); - using (var conn = new NpgsqlConnection(connString2)) - conn.Open(); - } + ApplicationName = "Another connstring" + }.ToString(); + using (var conn = new NpgsqlConnection(connString2)) + conn.Open(); + } #if DEBUG - [Test] - public void ManyPools() + [Test] + public void Many_pools() + { + PoolManager.Reset(); + for (var i = 0; i < 15; i++) { - PoolManager.Reset(); - for (var i = 0; i < PoolManager.InitialPoolsSize + 1; i++) + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = "App" + i - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - conn.Open(); - } - PoolManager.Reset(); + ApplicationName = "App" + i + }.ToString(); + using var conn = new NpgsqlConnection(connString); + conn.Open(); } + PoolManager.Reset(); + } #endif - [Test] - public void ClearAll() - { - using (OpenConnection()) {} - // Now have one connection in the pool - Assert.That(PoolManager.TryGetValue(ConnectionString, out var pool), Is.True); - Assert.That(pool!.Statistics.Idle, Is.EqualTo(1)); + [Test] + public void ClearAllPools() + { + using (var conn = new NpgsqlConnection(ConnectionString)) + conn.Open(); + // Now have one connection in the pool + Assert.That(PoolManager.Pools.TryGetValue(ConnectionString, out var pool), Is.True); + Assert.That(pool!.Statistics.Idle, Is.EqualTo(1)); - NpgsqlConnection.ClearAllPools(); - Assert.That(pool.Statistics.Idle, Is.Zero); - Assert.That(pool.Statistics.Total, Is.Zero); - } + NpgsqlConnection.ClearAllPools(); + Assert.That(pool.Statistics.Idle, Is.Zero); + Assert.That(pool.Statistics.Total, Is.Zero); + } - [Test] - public void ClearAllWithBusy() + [Test] + public void ClearAllPools_with_busy() + { + NpgsqlDataSource? pool; + using (var conn = new NpgsqlConnection(ConnectionString)) { - ConnectorPool? pool; - using (OpenConnection()) - { - using (OpenConnection()) { } - // We have one idle, one busy + conn.Open(); + using (var anotherConn = new NpgsqlConnection(ConnectionString)) + anotherConn.Open(); + // We have one idle, one busy - NpgsqlConnection.ClearAllPools(); - Assert.That(PoolManager.TryGetValue(ConnectionString, out pool), Is.True); - Assert.That(pool!.Statistics.Idle, Is.Zero); - Assert.That(pool.Statistics.Total, Is.EqualTo(1)); - } - Assert.That(pool.Statistics.Idle, Is.Zero); - Assert.That(pool.Statistics.Total, Is.Zero); + NpgsqlConnection.ClearAllPools(); + Assert.That(PoolManager.Pools.TryGetValue(ConnectionString, out pool), Is.True); + Assert.That(pool!.Statistics.Idle, Is.Zero); + Assert.That(pool.Statistics.Total, Is.EqualTo(1)); } + Assert.That(pool.Statistics.Idle, Is.Zero); + Assert.That(pool.Statistics.Total, Is.Zero); + } - [SetUp] - public void Setup() => PoolManager.Reset(); + [SetUp] + public void Setup() => PoolManager.Reset(); - [TearDown] - public void Teardown() => PoolManager.Reset(); - } + [TearDown] + public void Teardown() => PoolManager.Reset(); } diff --git a/test/Npgsql.Tests/PoolTests.cs b/test/Npgsql.Tests/PoolTests.cs index f8f8f7e891..d9024dd0dd 100644 --- a/test/Npgsql.Tests/PoolTests.cs +++ b/test/Npgsql.Tests/PoolTests.cs @@ -1,574 +1,496 @@ using System; -using System.Collections.Generic; using System.Linq; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using NUnit.Framework; -using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +class PoolTests : TestBase { - [NonParallelizable] - class PoolTests : TestBase + [Test] + public async Task MinPoolSize_equals_MaxPoolSize() { - [Test] - public void MinPoolSizeEqualsMaxPoolSize() + await using var dataSource = CreateDataSource(csb => { - using (var conn = CreateConnection(new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(MinPoolSizeEqualsMaxPoolSize), - MinPoolSize = 30, - MaxPoolSize = 30 - }.ToString())) - { - conn.Open(); - } - } + csb.MinPoolSize = 30; + csb.MaxPoolSize = 30; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + } - [Test] - public void MinPoolSizeLargerThanMaxPoolSize() + [Test] + public void MinPoolSize_bigger_than_MaxPoolSize_throws() + => Assert.ThrowsAsync(async () => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + await using var dataSource = CreateDataSource(csb => { - ApplicationName = nameof(MinPoolSizeLargerThanMaxPoolSize), - MinPoolSize = 2, - MaxPoolSize = 1 - }.ToString(); + csb.MinPoolSize = 2; + csb.MaxPoolSize = 1; + }); + }); - Assert.That(() => CreateConnection(connString), Throws.Exception.TypeOf()); - } + [Test] + public async Task Reuse_connector_before_creating_new() + { + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + var backendId = conn.Connector!.BackendProcessId; + await conn.CloseAsync(); + await conn.OpenAsync(); + Assert.That(conn.Connector.BackendProcessId, Is.EqualTo(backendId)); + } - [Test] - public void ReuseConnectorBeforeCreatingNew() + [Test] + public async Task Get_connector_from_exhausted_pool([Values(true, false)] bool async) + { + await using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ReuseConnectorBeforeCreatingNew), - }.ToString(); + csb.MaxPoolSize = 1; + csb.Timeout = 0; + }); - using (var conn = CreateConnection(connString)) - { - conn.Open(); - var backendId = conn.Connector!.BackendProcessId; - conn.Close(); - conn.Open(); - Assert.That(conn.Connector.BackendProcessId, Is.EqualTo(backendId)); - } - } + await using var conn1 = await dataSource.OpenConnectionAsync(); - [Test, Timeout(10000)] - public void GetConnectorFromExhaustedPool() + // Pool is exhausted + await using var conn2 = dataSource.CreateConnection(); + _ = Task.Delay(1000).ContinueWith(async _ => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(GetConnectorFromExhaustedPool), - MaxPoolSize = 1, - Timeout = 0 - }.ToString(); - - using (var conn1 = CreateConnection(connString)) - { - conn1.Open(); - - // Pool is exhausted - using (var conn2 = CreateConnection(connString)) - { - new Timer(o => conn1.Close(), null, 1000, Timeout.Infinite); - conn2.Open(); - } - } - } - - //[Test, Explicit, Timeout(10000)] - public async Task GetConnectorFromExhaustedPoolAsync() - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(GetConnectorFromExhaustedPoolAsync), - MaxPoolSize = 1, - Timeout = 0 - }.ToString(); - - using (var conn1 = CreateConnection(connString)) - { - await conn1.OpenAsync(); - - // Pool is exhausted - using (var conn2 = CreateConnection(connString)) - using (new Timer(o => conn1.Close(), null, 1000, Timeout.Infinite)) - await conn2.OpenAsync(); - } - } + if (async) + await conn1.CloseAsync(); + else + conn1.Close(); + }); + if (async) + await conn2.OpenAsync(); + else + conn2.Open(); + } - [Test] - public void TimeoutGettingConnectorFromExhaustedPool() + [Test] + public async Task Timeout_getting_connector_from_exhausted_pool([Values(true, false)] bool async) + { + await using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(TimeoutGettingConnectorFromExhaustedPool), - MaxPoolSize = 1, - Timeout = 2 - }.ToString(); + csb.MaxPoolSize = 1; + csb.Timeout = 2; + }); - using (var conn1 = CreateConnection(connString)) - { - conn1.Open(); - // Pool is exhausted - using (var conn2 = CreateConnection(connString)) - Assert.That(() => conn2.Open(), Throws.Exception.TypeOf()); - } - // conn1 should now be back in the pool as idle - using (var conn3 = CreateConnection(connString)) - conn3.Open(); - } - - [Test] - public async Task TimeoutGettingConnectorFromExhaustedPoolAsync() + await using (var conn1 = dataSource.CreateConnection()) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(TimeoutGettingConnectorFromExhaustedPoolAsync), - MaxPoolSize = 1, - Timeout = 2 - }.ToString(); + await conn1.OpenAsync(); + // Pool is now exhausted - using (var conn1 = CreateConnection(connString)) - { - await conn1.OpenAsync(); + await using var conn2 = dataSource.CreateConnection(); + var e = async + ? Assert.ThrowsAsync(async () => await conn2.OpenAsync())! + : Assert.Throws(() => conn2.Open())!; - // Pool is exhausted - using (var conn2 = CreateConnection(connString)) - Assert.That(async () => await conn2.OpenAsync(), Throws.Exception.TypeOf()); - } - // conn1 should now be back in the pool as idle - using (var conn3 = CreateConnection(connString)) - conn3.Open(); + Assert.That(e.InnerException, Is.TypeOf()); } - [Test, Timeout(10000)] - [Explicit("Timing-based")] - public async Task CancelOpenAsync() - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(CancelOpenAsync), - MaxPoolSize = 1, - }.ToString(); - - using (var conn1 = CreateConnection(connString)) - { - await conn1.OpenAsync(); - - Assert.True(PoolManager.TryGetValue(connString, out var pool)); - AssertPoolState(pool, open: 1, idle: 0); + // conn1 should now be back in the pool as idle + await using var conn3 = await dataSource.OpenConnectionAsync(); + } - // Pool is exhausted - using (var conn2 = CreateConnection(connString)) - { - var cts = new CancellationTokenSource(1000); - var openTask = conn2.OpenAsync(cts.Token); - AssertPoolState(pool, open: 1, idle: 0); - Assert.That(async () => await openTask, Throws.Exception.TypeOf()); - } + [Test] + [Explicit("Timing-based")] + public async Task OpenAsync_cancel() + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn1 = await dataSource.OpenConnectionAsync(); - AssertPoolState(pool, open: 1, idle: 0); - using (var conn2 = CreateConnection(connString)) - using (new Timer(o => conn1.Close(), null, 1000, Timeout.Infinite)) - { - await conn2.OpenAsync(); - AssertPoolState(pool, open: 1, idle: 0); - } - AssertPoolState(pool, open: 1, idle: 1); - } - } + AssertPoolState(dataSource, open: 1, idle: 0); - [Test, Description("Makes sure that when a pooled connection is closed it's properly reset, and that parameter settings aren't leaked")] - public void ResetOnClose() + // Pool is exhausted + await using (var conn2 = dataSource.CreateConnection()) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ResetOnClose), - SearchPath = "public" - }.ToString(); - using (var conn = CreateConnection(connString)) - { - conn.Open(); - Assert.That(conn.ExecuteScalar("SHOW search_path"), Is.Not.Contains("pg_temp")); - var backendId = conn.Connector!.BackendProcessId; - conn.ExecuteNonQuery("SET search_path=pg_temp"); - conn.Close(); - - conn.Open(); - Assert.That(conn.Connector.BackendProcessId, Is.EqualTo(backendId)); - Assert.That(conn.ExecuteScalar("SHOW search_path"), Is.EqualTo("public")); - } + var cts = new CancellationTokenSource(1000); + var openTask = conn2.OpenAsync(cts.Token); + AssertPoolState(dataSource, open: 1, idle: 0); + Assert.That(async () => await openTask, Throws.Exception.TypeOf()); } - [Test] - public void ArgumentExceptionOnZeroPruningInterval() + AssertPoolState(dataSource, open: 1, idle: 0); + await using (var conn2 = dataSource.CreateConnection()) + await using (new Timer(o => conn1.Close(), null, 1000, Timeout.Infinite)) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ArgumentExceptionOnZeroPruningInterval), - ConnectionPruningInterval = 0 - }.ToString(); - - Assert.Throws(() => OpenConnection(connString)); + await conn2.OpenAsync(); + AssertPoolState(dataSource, open: 1, idle: 0); } + AssertPoolState(dataSource, open: 1, idle: 1); + } - [Test] - public void ArgumentExceptionOnPruningIntervalLargerThanIdleLifetime() - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ArgumentExceptionOnPruningIntervalLargerThanIdleLifetime), - ConnectionIdleLifetime = 1, - ConnectionPruningInterval = 2 - }.ToString(); + [Test, Description("Makes sure that when a pooled connection is closed it's properly reset, and that parameter settings aren't leaked")] + public async Task ResetOnClose() + { + await using var dataSource = CreateDataSource(csb => csb.SearchPath = "public"); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SHOW search_path"), Is.Not.Contains("pg_temp")); + var backendId = conn.Connector!.BackendProcessId; + await conn.ExecuteNonQueryAsync("SET search_path=pg_temp"); + await conn.CloseAsync(); + + await conn.OpenAsync(); + Assert.That(conn.Connector.BackendProcessId, Is.EqualTo(backendId)); + Assert.That(await conn.ExecuteScalarAsync("SHOW search_path"), Is.EqualTo("public")); + } - Assert.Throws(() => OpenConnection(connString)); - } + [Test] + public void ConnectionPruningInterval_zero_throws() + => Assert.ThrowsAsync(async () => + { + await using var dataSource = CreateDataSource(csb => csb.ConnectionPruningInterval = 0); + }); - [Theory, Explicit("Slow, and flaky under pressure, based on timing")] - [TestCase(0, 2, 1, 2)] // min pool size 0, sample twice - [TestCase(1, 2, 1, 2)] // min pool size 1, sample twice - [TestCase(2, 2, 1, 2)] // min pool size 2, sample twice - [TestCase(2, 3, 2, 2)] // test rounding up, should sample twice. - [TestCase(2, 1, 1, 1)] // test sample once. - [TestCase(2, 20, 3, 7)] // test high samples. - public void PruneIdleConnectors(int minPoolSize, int connectionIdleLifeTime, int connectionPruningInterval, int samples) + [Test] + public void ConnectionPruningInterval_bigger_than_ConnectionIdleLifetime_throws() + => Assert.ThrowsAsync(async () => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PruneIdleConnectors), - MinPoolSize = minPoolSize, - ConnectionIdleLifetime = connectionIdleLifeTime, - ConnectionPruningInterval = connectionPruningInterval - }.ToString(); + await using var dataSource = CreateDataSource(csb => + { + csb.ConnectionIdleLifetime = 1; + csb.ConnectionPruningInterval = 2; + }); + }); + + [Theory, Explicit("Slow, and flaky under pressure, based on timing")] + [TestCase(0, 2, 1, 2)] // min pool size 0, sample twice + [TestCase(1, 2, 1, 2)] // min pool size 1, sample twice + [TestCase(2, 2, 1, 2)] // min pool size 2, sample twice + [TestCase(2, 3, 2, 2)] // test rounding up, should sample twice. + [TestCase(2, 1, 1, 1)] // test sample once. + [TestCase(2, 20, 3, 7)] // test high samples. + public async Task Prune_idle_connectors(int minPoolSize, int connectionIdleLifeTime, int connectionPruningInterval, int samples) + { + await using var dataSource = CreateDataSource(csb => + { + csb.MinPoolSize = minPoolSize; + csb.ConnectionIdleLifetime = connectionIdleLifeTime; + csb.ConnectionPruningInterval = connectionPruningInterval; + }); - var connectionPruningIntervalMs = connectionPruningInterval * 1000; + var connectionPruningIntervalMs = connectionPruningInterval * 1000; - using (var conn1 = OpenConnection(connString)) - using (var conn2 = OpenConnection(connString)) - using (var conn3 = OpenConnection(connString)) - { - Assert.True(PoolManager.TryGetValue(connString, out var pool)); + await using var conn1 = await dataSource.OpenConnectionAsync(); + await using var conn2 = await dataSource.OpenConnectionAsync(); + await using var conn3 = await dataSource.OpenConnectionAsync(); - conn1.Close(); - conn2.Close(); - AssertPoolState(pool!, open: 3, idle: 2); + await conn1.CloseAsync(); + await conn2.CloseAsync(); + AssertPoolState(dataSource!, open: 3, idle: 2); - var paddingMs = 100; // 100ms - var sleepInterval = connectionPruningIntervalMs + paddingMs; - var total = 0; + var paddingMs = 100; // 100ms + var sleepInterval = connectionPruningIntervalMs + paddingMs; + var total = 0; - for (var i = 0; i < samples - 1; i++) - { - total += sleepInterval; - Thread.Sleep(sleepInterval); - // ConnectionIdleLifetime not yet reached. - AssertPoolState(pool, open: 3, idle: 2); - } + for (var i = 0; i < samples - 1; i++) + { + total += sleepInterval; + Thread.Sleep(sleepInterval); + // ConnectionIdleLifetime not yet reached. + AssertPoolState(dataSource, open: 3, idle: 2); + } - // final cycle to do pruning. - Thread.Sleep(Math.Max(sleepInterval, (connectionIdleLifeTime * 1000) - total)); + // final cycle to do pruning. + Thread.Sleep(Math.Max(sleepInterval, (connectionIdleLifeTime * 1000) - total)); - // ConnectionIdleLifetime reached, we still have one connection open minimum, - // and as a result we have minPoolSize - 1 idle connections. - AssertPoolState(pool, open: Math.Max(1, minPoolSize), idle: Math.Max(0, minPoolSize - 1)); - } - } + // ConnectionIdleLifetime reached, we still have one connection open minimum, + // and as a result we have minPoolSize - 1 idle connections. + AssertPoolState(dataSource, open: Math.Max(1, minPoolSize), idle: Math.Max(0, minPoolSize - 1)); + } - [Test, Description("Makes sure that when a waiting async open is is given a connection, the continuation is executed in the TP rather than on the closing thread")] - public void CloseReleasesWaiterOnAnotherThread() + [Test] + [Explicit("Timing-based")] + public async Task Prune_counts_max_lifetime_exceeded() + { + await using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(CloseReleasesWaiterOnAnotherThread), - MaxPoolSize = 1 - }.ToString(); - var conn1 = CreateConnection(connString); - try - { - conn1.Open(); // Pool is now exhausted + csb.MinPoolSize = 0; + // Idle lifetime 2 seconds, 2 samples + csb.ConnectionIdleLifetime = 2; + csb.ConnectionPruningInterval = 1; + csb.ConnectionLifetime = 5; + }); + + // conn1 will exceed max lifetime + await using var conn1 = await dataSource.OpenConnectionAsync(); + + // make conn1 4 seconds older than the others, so it exceeds max lifetime + Thread.Sleep(4000); + + await using var conn2 = await dataSource.OpenConnectionAsync(); + await using var conn3 = await dataSource.OpenConnectionAsync(); + + await conn1.CloseAsync(); + await conn2.CloseAsync(); + AssertPoolState(dataSource, open: 3, idle: 2); + + // wait for 1 sample + Thread.Sleep(1000); + // ConnectionIdleLifetime not yet reached. + AssertPoolState(dataSource, open: 3, idle: 2); + + // close conn3, so we can see if too many connectors get pruned + await conn3.CloseAsync(); + + // wait for last sample + a bit more time for reliability + Thread.Sleep(1500); + + // ConnectionIdleLifetime reached + // - conn1 should have been closed due to max lifetime (but this should count as pruning) + // - conn2 or conn3 should have been closed due to idle pruning + // - conn3 or conn2 should remain + AssertPoolState(dataSource, open: 1, idle: 1); + } - Assert.True(PoolManager.TryGetValue(connString, out var pool)); - AssertPoolState(pool, open: 1, idle: 0); + [Test, Description("Makes sure that when a waiting async open is is given a connection, the continuation is executed in the TP rather than on the closing thread")] + public async Task Close_releases_waiter_on_another_thread() + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn1 = await dataSource.OpenConnectionAsync(); // Pool is now exhausted - Func> asyncOpener = async () => - { - using (var conn2 = CreateConnection(connString)) - { - await conn2.OpenAsync(); - AssertPoolState(pool, open: 1, idle: 0); - } - AssertPoolState(pool, open: 1, idle: 1); - return Thread.CurrentThread.ManagedThreadId; - }; - - // Start an async open which will not complete as the pool is exhausted. - var asyncOpenerTask = asyncOpener(); - conn1.Close(); // Complete the async open by closing conn1 - var asyncOpenerThreadId = asyncOpenerTask.GetAwaiter().GetResult(); - AssertPoolState(pool, open: 1, idle: 1); + AssertPoolState(dataSource, open: 1, idle: 0); - Assert.That(asyncOpenerThreadId, Is.Not.EqualTo(Thread.CurrentThread.ManagedThreadId)); - } - finally + Func> asyncOpener = async () => + { + using (var conn2 = dataSource.CreateConnection()) { - conn1.Close(); - NpgsqlConnection.ClearPool(conn1); + await conn2.OpenAsync(); + AssertPoolState(dataSource, open: 1, idle: 0); } - } + AssertPoolState(dataSource, open: 1, idle: 1); + return Environment.CurrentManagedThreadId; + }; + + // Start an async open which will not complete as the pool is exhausted. + var asyncOpenerTask = asyncOpener(); + conn1.Close(); // Complete the async open by closing conn1 + var asyncOpenerThreadId = asyncOpenerTask.GetAwaiter().GetResult(); + AssertPoolState(dataSource, open: 1, idle: 1); - [Test] - public void ReleaseWaiterOnConnectionFailure() + Assert.That(asyncOpenerThreadId, Is.Not.EqualTo(Environment.CurrentManagedThreadId)); + } + + [Test] //TODO: parallelize + public async Task Release_waiter_on_connection_failure() + { + await using var dataSource = CreateDataSource(csb => { - var connectionString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ReleaseWaiterOnConnectionFailure), - Port = 9999, - MaxPoolSize = 1 - }.ToString(); + csb.Port = 9999; + csb.MaxPoolSize = 1; + }); - try - { - var tasks = Enumerable.Range(0, 2).Select(i => Task.Run(async () => - { - using var conn = CreateConnection(connectionString); - await conn.OpenAsync(); - })).ToArray(); + var tasks = Enumerable.Range(0, 2).Select(i => Task.Run(async () => + { + await using var conn = await dataSource.OpenConnectionAsync(); + })).ToArray(); - try - { - Task.WaitAll(tasks); - } - catch (AggregateException e) - { - foreach (var inner in e.InnerExceptions) - Assert.That(inner, Is.TypeOf()); - return; - } - Assert.Fail(); - } - finally - { - NpgsqlConnection.ClearPool(CreateConnection(connectionString)); - } - } + var ex = Assert.Throws(() => Task.WaitAll(tasks))!; + Assert.That(ex.InnerExceptions, Has.Count.EqualTo(2)); + foreach (var inner in ex.InnerExceptions) + Assert.That(inner, Is.TypeOf()); + } - [Test] - [TestCase(1)] - [TestCase(2)] - public void ClearPool(int iterations) + [Test] + [TestCase(1)] + [TestCase(2)] + public void ClearPool(int iterations) + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ClearPool) - }.ToString(); + ApplicationName = nameof(ClearPool) + iterations + }.ToString(); - NpgsqlConnection conn; + NpgsqlConnection? conn = null; + try + { for (var i = 0; i < iterations; i++) { - using (conn = OpenConnection(connString)) { } + using (conn = new NpgsqlConnection(connString)) + { + conn.Open(); + } + // Now have one connection in the pool - Assert.True(PoolManager.TryGetValue(connString, out var pool)); + Assert.True(PoolManager.Pools.TryGetValue(connString, out var pool)); AssertPoolState(pool, open: 1, idle: 1); NpgsqlConnection.ClearPool(conn); AssertPoolState(pool, open: 0, idle: 0); } } + finally + { + if (conn is not null) + NpgsqlConnection.ClearPool(conn); + } + } - [Test] - public void ClearWithBusy() + [Test] + public void ClearPool_with_busy() + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ClearWithBusy) - }.ToString(); + ApplicationName = nameof(ClearPool_with_busy) + }.ToString(); - ConnectorPool? pool; - using (var conn = OpenConnection(connString)) + var conn = new NpgsqlConnection(connString); + try + { + NpgsqlDataSource? pool; + using (conn) { + conn.Open(); NpgsqlConnection.ClearPool(conn); // conn is still busy but should get closed when returned to the pool - Assert.True(PoolManager.TryGetValue(connString, out pool)); + Assert.True(PoolManager.Pools.TryGetValue(connString, out pool)); AssertPoolState(pool, open: 1, idle: 0); } + AssertPoolState(pool, open: 0, idle: 0); } - - [Test] - public void ClearWithNoPool() + finally { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ClearWithNoPool) - }.ToString(); - using (var conn = CreateConnection(connString)) - NpgsqlConnection.ClearPool(conn); + NpgsqlConnection.ClearPool(conn); } + } - [Test, Description("https://github.com/npgsql/npgsql/commit/45e33ecef21f75f51a625c7b919a50da3ed8e920#r28239653")] - public void PhysicalOpenFailure() + [Test] + public void ClearPool_with_no_pool() + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PhysicalOpenFailure), - Port = 44444, - MaxPoolSize = 1 - }.ToString(); - using (var conn = CreateConnection(connString)) - { - for (var i = 0; i < 1; i++) - Assert.That(() => conn.Open(), Throws.Exception - .TypeOf() - .With.InnerException.TypeOf()); - Assert.True(PoolManager.TryGetValue(connString, out var pool)); - AssertPoolState(pool, open: 0, idle: 0); - } - } + ApplicationName = nameof(ClearPool_with_no_pool) + }.ToString(); + using var conn = new NpgsqlConnection(connString); + NpgsqlConnection.ClearPool(conn); + } - //[Test, Explicit] - //[TestCase(10, 10, 30, true)] - //[TestCase(10, 10, 30, false)] - //[TestCase(10, 20, 30, true)] - //[TestCase(10, 20, 30, false)] - public void ExercisePool(int maxPoolSize, int numTasks, int seconds, bool async) + [Test, Description("https://github.com/npgsql/npgsql/commit/45e33ecef21f75f51a625c7b919a50da3ed8e920#r28239653")] + public void Open_physical_failure() + { + using var dataSource = CreateDataSource(csb => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ExercisePool), - MaxPoolSize = maxPoolSize - }.ToString(); + csb.Port = 44444; + csb.MaxPoolSize = 1; + }); + using var conn = dataSource.CreateConnection(); + for (var i = 0; i < 1; i++) + Assert.That(() => conn.Open(), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); + AssertPoolState(dataSource, open: 0, idle: 0); + } - Console.WriteLine($"Spinning up {numTasks} parallel tasks for {seconds} seconds (MaxPoolSize={maxPoolSize})..."); - StopFlag = 0; - var tasks = Enumerable.Range(0, numTasks).Select(i => Task.Run(async () => - { - while (StopFlag == 0) - using (var conn = CreateConnection(connString)) - { - if (async) - await conn.OpenAsync(); - else - conn.Open(); - } - })).ToArray(); - - Thread.Sleep(seconds * 1000); - Interlocked.Exchange(ref StopFlag, 1); - Console.WriteLine("Stopped. Waiting for all tasks to stop..."); - Task.WaitAll(tasks); - Console.WriteLine("Done"); - } + //[Test, Explicit] + //[TestCase(10, 10, 30, true)] + //[TestCase(10, 10, 30, false)] + //[TestCase(10, 20, 30, true)] + //[TestCase(10, 20, 30, false)] + public async Task Exercise_pool(int maxPoolSize, int numTasks, int seconds, bool async) + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = maxPoolSize); - [Test] - public async Task ConnectionLifetime() + Console.WriteLine($"Spinning up {numTasks} parallel tasks for {seconds} seconds (MaxPoolSize={maxPoolSize})..."); + StopFlag = 0; + var tasks = Enumerable.Range(0, numTasks).Select(i => Task.Run(async () => { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) + while (StopFlag == 0) { - ConnectionLifetime = 1 - }; - - using var _ = CreateTempPool(builder, out var connectionString); - await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(); - var processId = conn.ProcessID; - await conn.CloseAsync(); + await using var conn = dataSource.CreateConnection(); + if (async) + await conn.OpenAsync(); + else + conn.Open(); + } + })).ToArray(); - await Task.Delay(2000); + Thread.Sleep(seconds * 1000); + Interlocked.Exchange(ref StopFlag, 1); + Console.WriteLine("Stopped. Waiting for all tasks to stop..."); + Task.WaitAll(tasks); + Console.WriteLine("Done"); + } - await conn.OpenAsync(); - Assert.That(conn.ProcessID, Is.Not.EqualTo(processId)); - } + [Test] + public async Task ConnectionLifetime() + { + await using var dataSource = CreateDataSource(csb => csb.ConnectionLifetime = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + var processId = conn.ProcessID; + await conn.CloseAsync(); - #region Support + await Task.Delay(2000); - volatile int StopFlag; + await conn.OpenAsync(); + Assert.That(conn.ProcessID, Is.Not.EqualTo(processId)); + } - void AssertPoolState(ConnectorPool? pool, int open, int idle) - { - if (pool == null) - throw new ArgumentNullException(nameof(pool)); + #region Support - var (openState, idleState, _) = pool.Statistics; - Assert.That(openState, Is.EqualTo(open), $"Open should be {open} but is {openState}"); - Assert.That(idleState, Is.EqualTo(idle), $"Idle should be {idle} but is {idleState}"); - } + volatile int StopFlag; - // With MaxPoolSize=1, opens many connections in parallel and executes a simple SELECT. Since there's only one - // physical connection, all operations will be completely serialized - [Test] - public Task OnePhysicalConnectionManyCommands() - { - const int numParallelCommands = 10000; + void AssertPoolState(NpgsqlDataSource? pool, int open, int idle) + { + if (pool == null) + throw new ArgumentNullException(nameof(pool)); - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxPoolSize = 1, - MaxAutoPrepare = 5, - AutoPrepareMinUsages = 5, - Timeout = 0 - }.ToString(); - - return Task.WhenAll(Enumerable.Range(0, numParallelCommands) - .Select(async i => - { - using var conn = new NpgsqlConnection(connString); - await conn.OpenAsync(); - using var cmd = new NpgsqlCommand("SELECT " + i, conn); - var result = await cmd.ExecuteScalarAsync(); - Assert.That(result, Is.EqualTo(i)); - })); - } + var (openState, idleState, _) = pool.Statistics; + Assert.That(openState, Is.EqualTo(open), $"Open should be {open} but is {openState}"); + Assert.That(idleState, Is.EqualTo(idle), $"Idle should be {idle} but is {idleState}"); + } - // When multiplexing, and the pool is totally saturated (at Max Pool Size and 0 idle connectors), we select - // the connector with the least commands in flight and execute on it. We must never select a connector with - // a pending transaction on it. - // TODO: Test not tested - [Test] - [Ignore("Multiplexing: fails")] - public void MultiplexedCommandDoesntGetExecutedOnTransactionedConnector() - { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxPoolSize = 1, - Timeout = 1 - }.ToString(); - - using var connWithTx = OpenConnection(connString); - using var tx = connWithTx.BeginTransaction(); - // connWithTx should now be bound with the only physical connector available. - // Any commands execute should timeout - - using var conn2 = OpenConnection(connString); - using var cmd = new NpgsqlCommand("SELECT 1", conn2); - Assert.ThrowsAsync(() => cmd.ExecuteScalarAsync()); - } + // With MaxPoolSize=1, opens many connections in parallel and executes a simple SELECT. Since there's only one + // physical connection, all operations will be completely serialized + [Test] + public async Task OnePhysicalConnectionManyCommands() + { + const int numParallelCommands = 10000; - protected override NpgsqlConnection CreateConnection(string? connectionString = null) + await using var dataSource = CreateDataSource(csb => { - var conn = base.CreateConnection(connectionString); - _cleanup.Add(conn); - return conn; - } - - readonly List _cleanup = new List(); + csb.MaxPoolSize = 1; + csb.MaxAutoPrepare = 5; + csb.AutoPrepareMinUsages = 5; + csb.Timeout = 0; + }); + + await Task.WhenAll(Enumerable.Range(0, numParallelCommands) + .Select(async i => + { + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT " + i, conn); + var result = await cmd.ExecuteScalarAsync(); + Assert.That(result, Is.EqualTo(i)); + })); + } - [TearDown] - public void Cleanup() + // When multiplexing, and the pool is totally saturated (at Max Pool Size and 0 idle connectors), we select + // the connector with the least commands in flight and execute on it. We must never select a connector with + // a pending transaction on it. + // TODO: Test not tested + [Test] + [Ignore("Multiplexing: fails")] + public async Task MultiplexedCommandDoesntGetExecutedOnTransactionedConnector() + { + await using var dataSource = CreateDataSource(csb => { - foreach (var c in _cleanup) - { - NpgsqlConnection.ClearPool(c); - } - _cleanup.Clear(); - } - - #endregion + csb.MaxPoolSize = 1; + csb.Timeout = 1; + }); + + await using var connWithTx = await dataSource.OpenConnectionAsync(); + await using var tx = await connWithTx.BeginTransactionAsync(); + // connWithTx should now be bound with the only physical connector available. + // Any commands execute should timeout + + await using var conn2 = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn2); + Assert.ThrowsAsync(() => cmd.ExecuteScalarAsync()); } + + #endregion } diff --git a/test/Npgsql.Tests/PostgresTypeTests.cs b/test/Npgsql.Tests/PostgresTypeTests.cs new file mode 100644 index 0000000000..056830cf32 --- /dev/null +++ b/test/Npgsql.Tests/PostgresTypeTests.cs @@ -0,0 +1,74 @@ +using System.Linq; +using System.Threading.Tasks; +using Npgsql.Internal; +using NUnit.Framework; + +namespace Npgsql.Tests; + +public class PostgresTypeTests : TestBase +{ + [Test] + public async Task Base() + { + var databaseInfo = await GetDatabaseInfo(); + + var text = databaseInfo.BaseTypes.Single(a => a.Name == "text"); + Assert.That(text.DisplayName, Is.EqualTo("text")); + Assert.That(text.Namespace, Is.EqualTo("pg_catalog")); + Assert.That(text.FullName, Is.EqualTo("pg_catalog.text")); + } + + [Test] + public async Task Array() + { + var databaseInfo = await GetDatabaseInfo(); + + var textArray = databaseInfo.ArrayTypes.Single(a => a.Name == "text[]"); + Assert.That(textArray.DisplayName, Is.EqualTo("text[]")); + Assert.That(textArray.Namespace, Is.EqualTo("pg_catalog")); + Assert.That(textArray.FullName, Is.EqualTo("pg_catalog.text[]")); + + var text = databaseInfo.BaseTypes.Single(a => a.Name == "text"); + Assert.That(textArray.Element, Is.SameAs(text)); + Assert.That(text.Array, Is.SameAs(textArray)); + } + + [Test] + public async Task Range() + { + var databaseInfo = await GetDatabaseInfo(); + + var intRange = databaseInfo.RangeTypes.Single(a => a.Name == "int4range"); + Assert.That(intRange.DisplayName, Is.EqualTo("int4range")); + Assert.That(intRange.Namespace, Is.EqualTo("pg_catalog")); + Assert.That(intRange.FullName, Is.EqualTo("pg_catalog.int4range")); + + var integer = databaseInfo.BaseTypes.Single(a => a.Name == "integer"); + Assert.That(intRange.Subtype, Is.SameAs(integer)); + Assert.That(integer.Range, Is.SameAs(intRange)); + } + + [Test] + public async Task Multirange() + { + await using (var conn = await OpenConnectionAsync()) + TestUtil.MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + var databaseInfo = await GetDatabaseInfo(); + + var intMultirange = databaseInfo.MultirangeTypes.Single(a => a.Name == "int4multirange"); + Assert.That(intMultirange.DisplayName, Is.EqualTo("int4multirange")); + Assert.That(intMultirange.Namespace, Is.EqualTo("pg_catalog")); + Assert.That(intMultirange.FullName, Is.EqualTo("pg_catalog.int4multirange")); + + var intRange = databaseInfo.RangeTypes.Single(a => a.Name == "int4range"); + Assert.That(intMultirange.Subrange, Is.SameAs(intRange)); + Assert.That(intRange.Multirange, Is.SameAs(intMultirange)); + } + + async Task GetDatabaseInfo() + { + await using var conn = await OpenConnectionAsync(); + return conn.NpgsqlDataSource.DatabaseInfo; + } +} diff --git a/test/Npgsql.Tests/PrepareTests.cs b/test/Npgsql.Tests/PrepareTests.cs index 67c557e5d1..1d9c6dde85 100644 --- a/test/Npgsql.Tests/PrepareTests.cs +++ b/test/Npgsql.Tests/PrepareTests.cs @@ -3,695 +3,920 @@ using System.Data; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Internal.Postgres; +using Npgsql.Tests.Support; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class PrepareTests: TestBase { - public class PrepareTests: TestBase + static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; + + [Test] + public void Basic() { - [Test] - public void Basic() + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - AssertNumPreparedStatements(conn, 0); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - Assert.That(cmd.IsPrepared, Is.False); - - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - Assert.That(cmd.IsPrepared, Is.True); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - AssertNumPreparedStatements(conn, 1); - conn.UnprepareAll(); - } + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(cmd.IsPrepared, Is.False); + + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); } + AssertNumPreparedStatements(conn, 1); + conn.UnprepareAll(); + } - [Test] - public async Task Async() + [Test] + public async Task Async() + { + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - AssertNumPreparedStatements(conn, 0); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - Assert.That(cmd.IsPrepared, Is.False); - - await cmd.PrepareAsync(); - AssertNumPreparedStatements(conn, 1); - Assert.That(cmd.IsPrepared, Is.True); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - AssertNumPreparedStatements(conn, 1); - conn.UnprepareAll(); - } + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(cmd.IsPrepared, Is.False); + + await cmd.PrepareAsync(); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); } + AssertNumPreparedStatements(conn, 1); + conn.UnprepareAll(); + } - [Test] - public void Unprepare() - => Unprepare(false).GetAwaiter().GetResult(); - - [Test] - public Task UnprepareAsync() - => Unprepare(true); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3443")] + public void Bug3443() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + Assert.That(cmd.IsPrepared, Is.False); + + Assert.ThrowsAsync(() => cmd.PrepareAsync(new(canceled: true))); + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd.IsPrepared, Is.False); + + using var cmd2 = new NpgsqlCommand("SELECT 1", conn); + cmd2.Prepare(); + Assert.That(cmd2.ExecuteScalar(), Is.EqualTo(1)); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd2.IsPrepared, Is.True); + } - private async Task Unprepare(bool async) + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4209")] + public async Task Async_cancel_NullReferenceException() + { + for (var i = 0; i < 10; i++) { - using (var conn = OpenConnectionAndUnprepare()) + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + using var cts = new CancellationTokenSource(); + using var mre = new ManualResetEventSlim(); + var cancelTask = Task.Run(() => { - AssertNumPreparedStatements(conn, 0); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - if(async) - await cmd.PrepareAsync(); - else - cmd.Prepare(); + mre.Wait(); + cts.Cancel(); + }); + try + { + mre.Set(); + await cmd.PrepareAsync(cts.Token); + } + catch (OperationCanceledException) + { + // There is a race between us checking the cancellation token and the cancellation itself. + // If the cancellation happens first, we get OperationCancelledException. + // In other case, PrepareAsync will not be cancelled and shouldn't throw any exceptions. + } + await cancelTask; - AssertNumPreparedStatements(conn, 1); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + } + } - if (async) - await cmd.UnprepareAsync(); - else - cmd.Unprepare(); + [Test] + public void Unprepare() + => Unprepare(false).GetAwaiter().GetResult(); - AssertNumPreparedStatements(conn, 0); - Assert.That(cmd.IsPrepared, Is.False); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - } - } + [Test] + public Task UnprepareAsync() + => Unprepare(true); + + async Task Unprepare(bool async) + { + using var conn = OpenConnectionAndUnprepare(); + AssertNumPreparedStatements(conn, 0); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + if(async) + await cmd.PrepareAsync(); + else + cmd.Prepare(); + + AssertNumPreparedStatements(conn, 1); + + if (async) + await cmd.UnprepareAsync(); + else + cmd.Unprepare(); + + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd.IsPrepared, Is.False); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } + + [Test] + public void Named_parameters() + { + using var conn = OpenConnectionAndUnprepare(); - [Test] - public void Parameters() + for (var i = 0; i < 2; i++) { - using (var conn = OpenConnectionAndUnprepare()) - using (var command = new NpgsqlCommand("SELECT @a, @b", conn)) - { - command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32)); - command.Parameters.Add(new NpgsqlParameter("b", 8)); - command.Prepare(); - command.Parameters[0].Value = 3; - command.Parameters[1].Value = 5; - using (var reader = command.ExecuteReader()) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(3)); - Assert.That(reader.GetInt64(1), Is.EqualTo(5)); - } - command.Unprepare(); + using var command = new NpgsqlCommand("SELECT @a, @b", conn); + command.Parameters.Add(new NpgsqlParameter("a", DbType.Int32)); + command.Parameters.Add(new NpgsqlParameter("b", 8)); + command.Prepare(); + command.Parameters[0].Value = 3; + command.Parameters[1].Value = 5; + + using (var reader = command.ExecuteReader()) + { + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader.GetInt64(1), Is.EqualTo(5)); } + + command.Unprepare(); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1207")] - public void DoublePrepareSameSql() + [Test] + public void Positional_parameters() + { + using var conn = OpenConnectionAndUnprepare(); + + for (var i = 0; i < 2; i++) { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - cmd.Unprepare(); - AssertNumPreparedStatements(conn, 0); + using var command = new NpgsqlCommand("SELECT $1, $2", conn); + command.Parameters.Add(new NpgsqlParameter { DbType = DbType.Int32 }); + command.Parameters.Add(new NpgsqlParameter { Value = 8 }); + command.Prepare(); + command.Parameters[0].Value = 3; + command.Parameters[1].Value = 5; + + using (var reader = command.ExecuteReader()) + { + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader.GetInt64(1), Is.EqualTo(5)); } + + command.Unprepare(); } + } - [Test] - public void DoublePrepareDifferentSql() - { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1207")] + public void Double_prepare_same_sql() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Prepare(); + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + cmd.Unprepare(); + AssertNumPreparedStatements(conn, 0); + } - cmd.CommandText = "SELECT 1"; - cmd.Prepare(); - cmd.ExecuteNonQuery(); + [Test] + public void Double_prepare_different_sql() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand(); + cmd.Connection = conn; - cmd.CommandText = "SELECT 2"; - cmd.Prepare(); - AssertNumPreparedStatements(conn, 2); - cmd.ExecuteNonQuery(); + cmd.CommandText = "SELECT 1"; + cmd.Prepare(); + cmd.ExecuteNonQuery(); - conn.UnprepareAll(); - } - } + cmd.CommandText = "SELECT 2"; + cmd.Prepare(); + AssertNumPreparedStatements(conn, 2); + cmd.ExecuteNonQuery(); + + conn.UnprepareAll(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] + public void Across_close_open_same_connector() + { + using var dataSource = CreateDataSource(); + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Prepare(); + Assert.That(cmd.IsPrepared, Is.True); + var processId = conn.ProcessID; + conn.Close(); + conn.Open(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + Assert.That(cmd.IsPrepared, Is.True); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + cmd.Prepare(); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } + + [Test] + public void Across_close_open_different_connector() + { + using var dataSource = CreateDataSource(); + using var conn1 = dataSource.CreateConnection(); + using var conn2 = dataSource.CreateConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn1); + conn1.Open(); + cmd.Prepare(); + Assert.That(cmd.IsPrepared, Is.True); + var processId = conn1.ProcessID; + conn1.Close(); + conn2.Open(); + conn1.Open(); + Assert.That(conn1.ProcessID, Is.Not.EqualTo(processId)); + Assert.That(cmd.IsPrepared, Is.False); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); // Execute unprepared + cmd.Prepare(); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] - public void AcrossCloseOpenSameConnector() + [Test] + public void Reuse_prepared_statement() + { + using var dataSource = CreateDataSource(); + using var conn1 = dataSource.OpenConnection(); + var preparedStatement = Array.Empty(); + using (var cmd1 = new NpgsqlCommand("SELECT @p", conn1)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PrepareTests) + '.' + nameof(AcrossCloseOpenSameConnector) - }; - using (var conn = OpenConnectionAndUnprepare(csb)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - Assert.That(cmd.IsPrepared, Is.True); - var processId = conn.ProcessID; - conn.Close(); - conn.Open(); - Assert.That(processId, Is.EqualTo(conn.ProcessID)); - Assert.That(cmd.IsPrepared, Is.True); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - cmd.Prepare(); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - NpgsqlConnection.ClearPool(conn); - } + cmd1.Parameters.AddWithValue("p", 8); + cmd1.Prepare(); + Assert.That(cmd1.IsPrepared, Is.True); + Assert.That(cmd1.ExecuteScalar(), Is.EqualTo(8)); + preparedStatement = cmd1.InternalBatchCommands[0].PreparedStatement!.Name!; } - [Test] - public void AcrossCloseOpenDifferentConnector() + using (var cmd2 = new NpgsqlCommand("SELECT @p", conn1)) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PrepareTests) + '.' + nameof(AcrossCloseOpenDifferentConnector) - }.ToString(); - using (var conn1 = new NpgsqlConnection(connString)) - using (var conn2 = new NpgsqlConnection(connString)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn1)) - { - conn1.Open(); - cmd.Prepare(); - Assert.That(cmd.IsPrepared, Is.True); - var processId = conn1.ProcessID; - conn1.Close(); - conn2.Open(); - conn1.Open(); - Assert.That(conn1.ProcessID, Is.Not.EqualTo(processId)); - Assert.That(cmd.IsPrepared, Is.False); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); // Execute unprepared - cmd.Prepare(); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - NpgsqlConnection.ClearPool(conn1); - } + cmd2.Parameters.AddWithValue("p", 8); + cmd2.Prepare(); + Assert.That(cmd2.IsPrepared, Is.True); + Assert.That(cmd2.InternalBatchCommands[0].PreparedStatement!.Name, Is.EqualTo(preparedStatement)); + Assert.That(cmd2.ExecuteScalar(), Is.EqualTo(8)); } + } - [Test] - public void ReusePreparedStatement() + [Test] + public void Legacy_batching() + { + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PrepareTests) + '.' + nameof(ReusePreparedStatement) - }.ToString(); - using (var conn1 = OpenConnection(connString)) + cmd.Prepare(); + using (var reader = cmd.ExecuteReader()) { - var preparedStatement = ""; - using (var cmd1 = new NpgsqlCommand("SELECT @p", conn1)) - { - cmd1.Parameters.AddWithValue("p", 8); - cmd1.Prepare(); - Assert.That(cmd1.IsPrepared, Is.True); - Assert.That(cmd1.ExecuteScalar(), Is.EqualTo(8)); - preparedStatement = cmd1.Statements[0].PreparedStatement!.Name!; - } - - using (var cmd2 = new NpgsqlCommand("SELECT @p", conn1)) - { - cmd2.Parameters.AddWithValue("p", 8); - cmd2.Prepare(); - Assert.That(cmd2.IsPrepared, Is.True); - Assert.That(cmd2.Statements[0].PreparedStatement!.Name, Is.EqualTo(preparedStatement)); - Assert.That(cmd2.ExecuteScalar(), Is.EqualTo(8)); - } - NpgsqlConnection.ClearPool(conn1); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } } - [Test] - public void Multistatement() + AssertNumPreparedStatements(conn, 2); + + using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) { - using (var conn = OpenConnectionAndUnprepare()) + cmd.Prepare(); + using (var reader = cmd.ExecuteReader()) { - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - { - cmd.Prepare(); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - reader.NextResult(); - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(2)); - } - } - - AssertNumPreparedStatements(conn, 2); - - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - { - cmd.Prepare(); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - reader.NextResult(); - reader.Read(); - Assert.That(reader.GetInt32(0), Is.EqualTo(2)); - } - } - - AssertNumPreparedStatements(conn, 2); - conn.UnprepareAll(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } } - [Test] - public void OneCommandSameSqlTwice() + AssertNumPreparedStatements(conn, 2); + conn.UnprepareAll(); + } + + [Test] + public void Batch() + { + using var conn = OpenConnectionAndUnprepare(); + using (var batch = new NpgsqlBatch(conn) { BatchCommands = { new("SELECT 1"), new("SELECT 2") } }) { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 1", conn)) - { - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - cmd.ExecuteNonQuery(); - cmd.Unprepare(); + batch.Prepare(); + using (var reader = batch.ExecuteReader()) + { + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } } - [Test] - public void OneCommandSameSqlAutoPrepare() + using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - MaxAutoPrepare = 5, - AutoPrepareMinUsages = 2 - }; - using (var conn = OpenConnectionAndUnprepare(csb)) + cmd.Prepare(); + using (var reader = cmd.ExecuteReader()) { - var sql = new StringBuilder(); - for (var i = 0; i < 2 + 1; i++) - sql.Append("SELECT 1;"); - using (var cmd = new NpgsqlCommand(sql.ToString(), conn)) - cmd.ExecuteNonQuery(); - AssertNumPreparedStatements(conn, 1); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } } - [Test] - public void OneCommandSameSqlTwiceWithParams() + AssertNumPreparedStatements(conn, 2); + + using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn)) + cmd.Prepare(); + using (var reader = cmd.ExecuteReader()) { - cmd.Parameters.Add("p1", NpgsqlDbType.Integer); - cmd.Parameters.Add("p2", NpgsqlDbType.Integer); - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - - cmd.Parameters[0].Value = 8; - cmd.Parameters[1].Value = 9; - using (var reader = cmd.ExecuteReader()) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - Assert.That(reader.NextResult(), Is.True); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(9)); - Assert.That(reader.NextResult(), Is.False); - } - - cmd.Unprepare(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + reader.NextResult(); + reader.Read(); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } } - [Test] - public void UnprepareViaDifferentCommand() - { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd1 = new NpgsqlCommand("SELECT 1; SELECT 2", conn)) - using (var cmd2 = new NpgsqlCommand("SELECT 2; SELECT 3", conn)) - { - cmd1.Prepare(); - cmd2.Prepare(); - // Both commands reference the same prepared statement - AssertNumPreparedStatements(conn, 3); - cmd2.Unprepare(); - AssertNumPreparedStatements(conn, 1); - Assert.That(cmd1.IsPrepared, Is.False); // Only partially prepared, so no - cmd1.ExecuteNonQuery(); - cmd1.Unprepare(); - AssertNumPreparedStatements(conn, 0); - Assert.That(cmd1.IsPrepared, Is.False); - cmd1.ExecuteNonQuery(); + AssertNumPreparedStatements(conn, 2); + conn.UnprepareAll(); + } - conn.UnprepareAll(); - } - } + [Test] + public void One_command_same_sql_twice() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 1", conn); + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + cmd.ExecuteNonQuery(); + cmd.Unprepare(); + } - [Test, Description("Prepares the same SQL with different parameters (overloading)")] - public void OverloadedSql() + [Test] + public void One_command_same_sql_auto_prepare() + { + using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add("p", NpgsqlDbType.Integer); - cmd.Prepare(); - Assert.That(cmd.IsPrepared, Is.True); - } - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, "foo"); - cmd.Prepare(); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); - Assert.That(cmd.IsPrepared, Is.False); - } + csb.MaxAutoPrepare = 5; + csb.AutoPrepareMinUsages = 2; + }); + using var conn = dataSource.OpenConnection(); + var sql = new StringBuilder(); + for (var i = 0; i < 2 + 1; i++) + sql.Append("SELECT 1;"); + using (var cmd = new NpgsqlCommand(sql.ToString(), conn)) + cmd.ExecuteNonQuery(); + AssertNumPreparedStatements(conn, 1); + } - // SQL overloading is a pretty rare/exotic scenario. Handling it properly would involve keying - // prepared statements not just by SQL but also by the parameter types, which would pointlessly - // increase allocations. Instead, the second execution simply reuns unprepared - AssertNumPreparedStatements(conn, 1); - conn.UnprepareAll(); - } + [Test] + public void One_command_same_sql_twice_with_params() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); + cmd.Parameters.Add("p1", NpgsqlDbType.Integer); + cmd.Parameters.Add("p2", NpgsqlDbType.Integer); + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + + cmd.Parameters[0].Value = 8; + cmd.Parameters[1].Value = 9; + using (var reader = cmd.ExecuteReader()) + { + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + Assert.That(reader.NextResult(), Is.True); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(9)); + Assert.That(reader.NextResult(), Is.False); } - [Test] - public void ManyStatementsOnUnprepare() + cmd.Unprepare(); + } + + [Test] + public void Unprepare_via_different_command() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd1 = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + using var cmd2 = new NpgsqlCommand("SELECT 2; SELECT 3", conn); + cmd1.Prepare(); + cmd2.Prepare(); + // Both commands reference the same prepared statement + AssertNumPreparedStatements(conn, 3); + cmd2.Unprepare(); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd1.IsPrepared, Is.False); // Only partially prepared, so no + cmd1.ExecuteNonQuery(); + cmd1.Unprepare(); + AssertNumPreparedStatements(conn, 0); + Assert.That(cmd1.IsPrepared, Is.False); + cmd1.ExecuteNonQuery(); + + conn.UnprepareAll(); + } + + [Test, Description("Prepares the same SQL with different parameters (overloading)")] + public void Overloaded_sql() + { + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand()) - { - cmd.Connection = conn; - var sb = new StringBuilder(); - for (var i = 0; i < conn.Settings.WriteBufferSize; i++) - sb.Append("SELECT 1;"); - cmd.CommandText = sb.ToString(); - cmd.Prepare(); - cmd.Unprepare(); - } + cmd.Parameters.Add("p", NpgsqlDbType.Integer); + cmd.Prepare(); + Assert.That(cmd.IsPrepared, Is.True); } - - [Test] - public void IsPreparedIsFalseAfterChangingCommandText() + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - cmd.CommandText = "SELECT 2"; - Assert.That(cmd.IsPrepared, Is.False); - cmd.ExecuteNonQuery(); - Assert.That(cmd.IsPrepared, Is.False); - AssertNumPreparedStatements(conn, 1); - cmd.Unprepare(); - } + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Text, "foo"); + cmd.Prepare(); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo("foo")); + Assert.That(cmd.IsPrepared, Is.False); } - [Test, Description("Basic persistent prepared system scenario. Checks that statement is not deallocated in the backend after command dispose.")] - public void PersistentAcrossCommands() - { - using (var conn = OpenConnectionAndUnprepare()) - { - AssertNumPreparedStatements(conn, 0); + // SQL overloading is a pretty rare/exotic scenario. Handling it properly would involve keying + // prepared statements not just by SQL but also by the parameter types, which would pointlessly + // increase allocations. Instead, the second execution simply reuns unprepared + AssertNumPreparedStatements(conn, 1); + conn.UnprepareAll(); + } - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - AssertNumPreparedStatements(conn, 1); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - AssertNumPreparedStatements(conn, 1); + [Test] + public void Many_statements_on_unprepare() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand(); + cmd.Connection = conn; + var sb = new StringBuilder(); + for (var i = 0; i < conn.Settings.WriteBufferSize; i++) + sb.Append("SELECT 1;"); + cmd.CommandText = sb.ToString(); + cmd.Prepare(); + cmd.Unprepare(); + } - var stmtName = GetPreparedStatements(conn).Single(); + [Test] + public void IsPrepared_is_false_after_changing_CommandText() + { + using var conn = OpenConnectionAndUnprepare(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + cmd.CommandText = "SELECT 2"; + Assert.That(cmd.IsPrepared, Is.False); + cmd.ExecuteNonQuery(); + Assert.That(cmd.IsPrepared, Is.False); + AssertNumPreparedStatements(conn, 1); + cmd.Unprepare(); + } - // Rerun the test using the persistent prepared statement - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - Assert.That(cmd.IsPrepared, Is.True); - AssertNumPreparedStatements(conn, 1); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - AssertNumPreparedStatements(conn, 1); - Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName)); - conn.UnprepareAll(); - } - } + [Test, Description("Basic persistent prepared system scenario. Checks that statement is not deallocated in the backend after command dispose.")] + public void Persistent_across_commands() + { + using var conn = OpenConnectionAndUnprepare(); + AssertNumPreparedStatements(conn, 0); - [Test, Description("Basic persistent prepared system scenario. Checks that statement is not deallocated in the backend after connection close.")] - public void PersistentAcrossConnections() + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - var connSettings = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PersistentAcrossConnections) - }; + cmd.Prepare(); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } + AssertNumPreparedStatements(conn, 1); - using (var conn = OpenConnectionAndUnprepare(connSettings)) - { - var processId = conn.ProcessID; + var stmtName = GetPreparedStatements(conn).Single(); - AssertNumPreparedStatements(conn, 0); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.Prepare(); + // Rerun the test using the persistent prepared statement + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + { + cmd.Prepare(); + Assert.That(cmd.IsPrepared, Is.True); + AssertNumPreparedStatements(conn, 1); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + } + AssertNumPreparedStatements(conn, 1); + Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName)); + conn.UnprepareAll(); + } - var stmtName = GetPreparedStatements(conn).Single(); - conn.Close(); + [Test, Description("Basic persistent prepared system scenario. Checks that statement is not deallocated in the backend after connection close.")] + public void Persistent_across_connections() + { + using var dataSource = CreateDataSource(); + using var conn = dataSource.OpenConnection(); + var processId = conn.ProcessID; - conn.Open(); - Assert.That(conn.ProcessID, Is.EqualTo(processId), "Unexpected connection received from the pool"); + AssertNumPreparedStatements(conn, 0); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + cmd.Prepare(); - AssertNumPreparedStatements(conn, 1, "Prepared statement deallocated"); - Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Prepared statement name changed unexpectedly"); + var stmtName = GetPreparedStatements(conn).Single(); + conn.Close(); - // Rerun the test using the persistent prepared statement - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - } - AssertNumPreparedStatements(conn, 1, "Prepared statement deallocated"); - Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Prepared statement name changed unexpectedly"); + conn.Open(); + Assert.That(conn.ProcessID, Is.EqualTo(processId), "Unexpected connection received from the pool"); - NpgsqlConnection.ClearPool(conn); - } - } + AssertNumPreparedStatements(conn, 1, "Prepared statement deallocated"); + Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Prepared statement name changed unexpectedly"); - [Test, Description("Makes sure that calling Prepare() twice on a command does not deallocate or make a new one after the first prepared statement when command does not change")] - public void PersistentDoublePrepareCommandUnchanged() + // Rerun the test using the persistent prepared statement + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - cmd.ExecuteNonQuery(); - var stmtName = GetPreparedStatements(conn).Single(); - cmd.Prepare(); - cmd.ExecuteNonQuery(); - AssertNumPreparedStatements(conn, 1, "Unexpected count of prepared statements"); - Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Persistent prepared statement name changed unexpectedly"); - } - AssertNumPreparedStatements(conn, 1, "Persistent prepared statement deallocated"); - conn.UnprepareAll(); - } + cmd.Prepare(); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); } + AssertNumPreparedStatements(conn, 1, "Prepared statement deallocated"); + Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Prepared statement name changed unexpectedly"); + } - [Test] - public void PersistentDoublePrepareCommandChanged() + [Test, Description("Makes sure that calling Prepare() twice on a command does not deallocate or make a new one after the first prepared statement when command does not change")] + public void Persistent_double_prepare_command_unchanged() + { + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(); - cmd.ExecuteNonQuery(); - cmd.CommandText = "SELECT 2"; - AssertNumPreparedStatements(conn, 1); - cmd.Prepare(); - AssertNumPreparedStatements(conn, 2); - cmd.ExecuteNonQuery(); - } - AssertNumPreparedStatements(conn, 2); - conn.UnprepareAll(); - } + cmd.Prepare(); + cmd.ExecuteNonQuery(); + var stmtName = GetPreparedStatements(conn).Single(); + cmd.Prepare(); + cmd.ExecuteNonQuery(); + AssertNumPreparedStatements(conn, 1, "Unexpected count of prepared statements"); + Assert.That(GetPreparedStatements(conn).Single(), Is.EqualTo(stmtName), "Persistent prepared statement name changed unexpectedly"); } + AssertNumPreparedStatements(conn, 1, "Persistent prepared statement deallocated"); + conn.UnprepareAll(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2665")] - public void PreparedCommandFailure() + [Test] + public void Persistent_double_prepare_command_changed() + { + using var conn = OpenConnectionAndUnprepare(); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using var conn = OpenConnection(); + cmd.Prepare(); + cmd.ExecuteNonQuery(); + cmd.CommandText = "SELECT 2"; + AssertNumPreparedStatements(conn, 1); + cmd.Prepare(); + AssertNumPreparedStatements(conn, 2); + cmd.ExecuteNonQuery(); + } + AssertNumPreparedStatements(conn, 2); + conn.UnprepareAll(); + } - using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) - Assert.Throws(() => command.Prepare()); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2665")] + public void Prepared_command_failure() + { + using var conn = OpenConnection(); - conn.ExecuteNonQuery("CREATE TEMP TABLE test_table (id integer)"); + using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) + Assert.Throws(() => command.Prepare()); - using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) - { - command.Prepare(); - command.ExecuteNonQuery(); - } - } + conn.ExecuteNonQuery("CREATE TEMP TABLE test_table (id integer)"); - /* - [Test] - public void Unpersist() + using (var command = new NpgsqlCommand("INSERT INTO test_table (id) VALUES (1)", conn)) { - using (var conn = OpenConnectionAndUnprepare()) - { - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.Prepare(true); - - // Unpersist via a different command - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(true); - cmd.Unpersist(); - AssertNumPreparedStatements(conn, 0); - } - - // Repersist - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(true); - Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); - cmd.Unpersist(); - AssertNumPreparedStatements(conn, 0); - } + command.Prepare(); + command.ExecuteNonQuery(); + } + } - // Unpersist via an unprepared command - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.Prepare(true); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.Unpersist(); - AssertNumPreparedStatements(conn, 0); + /* + [Test] + public void Unpersist() + { + using (var conn = OpenConnectionAndUnprepare()) + { + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + cmd.Prepare(true); - // Unpersist via a prepared but unpersisted command - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.Prepare(true); - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - { - cmd.Prepare(false); - cmd.Unpersist(); - } + // Unpersist via a different command + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + { + cmd.Prepare(true); + cmd.Unpersist(); AssertNumPreparedStatements(conn, 0); } - } - [Test] - public void SameSqlDifferentParams() - { - using (var conn = OpenConnectionAndUnprepare()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) + // Repersist + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - throw new NotImplementedException("Problem: currentl setting NpgsqlParameter.Value clears/invalidates..."); - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Integer)); cmd.Prepare(true); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(1)); + cmd.Unpersist(); + AssertNumPreparedStatements(conn, 0); + } - cmd.Parameters[0].NpgsqlDbType = NpgsqlDbType.Text; - Assert.That(cmd.IsPrepared, Is.False); + // Unpersist via an unprepared command + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) cmd.Prepare(true); - using (var crapCmd = new NpgsqlCommand("SELECT name,statement,parameter_types::TEXT[] FROM pg_prepared_statements", conn)) - using (var reader = crapCmd.ExecuteReader()) - { - while (reader.Read()) - { - Console.WriteLine($"Statement: {reader.GetString(0)}, {reader.GetString(1)}"); - foreach (var p in reader.GetFieldValue(2)) - { - Console.WriteLine(" Param: " + p); - } - } - } - //AssertNumPreparedStatements(conn, 2); - cmd.Parameters[0].Value = "hello"; - Console.WriteLine(cmd.ExecuteScalar()); - } - } - */ + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + cmd.Unpersist(); + AssertNumPreparedStatements(conn, 0); - [Test] - public void InvalidStatement() - { - using (var conn = OpenConnection()) + // Unpersist via a prepared but unpersisted command + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + cmd.Prepare(true); + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - var cmd = new NpgsqlCommand("sele", conn); - Assert.That(() => cmd.Prepare(), Throws.Exception.TypeOf()); + cmd.Prepare(false); + cmd.Unpersist(); } + AssertNumPreparedStatements(conn, 0); } + } - [Test] - public void PrepareMultipleCommandsWithParameters() + [Test] + public void Same_sql_different_params() + { + using (var conn = OpenConnectionAndUnprepare()) + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - using (var conn = OpenConnection()) + throw new NotImplementedException("Problem: currentl setting NpgsqlParameter.Value clears/invalidates..."); + cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Integer)); + cmd.Prepare(true); + + cmd.Parameters[0].NpgsqlDbType = NpgsqlDbType.Text; + Assert.That(cmd.IsPrepared, Is.False); + cmd.Prepare(true); + using (var crapCmd = new NpgsqlCommand("SELECT name,statement,parameter_types::TEXT[] FROM pg_prepared_statements", conn)) + using (var reader = crapCmd.ExecuteReader()) { - using (var cmd1 = new NpgsqlCommand("SELECT @p1;", conn)) - using (var cmd2 = new NpgsqlCommand("SELECT @p1; SELECT @p2;", conn)) + while (reader.Read()) { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Integer); - var p21 = new NpgsqlParameter("p1", NpgsqlDbType.Text); - var p22 = new NpgsqlParameter("p2", NpgsqlDbType.Text); - cmd1.Parameters.Add(p1); - cmd2.Parameters.Add(p21); - cmd2.Parameters.Add(p22); - cmd1.Prepare(); - cmd2.Prepare(); - p1.Value = 8; - p21.Value = "foo"; - p22.Value = "bar"; - using (var reader1 = cmd1.ExecuteReader()) + Console.WriteLine($"Statement: {reader.GetString(0)}, {reader.GetString(1)}"); + foreach (var p in reader.GetFieldValue(2)) { - Assert.That(reader1.Read(), Is.True); - Assert.That(reader1.GetInt32(0), Is.EqualTo(8)); - } - using (var reader2 = cmd2.ExecuteReader()) - { - Assert.That(reader2.Read(), Is.True); - Assert.That(reader2.GetString(0), Is.EqualTo("foo")); - Assert.That(reader2.NextResult(), Is.True); - Assert.That(reader2.Read(), Is.True); - Assert.That(reader2.GetString(0), Is.EqualTo("bar")); + Console.WriteLine(" Param: " + p); } } } + //AssertNumPreparedStatements(conn, 2); + cmd.Parameters[0].Value = "hello"; + Console.WriteLine(cmd.ExecuteScalar()); } + } + */ - [Test] - public void MultiplexingNotSupported() - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { Multiplexing = true }; - using var conn = OpenConnection(builder); - using var cmd = new NpgsqlCommand("SELECT 1", conn); + [Test] + public void Invalid_statement() + { + using var conn = OpenConnection(); + var cmd = new NpgsqlCommand("sele", conn); + Assert.That(() => cmd.Prepare(), Throws.Exception.TypeOf()); + } - Assert.That(() => cmd.Prepare(), Throws.Exception.TypeOf()); - Assert.That(() => conn.UnprepareAll(), Throws.Exception.TypeOf()); + [Test] + public void Prepare_multiple_commands_with_parameters() + { + using var conn = OpenConnection(); + using var cmd1 = new NpgsqlCommand("SELECT @p1;", conn); + using var cmd2 = new NpgsqlCommand("SELECT @p1; SELECT @p2;", conn); + var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Integer); + var p21 = new NpgsqlParameter("p1", NpgsqlDbType.Text); + var p22 = new NpgsqlParameter("p2", NpgsqlDbType.Text); + cmd1.Parameters.Add(p1); + cmd2.Parameters.Add(p21); + cmd2.Parameters.Add(p22); + cmd1.Prepare(); + cmd2.Prepare(); + p1.Value = 8; + p21.Value = "foo"; + p22.Value = "bar"; + using (var reader1 = cmd1.ExecuteReader()) + { + Assert.That(reader1.Read(), Is.True); + Assert.That(reader1.GetInt32(0), Is.EqualTo(8)); } - - NpgsqlConnection OpenConnectionAndUnprepare(string? connectionString = null) + using (var reader2 = cmd2.ExecuteReader()) { - var conn = OpenConnection(connectionString); - conn.UnprepareAll(); - return conn; + Assert.That(reader2.Read(), Is.True); + Assert.That(reader2.GetString(0), Is.EqualTo("foo")); + Assert.That(reader2.NextResult(), Is.True); + Assert.That(reader2.Read(), Is.True); + Assert.That(reader2.GetString(0), Is.EqualTo("bar")); } + } - NpgsqlConnection OpenConnectionAndUnprepare(NpgsqlConnectionStringBuilder csb) - => OpenConnectionAndUnprepare(csb.ToString()); + [Test] + public void Multiplexing_not_supported() + { + using var dataSource = CreateDataSource(csb => csb.Multiplexing = true); + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); - void AssertNumPreparedStatements(NpgsqlConnection conn, int expected) - => Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statements%'"), Is.EqualTo(expected)); + Assert.That(() => cmd.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => conn.UnprepareAll(), Throws.Exception.TypeOf()); + } - void AssertNumPreparedStatements(NpgsqlConnection conn, int expected, string message) - => Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statements%'"), Is.EqualTo(expected), message); + [Test] + public async Task Explicitly_prepared_statement_invalidation() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + await using var connection = await dataSource.OpenConnectionAsync(); + var table = await CreateTempTable(connection, "foo int"); + + await using var command = new NpgsqlCommand($"SELECT * FROM {table}", connection); + await command.PrepareAsync(); + + await connection.ExecuteNonQueryAsync($"ALTER TABLE {table} RENAME COLUMN foo TO bar"); + + // Since we've changed the table schema, the next execution of the prepared statement will error with 0A000 + var exception = Assert.ThrowsAsync(() => command.ExecuteNonQueryAsync())!; + Assert.That(exception.SqlState, Is.EqualTo(PostgresErrorCodes.FeatureNotSupported)); // cached plan must not change result type + + // However, Npgsql should invalidate the prepared statement in this case, so the next execution should work + Assert.DoesNotThrowAsync(() => command.ExecuteNonQueryAsync()); + + // The command is unprepared, though. It's the user's responsibility to re-prepare if they wish. + Assert.False(command.IsPrepared); + } - List GetPreparedStatements(NpgsqlConnection conn) + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4920")] + public async Task Explicit_prepare_unprepare_many_queries() + { + // Set a specific buffer's size to trigger #4920 + await using var dataSource = CreateDataSource(csb => csb.WriteBufferSize = 5002); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + cmd.CommandText = string.Join(';', Enumerable.Range(1, 500).Select(x => $"SELECT {x}")); + await cmd.PrepareAsync(); + await cmd.UnprepareAsync(); + } + + [Test] + public async Task Explicitly_prepared_batch_sends_prepared_queries() + { + await using var postmaster = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmaster.ConnectionString); + + await using var conn = await dataSource.OpenConnectionAsync(); + var server = await postmaster.WaitForServerConnection(); + + await using var batch = new NpgsqlBatch(conn) { - var statements = new List(); - using (var cmd = new NpgsqlCommand("SELECT name FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statement%'", conn)) - using (var reader = cmd.ExecuteReader()) - { - while (reader.Read()) - statements.Add(reader.GetString(0)); - } - return statements; + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + var prepareTask = batch.PrepareAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Parse, FrontendMessageCode.Describe, + FrontendMessageCode.Parse, FrontendMessageCode.Describe, + FrontendMessageCode.Sync); + + await server + .WriteParseComplete() + .WriteParameterDescription(new FieldDescription(Int4Oid)) + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteParseComplete() + .WriteParameterDescription(new FieldDescription(Int4Oid)) + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteReadyForQuery() + .FlushAsync(); + + await prepareTask; + + for (var i = 0; i < 2; i++) + await ExecutePreparedBatch(batch, server); + + async Task ExecutePreparedBatch(NpgsqlBatch batch, PgServerMock server) + { + var executeBatchTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteBindComplete() + .WriteCommandComplete() + .WriteBindComplete() + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await executeBatchTask; } } + + [Test] + public async Task Auto_prepared_batch_sends_prepared_queries() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + AutoPrepareMinUsages = 1, + MaxAutoPrepare = 10 + }; + await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(postmaster.ConnectionString); + + await using var conn = await dataSource.OpenConnectionAsync(); + var server = await postmaster.WaitForServerConnection(); + + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1"), new("SELECT 2") } + }; + + var firstBatchExecuteTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Parse, FrontendMessageCode.Bind, FrontendMessageCode.Describe, FrontendMessageCode.Execute, + FrontendMessageCode.Parse, FrontendMessageCode.Bind, FrontendMessageCode.Describe, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteCommandComplete() + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await firstBatchExecuteTask; + + for (var i = 0; i < 2; i++) + await ExecutePreparedBatch(batch, server); + + async Task ExecutePreparedBatch(NpgsqlBatch batch, PgServerMock server) + { + var executeBatchTask = batch.ExecuteNonQueryAsync(); + + await server.ExpectMessages( + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Bind, FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + await server + .WriteBindComplete() + .WriteCommandComplete() + .WriteBindComplete() + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await executeBatchTask; + } + } + + NpgsqlConnection OpenConnectionAndUnprepare() + { + var conn = OpenConnection(); + conn.UnprepareAll(); + return conn; + } + + void AssertNumPreparedStatements(NpgsqlConnection conn, int expected) + => Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statements%'"), Is.EqualTo(expected)); + + void AssertNumPreparedStatements(NpgsqlConnection conn, int expected, string message) + => Assert.That(conn.ExecuteScalar("SELECT COUNT(*) FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statements%'"), Is.EqualTo(expected), message); + + List GetPreparedStatements(NpgsqlConnection conn) + { + var statements = new List(); + using var cmd = new NpgsqlCommand("SELECT name FROM pg_prepared_statements WHERE statement NOT LIKE '%FROM pg_prepared_statement%'", conn); + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + statements.Add(reader.GetString(0)); + return statements; + } } diff --git a/test/Npgsql.Tests/Properties/AssemblyInfo.cs b/test/Npgsql.Tests/Properties/AssemblyInfo.cs index 716edf739f..f7cdcd188d 100644 --- a/test/Npgsql.Tests/Properties/AssemblyInfo.cs +++ b/test/Npgsql.Tests/Properties/AssemblyInfo.cs @@ -1,4 +1,11 @@ -using NUnit.Framework; +using System.Runtime.CompilerServices; +using NUnit.Framework; -[assembly: NonParallelizable] +[assembly: Parallelizable(ParallelScope.Children), Timeout(30000)] +[assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + +"0024000004800000940000000602000000240000525341310004000001000100" + +"2b3c590b2a4e3d347e6878dc0ff4d21eb056a50420250c6617044330701d35c9" + +"8078a5df97a62d83c9a2db2d072523a8fc491398254c6b89329b8c1dcef43a1e" + +"7aa16153bcea2ae9a471145624826f60d7c8e71cd025b554a0177bd935a78096" + +"29f0a7afc778ebb4ad033e1bf512c1a9c6ceea26b077bc46cac93800435e77ee")] diff --git a/test/Npgsql.Tests/ReadBufferTests.cs b/test/Npgsql.Tests/ReadBufferTests.cs index 2f64cd3746..af48af5223 100644 --- a/test/Npgsql.Tests/ReadBufferTests.cs +++ b/test/Npgsql.Tests/ReadBufferTests.cs @@ -1,176 +1,178 @@ -using System; +using Npgsql.Internal; +using NUnit.Framework; +using System; using System.IO; -using System.Text; using System.Threading; using System.Threading.Tasks; -using Npgsql.Util; -using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +[FixtureLifeCycle(LifeCycle.InstancePerTestCase)] // Parallel access to a single buffer +class ReadBufferTests { - class ReadBufferTests + [Test] + public void Skip() { - [Test] - public void Skip() - { - for (byte i = 0; i < 50; i++) - Writer.WriteByte(i); - - ReadBuffer.Ensure(10); - ReadBuffer.Skip(7); - Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(7)); - ReadBuffer.Skip(10); - ReadBuffer.Ensure(1); - Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(18)); - ReadBuffer.Skip(20); - ReadBuffer.Ensure(1); - Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(39)); - } + for (byte i = 0; i < 50; i++) + Writer.WriteByte(i); + + ReadBuffer.Ensure(10); + ReadBuffer.Skip(7); + Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(7)); + ReadBuffer.Skip(10); + ReadBuffer.Ensure(1); + Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(18)); + ReadBuffer.Skip(20); + ReadBuffer.Ensure(1); + Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(39)); + } - [Test] - public void ReadSingle() - { - const float expected = 8.7f; - var bytes = BitConverter.GetBytes(expected); - Array.Reverse(bytes); - Writer.Write(bytes); + [Test] + public void ReadSingle() + { + const float expected = 8.7f; + var bytes = BitConverter.GetBytes(expected); + Array.Reverse(bytes); + Writer.Write(bytes); - ReadBuffer.Ensure(4); - Assert.That(ReadBuffer.ReadSingle(), Is.EqualTo(expected)); - } + ReadBuffer.Ensure(4); + Assert.That(ReadBuffer.ReadSingle(), Is.EqualTo(expected)); + } - [Test] - public void ReadDouble() - { - const double expected = 8.7; - var bytes = BitConverter.GetBytes(expected); - Array.Reverse(bytes); - Writer.Write(bytes); + [Test] + public void ReadDouble() + { + const double expected = 8.7; + var bytes = BitConverter.GetBytes(expected); + Array.Reverse(bytes); + Writer.Write(bytes); - ReadBuffer.Ensure(8); - Assert.That(ReadBuffer.ReadDouble(), Is.EqualTo(expected)); - } + ReadBuffer.Ensure(8); + Assert.That(ReadBuffer.ReadDouble(), Is.EqualTo(expected)); + } - [Test] - public void ReadNullTerminatedString_buffered_only() - { - Writer - .Write(PGUtil.UTF8Encoding.GetBytes(new string("foo"))) - .WriteByte(0) - .Write(PGUtil.UTF8Encoding.GetBytes(new string("bar"))) - .WriteByte(0); - - Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("foo")); - Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); - } + [Test] + public void ReadNullTerminatedString_buffered_only() + { + Writer + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("foo"))) + .WriteByte(0) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("bar"))) + .WriteByte(0); - [Test] - public async Task ReadNullTerminatedString_with_io() - { - Writer.Write(PGUtil.UTF8Encoding.GetBytes(new string("Chunked "))); - var task = ReadBuffer.ReadNullTerminatedString(async: true); - Assert.That(!task.IsCompleted); - - Writer - .Write(PGUtil.UTF8Encoding.GetBytes(new string("string"))) - .WriteByte(0) - .Write(PGUtil.UTF8Encoding.GetBytes(new string("bar"))) - .WriteByte(0); - Assert.That(task.IsCompleted); - Assert.That(await task, Is.EqualTo("Chunked string")); - Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); - } + ReadBuffer.Ensure(1); + + Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("foo")); + Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); + } + + [Test] + public async Task ReadNullTerminatedString_with_io() + { + Writer.Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("Chunked "))); + await ReadBuffer.Ensure(1, async: true); + var task = ReadBuffer.ReadNullTerminatedString(async: true); + Assert.That(!task.IsCompleted); + + Writer + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("string"))) + .WriteByte(0) + .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("bar"))) + .WriteByte(0); + Assert.That(task.IsCompleted); + Assert.That(await task, Is.EqualTo("Chunked string")); + Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); + } #pragma warning disable CS8625 - [SetUp] - public void SetUp() - { - var stream = new MockStream(); - ReadBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.DefaultSize, PGUtil.UTF8Encoding, PGUtil.RelaxedUTF8Encoding); - Writer = stream.Writer; - } + [SetUp] + public void SetUp() + { + var stream = new MockStream(); + ReadBuffer = new NpgsqlReadBuffer(null, stream, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding, NpgsqlWriteBuffer.RelaxedUTF8Encoding); + Writer = stream.Writer; + } #pragma warning restore CS8625 - // ReSharper disable once InconsistentNaming - NpgsqlReadBuffer ReadBuffer = default!; - // ReSharper disable once InconsistentNaming - MockStream.MockStreamWriter Writer = default!; + // ReSharper disable once InconsistentNaming + NpgsqlReadBuffer ReadBuffer = default!; + // ReSharper disable once InconsistentNaming + MockStream.MockStreamWriter Writer = default!; - class MockStream : Stream - { - const int Size = 8192; + class MockStream : Stream + { + const int Size = 8192; - internal MockStreamWriter Writer { get; } + internal MockStreamWriter Writer { get; } - public MockStream() => Writer = new MockStreamWriter(this); + public MockStream() => Writer = new MockStreamWriter(this); - TaskCompletionSource _tcs = new TaskCompletionSource(); - readonly byte[] _data = new byte[Size]; - int _filled; + TaskCompletionSource _tcs = new(); + readonly byte[] _data = new byte[Size]; + int _filled; - public override int Read(byte[] buffer, int offset, int count) - => Read(buffer, offset, count, async: false).GetAwaiter().GetResult(); + public override int Read(byte[] buffer, int offset, int count) + => Read(buffer, offset, count, async: false).GetAwaiter().GetResult(); - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => Read(buffer, offset, count, async: true); + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => Read(buffer, offset, count, async: true); - async Task Read(byte[] buffer, int offset, int count, bool async) + async Task Read(byte[] buffer, int offset, int count, bool async) + { + if (_filled == 0) { - if (_filled == 0) - { - _tcs = new TaskCompletionSource(); - if (async) - await _tcs.Task; - else - _tcs.Task.Wait(); - } - - count = Math.Min(count, _filled); - new Span(_data, 0, count).CopyTo(new Span(buffer, offset, count)); - new Span(_data, count, _filled - count).CopyTo(_data); - _filled -= count; - return count; + _tcs = new TaskCompletionSource(); + if (async) + await _tcs.Task; + else + _tcs.Task.Wait(); } - internal class MockStreamWriter + count = Math.Min(count, _filled); + new Span(_data, 0, count).CopyTo(new Span(buffer, offset, count)); + new Span(_data, count, _filled - count).CopyTo(_data); + _filled -= count; + return count; + } + + internal class MockStreamWriter + { + readonly MockStream _stream; + + public MockStreamWriter(MockStream stream) => _stream = stream; + + public MockStreamWriter WriteByte(byte b) { - readonly MockStream _stream; - - public MockStreamWriter(MockStream stream) => _stream = stream; - - public MockStreamWriter WriteByte(byte b) - { - Span bytes = stackalloc byte[1]; - bytes[0] = b; - Write(bytes); - return this; - } - - public MockStreamWriter Write(ReadOnlySpan bytes) - { - if (_stream._filled + bytes.Length > Size) - throw new Exception("Mock stream overrun"); - bytes.CopyTo(new Span(_stream._data, _stream._filled, bytes.Length)); - _stream._filled += bytes.Length; - _stream._tcs.TrySetResult(); - return this; - } + Span bytes = stackalloc byte[1]; + bytes[0] = b; + Write(bytes); + return this; } - public override void SetLength(long value) => throw new NotSupportedException(); - public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void Flush() => throw new NotSupportedException(); - - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => false; - public override long Length => throw new NotSupportedException(); - public override long Position + public MockStreamWriter Write(ReadOnlySpan bytes) { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); + if (_stream._filled + bytes.Length > Size) + throw new Exception("Mock stream overrun"); + bytes.CopyTo(new Span(_stream._data, _stream._filled, bytes.Length)); + _stream._filled += bytes.Length; + _stream._tcs.TrySetResult(new()); + return this; } } + + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void Flush() => throw new NotSupportedException(); + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } } } diff --git a/test/Npgsql.Tests/ReaderNewSchemaTests.cs b/test/Npgsql.Tests/ReaderNewSchemaTests.cs index df28c5d272..01e46cdd06 100644 --- a/test/Npgsql.Tests/ReaderNewSchemaTests.cs +++ b/test/Npgsql.Tests/ReaderNewSchemaTests.cs @@ -1,73 +1,73 @@ -using System; -using System.Collections.ObjectModel; +using System.Collections.ObjectModel; using System.Data; +using System.Data.Common; using System.Linq; using System.Threading.Tasks; using Npgsql.PostgresTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests -{ - /// - /// This tests the new CoreCLR schema/metadata API, which returns ReadOnlyCollection<DbColumn>. - /// Note that this API is also available on .NET Framework. - /// For the old DataTable-based API, see . - /// - public class ReaderNewSchemaTests : SyncOrAsyncTestBase - { - // ReSharper disable once InconsistentNaming - [Test] - public async Task AllowDBNull() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "nullable INTEGER, non_nullable INTEGER NOT NULL", out var table); - - using var cmd = new NpgsqlCommand($"SELECT nullable,non_nullable,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].AllowDBNull, Is.True); - Assert.That(columns[1].AllowDBNull, Is.False); - Assert.That(columns[2].AllowDBNull, Is.Null); - } - - [Test] - public async Task BaseCatalogName() - { - var dbName = new NpgsqlConnectionStringBuilder(ConnectionString).Database; - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].BaseCatalogName, Is.EqualTo(dbName)); - Assert.That(columns[1].BaseCatalogName, Is.EqualTo(dbName)); - } +namespace Npgsql.Tests; - [Test] - public async Task BaseColumnName() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); +/// +/// This tests the new CoreCLR schema/metadata API, which returns ReadOnlyCollection<DbColumn>. +/// Note that this API is also available on .NET Framework. +/// For the old DataTable-based API, see . +/// +public class ReaderNewSchemaTests : SyncOrAsyncTestBase +{ + // ReSharper disable once InconsistentNaming + [Test] + public async Task Allow_DBNull() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "nullable INTEGER, non_nullable INTEGER NOT NULL"); + + using var cmd = new NpgsqlCommand($"SELECT nullable,non_nullable,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].AllowDBNull, Is.True); + Assert.That(columns[1].AllowDBNull, Is.False); + Assert.That(columns[2].AllowDBNull, Is.Null); + } - using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); - await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + [Test] + public async Task BaseCatalogName() + { + var dbName = new NpgsqlConnectionStringBuilder(ConnectionString).Database; + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].BaseCatalogName, Is.EqualTo(dbName)); + Assert.That(columns[1].BaseCatalogName, Is.EqualTo(dbName)); + } - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].BaseColumnName, Is.EqualTo("foo")); - Assert.That(columns[1].BaseColumnName, Is.EqualTo("foo")); - Assert.That(columns[2].BaseColumnName, Is.Null); - Assert.That(columns[3].BaseColumnName, Is.Null); - Assert.That(columns[4].BaseColumnName, Is.Null); - } + [Test] + public async Task BaseColumnName() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].BaseColumnName, Is.EqualTo("foo")); + Assert.That(columns[1].BaseColumnName, Is.EqualTo("foo")); + Assert.That(columns[2].BaseColumnName, Is.Null); + Assert.That(columns[3].BaseColumnName, Is.Null); + Assert.That(columns[4].BaseColumnName, Is.Null); + } - [Test] - public async Task BaseColumnNameWithColumnAliases() - { - using var conn = OpenConnection(); + [Test] + public async Task BaseColumnName_with_column_aliases() + { + using var conn = OpenConnection(); - conn.ExecuteNonQuery(@" + conn.ExecuteNonQuery(@" CREATE TEMP TABLE data ( Cod varchar(5) NOT NULL, Descr varchar(40), @@ -76,695 +76,725 @@ CONSTRAINT PK_test_Cod PRIMARY KEY (Cod) ); "); - var cmd = new NpgsqlCommand("SELECT Cod as CodAlias, Descr as DescrAlias, Date, NULL AS Generated FROM data", conn); - - using var dr = cmd.ExecuteReader(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var cols = await GetColumnSchema(dr); + var cmd = new NpgsqlCommand("SELECT Cod as CodAlias, Descr as DescrAlias, Date, NULL AS Generated FROM data", conn); - Assert.That(cols[0].BaseColumnName, Is.EqualTo("cod")); - Assert.That(cols[0].ColumnName, Is.EqualTo("codalias")); - Assert.That(cols[0].IsAliased, Is.True); + using var dr = cmd.ExecuteReader(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var cols = await GetColumnSchema(dr); - Assert.That(cols[1].BaseColumnName, Is.EqualTo("descr")); - Assert.That(cols[1].ColumnName, Is.EqualTo("descralias")); - Assert.That(cols[1].IsAliased, Is.True); + Assert.That(cols[0].BaseColumnName, Is.EqualTo("cod")); + Assert.That(cols[0].ColumnName, Is.EqualTo("codalias")); + Assert.That(cols[0].IsAliased, Is.True); - Assert.That(cols[2].BaseColumnName, Is.EqualTo("date")); - Assert.That(cols[2].ColumnName, Is.EqualTo("date")); - Assert.That(cols[2].IsAliased, Is.False); + Assert.That(cols[1].BaseColumnName, Is.EqualTo("descr")); + Assert.That(cols[1].ColumnName, Is.EqualTo("descralias")); + Assert.That(cols[1].IsAliased, Is.True); - Assert.That(cols[3].BaseColumnName, Is.Null); - Assert.That(cols[3].ColumnName, Is.EqualTo("generated")); - Assert.That(cols[3].IsAliased, Is.Null); - } - - [Test] - public async Task BaseSchemaName() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + Assert.That(cols[2].BaseColumnName, Is.EqualTo("date")); + Assert.That(cols[2].ColumnName, Is.EqualTo("date")); + Assert.That(cols[2].IsAliased, Is.False); - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].BaseSchemaName, Is.EqualTo("public")); - Assert.That(columns[1].BaseSchemaName, Is.Null); - } - - [Test] - public async Task BaseServerName() - { - var host = new NpgsqlConnectionStringBuilder(ConnectionString).Host; - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].BaseServerName, Is.EqualTo(host)); - Assert.That(columns[1].BaseServerName, Is.EqualTo(host)); - } + Assert.That(cols[3].BaseColumnName, Is.Null); + Assert.That(cols[3].ColumnName, Is.EqualTo("generated")); + Assert.That(cols[3].IsAliased, Is.Null); + } - [Test] - public async Task BaseTableName() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + [Test] + public async Task BaseSchemaName() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].BaseSchemaName, Is.EqualTo("public")); + Assert.That(columns[1].BaseSchemaName, Is.Null); + } - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].BaseTableName, Does.StartWith("temp_table")); - Assert.That(columns[1].BaseTableName, Is.Null); - } + [Test] + public async Task BaseServerName() + { + var host = new NpgsqlConnectionStringBuilder(ConnectionString).Host; + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].BaseServerName, Is.EqualTo(host)); + Assert.That(columns[1].BaseServerName, Is.EqualTo(host)); + } - [Test] - public async Task ColumnName() - { - await using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); - Assert.That(columns[1].ColumnName, Is.EqualTo("foobar")); - Assert.That(columns[2].ColumnName, Is.EqualTo("bar")); - Assert.That(columns[3].ColumnName, Is.EqualTo("?column?")); - Assert.That(columns[4].ColumnName, Is.EqualTo("varchar")); - } - - // See https://github.com/npgsql/npgsql/issues/1676 - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "col TEXT", out var table); - - var behavior = CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo; - //var behavior = CommandBehavior.SchemaOnly; - using var command = new NpgsqlCommand($"SELECT col AS col_alias FROM {table}", conn); - using var reader = command.ExecuteReader(behavior); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnName, Is.EqualTo("col_alias")); - } - } + [Test] + public async Task BaseTableName() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].BaseTableName, Does.StartWith("temp_table")); + Assert.That(columns[1].BaseTableName, Is.Null); + } - [Test] - public async Task ColumnOrdinal() + [Test] + public async Task ColumnName() + { + await using (var conn = await OpenConnectionAsync()) { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "first INTEGER, second INTEGER", out var table); + var table = await CreateTempTable(conn, "foo INTEGER"); - using var cmd = new NpgsqlCommand($"SELECT second,first FROM {table}", conn); + using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnName, Is.EqualTo("second")); - Assert.That(columns[0].ColumnOrdinal, Is.EqualTo(0)); - Assert.That(columns[1].ColumnName, Is.EqualTo("first")); - Assert.That(columns[1].ColumnOrdinal, Is.EqualTo(1)); - } - [Test] - public async Task ColumnAttributeNumber() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "first INTEGER, second INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT second,first FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnName, Is.EqualTo("second")); - Assert.That(columns[0].ColumnAttributeNumber, Is.EqualTo(2)); - Assert.That(columns[1].ColumnName, Is.EqualTo("first")); - Assert.That(columns[1].ColumnAttributeNumber, Is.EqualTo(1)); - } - - [Test] - public async Task ColumnSize() - { - if (IsRedshift) - Assert.Ignore("Column size is never unlimited on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "bounded VARCHAR(30), unbounded VARCHAR", out var table); - - using var cmd = new NpgsqlCommand($"SELECT bounded,unbounded,'a'::VARCHAR(10),'b'::VARCHAR FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnSize, Is.EqualTo(30)); - Assert.That(columns[1].ColumnSize, Is.Null); - Assert.That(columns[2].ColumnSize, Is.EqualTo(10)); - Assert.That(columns[3].ColumnSize, Is.Null); + Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); + Assert.That(columns[1].ColumnName, Is.EqualTo("foobar")); + Assert.That(columns[2].ColumnName, Is.EqualTo("bar")); + Assert.That(columns[3].ColumnName, Is.EqualTo("?column?")); + Assert.That(columns[4].ColumnName, Is.EqualTo("varchar")); } - [Test] - public async Task IsAutoIncrement() + // See https://github.com/npgsql/npgsql/issues/1676 + using (var conn = await OpenConnectionAsync()) { - if (IsRedshift) - Assert.Ignore("Serial columns not support on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "inc SERIAL, non_inc INT", out var table); + var table = await CreateTempTable(conn, "col TEXT"); - using var cmd = new NpgsqlCommand($"SELECT inc,non_inc,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var behavior = CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo; + //var behavior = CommandBehavior.SchemaOnly; + using var command = new NpgsqlCommand($"SELECT col AS col_alias FROM {table}", conn); + using var reader = command.ExecuteReader(behavior); var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsAutoIncrement, Is.True, "Serial not identified as autoincrement"); - Assert.That(columns[1].IsAutoIncrement, Is.False, "Regular int column identified as autoincrement"); - Assert.That(columns[2].IsAutoIncrement, Is.Null, "Literal int identified as autoincrement"); + Assert.That(columns[0].ColumnName, Is.EqualTo("col_alias")); } + } - [Test] - public async Task IsAutoIncrementIdentity() - { - using var conn = await OpenConnectionAsync(); - if (conn.PostgreSqlVersion < new Version(10, 0)) - Assert.Ignore("IDENTITY introduced in PostgreSQL 10"); + [Test] + public async Task ColumnOrdinal() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "first INTEGER, second INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT second,first FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].ColumnName, Is.EqualTo("second")); + Assert.That(columns[0].ColumnOrdinal, Is.EqualTo(0)); + Assert.That(columns[1].ColumnName, Is.EqualTo("first")); + Assert.That(columns[1].ColumnOrdinal, Is.EqualTo(1)); + } - await using var _ = await CreateTempTable( - conn, - "inc SERIAL, identity INT GENERATED BY DEFAULT AS IDENTITY, non_inc INT", - out var table); + [Test] + public async Task ColumnAttributeNumber() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "first INTEGER, second INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT second,first FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].ColumnName, Is.EqualTo("second")); + Assert.That(columns[0].ColumnAttributeNumber, Is.EqualTo(2)); + Assert.That(columns[1].ColumnName, Is.EqualTo("first")); + Assert.That(columns[1].ColumnAttributeNumber, Is.EqualTo(1)); + } - using var cmd = new NpgsqlCommand($"SELECT inc,identity,non_inc,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsAutoIncrement, Is.True, "Serial not identified as autoincrement"); - Assert.That(columns[1].IsAutoIncrement, Is.True, "PG 10 IDENTITY not identified as autoincrement"); - Assert.That(columns[2].IsAutoIncrement, Is.False, "Regular int column identified as autoincrement"); - Assert.That(columns[3].IsAutoIncrement, Is.Null, "Literal int identified as autoincrement"); - } + [Test] + public async Task ColumnSize() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Column size is never unlimited on Redshift"); + var table = await CreateTempTable(conn, "bounded VARCHAR(30), unbounded VARCHAR"); + + using var cmd = new NpgsqlCommand($"SELECT bounded,unbounded,'a'::VARCHAR(10),'b'::VARCHAR FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].ColumnSize, Is.EqualTo(30)); + Assert.That(columns[1].ColumnSize, Is.Null); + Assert.That(columns[2].ColumnSize, Is.EqualTo(10)); + Assert.That(columns[3].ColumnSize, Is.Null); + } - [Test] - public async Task IsKey() - { - if (IsRedshift) - Assert.Ignore("Key not supported in reader schema on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "id INT PRIMARY KEY, non_id INT, uniq INT UNIQUE", out var table); + [Test] + public async Task IsAutoIncrement() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Serial columns not support on Redshift"); - using var cmd = new NpgsqlCommand($"SELECT id,non_id,uniq,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsKey, Is.True); - Assert.That(columns[1].IsKey, Is.False); + var table = await CreateTempTable(conn, "serial SERIAL, int INT"); - // Note: according to the old API docs any unique column is considered key. - // https://msdn.microsoft.com/en-us/library/system.data.sqlclient.sqldatareader.getschematable(v=vs.110).aspx - // But in the new API we have a separate IsUnique so IsKey should be false - Assert.That(columns[2].IsKey, Is.False); + await using var cmd = new NpgsqlCommand($"SELECT serial, int, 8 FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsAutoIncrement, Is.True, "Serial not identified as autoincrement"); + Assert.That(columns[1].IsAutoIncrement, Is.False, "Regular int column identified as autoincrement"); + Assert.That(columns[2].IsAutoIncrement, Is.Null, "Literal int identified as autoincrement"); + } - Assert.That(columns[3].IsKey, Is.Null); - } + [Test] + public async Task IsAutoIncrement_identity() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); + + var table = + await CreateTempTable(conn, "identity1 INT GENERATED BY DEFAULT AS IDENTITY, identity2 INT GENERATED ALWAYS AS IDENTITY"); + + await using var cmd = new NpgsqlCommand($"SELECT identity1, identity2 FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsAutoIncrement, Is.True, "PG 10 IDENTITY not identified as autoincrement"); + Assert.That(columns[1].IsAutoIncrement, Is.True, "PG 10 IDENTITY not identified as autoincrement"); + } - [Test] - public async Task IsKeyComposite() - { - if (IsRedshift) - Assert.Ignore("Key not supported in reader schema on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "id1 INT, id2 INT, PRIMARY KEY (id1, id2)", out var table); + [Test] + public async Task IsIdentity() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); + var table = await CreateTempTable( + conn, + "identity1 INT GENERATED BY DEFAULT AS IDENTITY, identity2 INT GENERATED ALWAYS AS IDENTITY, serial SERIAL, int INT"); + + await using var cmd = new NpgsqlCommand($"SELECT identity1, identity2, serial, int, 8 FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsIdentity, Is.True, "PG 10 IDENTITY not identified as identity"); + Assert.That(columns[1].IsIdentity, Is.True, "PG 10 IDENTITY not identified as identity"); + Assert.That(columns[2].IsIdentity, Is.False, "Serial identified as identity"); + Assert.That(columns[3].IsIdentity, Is.False, "Regular int column identified as identity"); + Assert.That(columns[4].IsIdentity, Is.False, "Literal int identified as identity"); + } - using var cmd = new NpgsqlCommand($"SELECT id1,id2 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsKey, Is.True); - Assert.That(columns[1].IsKey, Is.True); - } + [Test] + public async Task IsKey() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); + var table = await CreateTempTable(conn, "id INT PRIMARY KEY, non_id INT, uniq INT UNIQUE"); + + using var cmd = new NpgsqlCommand($"SELECT id,non_id,uniq,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsKey, Is.True); + Assert.That(columns[1].IsKey, Is.False); + + // Note: according to the old API docs any unique column is considered key. + // https://msdn.microsoft.com/en-us/library/system.data.sqlclient.sqldatareader.getschematable(v=vs.110).aspx + // But in the new API we have a separate IsUnique so IsKey should be false + Assert.That(columns[2].IsKey, Is.False); + + Assert.That(columns[3].IsKey, Is.Null); + } - [Test] - public async Task IsLong() - { - if (IsRedshift) - Assert.Ignore("bytea not supported on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "long BYTEA, non_long INT", out var table); + [Test] + public async Task IsKey_composite() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); + var table = await CreateTempTable(conn, "id1 INT, id2 INT, PRIMARY KEY (id1, id2)"); + + using var cmd = new NpgsqlCommand($"SELECT id1,id2 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsKey, Is.True); + Assert.That(columns[1].IsKey, Is.True); + } - using var cmd = new NpgsqlCommand($"SELECT long, non_long, 8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsLong, Is.True); - Assert.That(columns[1].IsLong, Is.False); - Assert.That(columns[2].IsLong, Is.False); - } + [Test] + public async Task IsLong() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "bytea not supported on Redshift"); + var table = await CreateTempTable(conn, "long BYTEA, non_long INT"); + + using var cmd = new NpgsqlCommand($"SELECT long, non_long, 8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsLong, Is.True); + Assert.That(columns[1].IsLong, Is.False); + Assert.That(columns[2].IsLong, Is.False); + } - [Test] - public async Task IsReadOnlyOnView() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempViewName(conn, out var view); - await using var __ = await GetTempTableName(conn, out var table); + [Test] + public async Task IsReadOnly_on_view() + { + using var conn = await OpenConnectionAsync(); + var view = await GetTempViewName(conn); + var table = await GetTempTableName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE VIEW {view} AS SELECT 8 AS foo; CREATE TABLE {table} (bar INTEGER)"); - using var cmd = new NpgsqlCommand($"SELECT foo,bar FROM {view},{table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsReadOnly, Is.True); - Assert.That(columns[1].IsReadOnly, Is.False); - } + using var cmd = new NpgsqlCommand($"SELECT foo,bar FROM {view},{table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsReadOnly, Is.True); + Assert.That(columns[1].IsReadOnly, Is.False); + } - [Test] - public async Task IsReadOnlyOnNonColumn() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT 8", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns.Single().IsReadOnly, Is.True); - } + [Test] + public async Task IsReadOnly_on_non_column() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 8", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns.Single().IsReadOnly, Is.True); + } - [Test] - public async Task IsUnique() - { - if (IsRedshift) - Assert.Ignore("Unique not supported in reader schema on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var __ = await GetTempTableName(conn, out var table); + [Test] + public async Task IsUnique() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Unique not supported in reader schema on Redshift"); + var table = await GetTempTableName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} (id INT PRIMARY KEY, non_id INT, uniq INT UNIQUE, non_id_second INT, non_id_third INT); CREATE UNIQUE INDEX idx_{table} ON {table} (non_id_second, non_id_third)"); - using var cmd = new NpgsqlCommand($"SELECT id,non_id,uniq,8,non_id_second,non_id_third FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsUnique, Is.True); - Assert.That(columns[1].IsUnique, Is.False); - Assert.That(columns[2].IsUnique, Is.True); - Assert.That(columns[3].IsUnique, Is.Null); - Assert.That(columns[4].IsUnique, Is.False); - Assert.That(columns[5].IsUnique, Is.False); - } - - [Test] - public async Task NumericPrecision() - { - if (IsRedshift) - Assert.Ignore("Precision is never unlimited on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "a NUMERIC(8), b NUMERIC, c INTEGER", out var table); + using var cmd = new NpgsqlCommand($"SELECT id,non_id,uniq,8,non_id_second,non_id_third FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsUnique, Is.True); + Assert.That(columns[1].IsUnique, Is.False); + Assert.That(columns[2].IsUnique, Is.True); + Assert.That(columns[3].IsUnique, Is.Null); + Assert.That(columns[4].IsUnique, Is.False); + Assert.That(columns[5].IsUnique, Is.False); + } - using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8) FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].NumericPrecision, Is.EqualTo(8)); - Assert.That(columns[1].NumericPrecision, Is.Null); - Assert.That(columns[2].NumericPrecision, Is.Null); - Assert.That(columns[3].NumericPrecision, Is.EqualTo(8)); - } + [Test] + public async Task NumericPrecision() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Precision is never unlimited on Redshift"); + var table = await CreateTempTable(conn, "a NUMERIC(8), b NUMERIC, c INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8) FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].NumericPrecision, Is.EqualTo(8)); + Assert.That(columns[1].NumericPrecision, Is.Null); + Assert.That(columns[2].NumericPrecision, Is.Null); + Assert.That(columns[3].NumericPrecision, Is.EqualTo(8)); + } - [Test] - public async Task NumericScale() - { - if (IsRedshift) - Assert.Ignore("Scale is never unlimited on Redshift"); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "a NUMERIC(8,5), b NUMERIC, c INTEGER", out var table); + [Test] + public async Task NumericScale() + { + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Scale is never unlimited on Redshift"); + var table = await CreateTempTable(conn, "a NUMERIC(8,5), b NUMERIC, c INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8,5) FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].NumericScale, Is.EqualTo(5)); + Assert.That(columns[1].NumericScale, Is.Null); + Assert.That(columns[2].NumericScale, Is.Null); + Assert.That(columns[3].NumericScale, Is.EqualTo(5)); + } - using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8,5) FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].NumericScale, Is.EqualTo(5)); - Assert.That(columns[1].NumericScale, Is.Null); - Assert.That(columns[2].NumericScale, Is.Null); - Assert.That(columns[3].NumericScale, Is.EqualTo(5)); - } + [Test] + public async Task DataType() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8::INTEGER FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].DataType, Is.SameAs(typeof(int))); + Assert.That(columns[1].DataType, Is.SameAs(typeof(int))); + } - [Test] - public async Task DataType() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1305")] + public async Task DataType_unknown_type() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo::INTEGER FROM {table}", conn); + cmd.AllResultTypesAreUnknown = true; + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].DataType, Is.SameAs(typeof(int))); + } - using var cmd = new NpgsqlCommand($"SELECT foo,8::INTEGER FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].DataType, Is.SameAs(typeof(int))); - Assert.That(columns[1].DataType, Is.SameAs(typeof(int))); - } + [Test] + public async Task DataType_with_composite() + { + await using var adminConnection = await OpenConnectionAsync(); + IgnoreOnRedshift(adminConnection, "Composite types not support on Redshift"); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (foo int)"); + var tableName = await CreateTempTable(adminConnection, $"comp {type}"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand($"SELECT comp,'(4)'::{type} FROM {tableName}", connection); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].DataType, Is.SameAs(typeof(SomeComposite))); + Assert.That(columns[0].UdtAssemblyQualifiedName, Is.EqualTo(typeof(SomeComposite).AssemblyQualifiedName)); + Assert.That(columns[1].DataType, Is.SameAs(typeof(SomeComposite))); + Assert.That(columns[1].UdtAssemblyQualifiedName, Is.EqualTo(typeof(SomeComposite).AssemblyQualifiedName)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1305")] - public async Task DataTypeUnknownType() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + [Test] + public async Task UdtAssemblyQualifiedName() + { + // Also see DataTypeWithComposite + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].UdtAssemblyQualifiedName, Is.Null); + Assert.That(columns[1].UdtAssemblyQualifiedName, Is.Null); + } - using var cmd = new NpgsqlCommand($"SELECT foo::INTEGER FROM {table}", conn); - cmd.AllResultTypesAreUnknown = true; - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].DataType, Is.SameAs(typeof(int))); - } + [Test] + public async Task PostgresType() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + var intType = columns[0].PostgresType; + Assert.That(columns[1].PostgresType, Is.SameAs(intType)); + Assert.That(intType.Name, Is.EqualTo("integer")); + Assert.That(intType.InternalName, Is.EqualTo("int4")); + } - [Test, NonParallelizable] - public async Task DataTypeWithComposite() - { - if (IsRedshift) - Assert.Ignore("Composite types not support on Redshift"); - // if (IsMultiplexing) - // Assert.Ignore("Multiplexing: ReloadTypes"); - - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(DataTypeWithComposite), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using var conn = await OpenConnectionAsync(csb); - await conn.ExecuteNonQueryAsync("CREATE TYPE pg_temp.some_composite AS (foo int)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite(); - await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (comp pg_temp.some_composite)"); - - using var cmd = new NpgsqlCommand("SELECT comp,'(4)'::some_composite FROM data", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].DataType, Is.SameAs(typeof(SomeComposite))); - Assert.That(columns[0].UdtAssemblyQualifiedName, Is.EqualTo(typeof(SomeComposite).AssemblyQualifiedName)); - Assert.That(columns[1].DataType, Is.SameAs(typeof(SomeComposite))); - Assert.That(columns[1].UdtAssemblyQualifiedName, Is.EqualTo(typeof(SomeComposite).AssemblyQualifiedName)); - } + [Test] + public async Task ColumnSchema_with_and_without_KeyInfo() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - [Test] - public async Task UdtAssemblyQualifiedName() + using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); + await using (var reader = await cmd.ExecuteReaderAsync()) { - // Also see DataTypeWithComposite - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); var columns = await GetColumnSchema(reader); - Assert.That(columns[0].UdtAssemblyQualifiedName, Is.Null); - Assert.That(columns[1].UdtAssemblyQualifiedName, Is.Null); - } - - [Test] - public async Task PostgresType() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - using var cmd = new NpgsqlCommand($"SELECT foo,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - var intType = columns[0].PostgresType; - Assert.That(columns[1].PostgresType, Is.SameAs(intType)); - Assert.That(intType.Name, Is.EqualTo("integer")); - Assert.That(intType.InternalName, Is.EqualTo("int4")); + Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); + Assert.That(columns[0].BaseColumnName, Is.Null); + Assert.That(columns[0].BaseTableName, Is.Null); + Assert.That(columns[0].BaseSchemaName, Is.Null); + Assert.That(columns[0].IsAliased, Is.Null); + Assert.That(columns[0].IsKey, Is.Null); + Assert.That(columns[0].IsUnique, Is.Null); + Assert.That(columns[1].ColumnName, Is.EqualTo("foobar")); + Assert.That(columns[1].BaseColumnName, Is.Null); + Assert.That(columns[1].BaseTableName, Is.Null); + Assert.That(columns[1].BaseSchemaName, Is.Null); + Assert.That(columns[1].IsAliased, Is.Null); + Assert.That(columns[1].IsKey, Is.Null); + Assert.That(columns[1].IsUnique, Is.Null); + Assert.That(columns[2].ColumnName, Is.EqualTo("bar")); + Assert.That(columns[2].BaseColumnName, Is.Null); + Assert.That(columns[2].BaseTableName, Is.Null); + Assert.That(columns[2].BaseSchemaName, Is.Null); + Assert.That(columns[2].IsAliased, Is.Null); + Assert.That(columns[2].IsKey, Is.Null); + Assert.That(columns[2].IsUnique, Is.Null); + Assert.That(columns[3].ColumnName, Is.EqualTo("?column?")); + Assert.That(columns[3].BaseColumnName, Is.Null); + Assert.That(columns[3].BaseTableName, Is.Null); + Assert.That(columns[3].BaseSchemaName, Is.Null); + Assert.That(columns[3].IsAliased, Is.Null); + Assert.That(columns[3].IsKey, Is.Null); + Assert.That(columns[3].IsUnique, Is.Null); + Assert.That(columns[4].ColumnName, Is.EqualTo("varchar")); + Assert.That(columns[4].BaseColumnName, Is.Null); + Assert.That(columns[4].BaseTableName, Is.Null); + Assert.That(columns[4].BaseSchemaName, Is.Null); + Assert.That(columns[4].IsAliased, Is.Null); + Assert.That(columns[4].IsKey, Is.Null); + Assert.That(columns[4].IsUnique, Is.Null); + + } + + await using (var readerInfo = await cmd.ExecuteReaderAsync(CommandBehavior.KeyInfo)) + { + var columnsInfo = await GetColumnSchema(readerInfo); + + Assert.That(columnsInfo[0].ColumnName, Is.EqualTo("foo")); + Assert.That(columnsInfo[0].BaseColumnName, Is.EqualTo("foo")); + Assert.That(columnsInfo[0].BaseSchemaName, Is.EqualTo("public")); + Assert.That(columnsInfo[0].IsAliased, Is.EqualTo(false)); + Assert.That(columnsInfo[0].IsKey, Is.EqualTo(false)); + Assert.That(columnsInfo[0].IsUnique, Is.EqualTo(false)); + Assert.That(columnsInfo[1].ColumnName, Is.EqualTo("foobar")); + Assert.That(columnsInfo[1].BaseColumnName, Is.EqualTo("foo")); + Assert.That(columnsInfo[1].BaseSchemaName, Is.EqualTo("public")); + Assert.That(columnsInfo[1].IsAliased, Is.EqualTo(true)); + Assert.That(columnsInfo[1].IsKey, Is.EqualTo(false)); + Assert.That(columnsInfo[1].IsUnique, Is.EqualTo(false)); + Assert.That(columnsInfo[2].ColumnName, Is.EqualTo("bar")); + Assert.That(columnsInfo[2].BaseColumnName, Is.Null); + Assert.That(columnsInfo[2].BaseSchemaName, Is.Null); + Assert.That(columnsInfo[2].IsAliased, Is.Null); + Assert.That(columnsInfo[2].IsKey, Is.Null); + Assert.That(columnsInfo[2].IsUnique, Is.Null); + Assert.That(columnsInfo[3].ColumnName, Is.EqualTo("?column?")); + Assert.That(columnsInfo[3].BaseColumnName, Is.Null); + Assert.That(columnsInfo[3].BaseSchemaName, Is.Null); + Assert.That(columnsInfo[3].IsAliased, Is.Null); + Assert.That(columnsInfo[3].IsKey, Is.Null); + Assert.That(columnsInfo[3].IsUnique, Is.Null); + Assert.That(columnsInfo[4].ColumnName, Is.EqualTo("varchar")); + Assert.That(columnsInfo[4].BaseColumnName, Is.Null); + Assert.That(columnsInfo[4].BaseSchemaName, Is.Null); + Assert.That(columnsInfo[4].IsAliased, Is.Null); + Assert.That(columnsInfo[4].IsKey, Is.Null); + Assert.That(columnsInfo[4].IsUnique, Is.Null); } + } - [Test] - public async Task ColumnSchemaWithoutWithNoKeyInfo() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + /// + [Test] + [TestCase("integer")] + [TestCase("real")] + [TestCase("integer[]")] + [TestCase("character varying(10)", 10)] + [TestCase("character varying")] + [TestCase("character varying(10)[]", 10)] + [TestCase("character(10)", 10)] + [TestCase("character", 1)] + [TestCase("character(1)", 1)] + [TestCase("numeric(1000, 2)", null, 1000, 2)] + [TestCase("numeric(1000)", null, 1000, null)] + [TestCase("numeric")] + [TestCase("timestamp without time zone")] + [TestCase("timestamp(2) without time zone", null, 2)] + [TestCase("timestamp(2) with time zone", null, 2)] + [TestCase("time without time zone")] + [TestCase("time(2) without time zone", null, 2)] + [TestCase("time(2) with time zone", null, 2)] + [TestCase("interval")] + [TestCase("interval(2)", null, 2)] + [TestCase("bit", 1)] + [TestCase("bit(3)", 3)] + [TestCase("bit varying")] + [TestCase("bit varying(3)", 3)] + public async Task DataTypeName(string typeName, int? size = null, int? precision = null, int? scale = null) + { + var openingParen = typeName.IndexOf('('); + var typeNameWithoutFacets = openingParen == -1 + ? typeName + : typeName.Substring(0, openingParen) + typeName.Substring(typeName.IndexOf(')') + 1); + + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, $"foo {typeName}"); + + using var cmd = new NpgsqlCommand($"SELECT foo,NULL::{typeName} FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + var tableColumn = columns[0]; + var nonTableColumn = columns[1]; + Assert.That(tableColumn.DataTypeName, Is.EqualTo(typeNameWithoutFacets)); + Assert.That(tableColumn.ColumnSize, Is.EqualTo(size)); + Assert.That(tableColumn.NumericPrecision, Is.EqualTo(precision)); + Assert.That(tableColumn.NumericScale, Is.EqualTo(scale)); + Assert.That(nonTableColumn.DataTypeName, Is.EqualTo(typeNameWithoutFacets)); + Assert.That(nonTableColumn.ColumnSize, Is.EqualTo(size)); + Assert.That(nonTableColumn.NumericPrecision, Is.EqualTo(precision)); + Assert.That(nonTableColumn.NumericScale, Is.EqualTo(scale)); + } - using var cmd = new NpgsqlCommand($"SELECT foo, foo AS foobar, 8 AS bar, 8, '8'::VARCHAR(10) FROM {table}", conn); - await using (var reader = await cmd.ExecuteReaderAsync()) - { - var columns = await GetColumnSchema(reader); - - Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); - Assert.That(columns[0].BaseColumnName, Is.Null); - Assert.That(columns[0].BaseTableName, Is.Null); - Assert.That(columns[0].BaseSchemaName, Is.Null); - Assert.That(columns[0].IsAliased, Is.Null); - Assert.That(columns[0].IsKey, Is.Null); - Assert.That(columns[0].IsUnique, Is.Null); - Assert.That(columns[1].ColumnName, Is.EqualTo("foobar")); - Assert.That(columns[1].BaseColumnName, Is.Null); - Assert.That(columns[1].BaseTableName, Is.Null); - Assert.That(columns[1].BaseSchemaName, Is.Null); - Assert.That(columns[1].IsAliased, Is.Null); - Assert.That(columns[1].IsKey, Is.Null); - Assert.That(columns[1].IsUnique, Is.Null); - Assert.That(columns[2].ColumnName, Is.EqualTo("bar")); - Assert.That(columns[2].BaseColumnName, Is.Null); - Assert.That(columns[2].BaseTableName, Is.Null); - Assert.That(columns[2].BaseSchemaName, Is.Null); - Assert.That(columns[2].IsAliased, Is.Null); - Assert.That(columns[2].IsKey, Is.Null); - Assert.That(columns[2].IsUnique, Is.Null); - Assert.That(columns[3].ColumnName, Is.EqualTo("?column?")); - Assert.That(columns[3].BaseColumnName, Is.Null); - Assert.That(columns[3].BaseTableName, Is.Null); - Assert.That(columns[3].BaseSchemaName, Is.Null); - Assert.That(columns[3].IsAliased, Is.Null); - Assert.That(columns[3].IsKey, Is.Null); - Assert.That(columns[3].IsUnique, Is.Null); - Assert.That(columns[4].ColumnName, Is.EqualTo("varchar")); - Assert.That(columns[4].BaseColumnName, Is.Null); - Assert.That(columns[4].BaseTableName, Is.Null); - Assert.That(columns[4].BaseSchemaName, Is.Null); - Assert.That(columns[4].IsAliased, Is.Null); - Assert.That(columns[4].IsKey, Is.Null); - Assert.That(columns[4].IsUnique, Is.Null); - - } - - await using (var readerInfo = await cmd.ExecuteReaderAsync(CommandBehavior.KeyInfo)) - { - var columnsInfo = await GetColumnSchema(readerInfo); - - Assert.That(columnsInfo[0].ColumnName, Is.EqualTo("foo")); - Assert.That(columnsInfo[0].BaseColumnName, Is.EqualTo("foo")); - Assert.That(columnsInfo[0].BaseSchemaName, Is.EqualTo("public")); - Assert.That(columnsInfo[0].IsAliased, Is.EqualTo(false)); - Assert.That(columnsInfo[0].IsKey, Is.EqualTo(false)); - Assert.That(columnsInfo[0].IsUnique, Is.EqualTo(false)); - Assert.That(columnsInfo[1].ColumnName, Is.EqualTo("foobar")); - Assert.That(columnsInfo[1].BaseColumnName, Is.EqualTo("foo")); - Assert.That(columnsInfo[1].BaseSchemaName, Is.EqualTo("public")); - Assert.That(columnsInfo[1].IsAliased, Is.EqualTo(true)); - Assert.That(columnsInfo[1].IsKey, Is.EqualTo(false)); - Assert.That(columnsInfo[1].IsUnique, Is.EqualTo(false)); - Assert.That(columnsInfo[2].ColumnName, Is.EqualTo("bar")); - Assert.That(columnsInfo[2].BaseColumnName, Is.Null); - Assert.That(columnsInfo[2].BaseSchemaName, Is.Null); - Assert.That(columnsInfo[2].IsAliased, Is.Null); - Assert.That(columnsInfo[2].IsKey, Is.Null); - Assert.That(columnsInfo[2].IsUnique, Is.Null); - Assert.That(columnsInfo[3].ColumnName, Is.EqualTo("?column?")); - Assert.That(columnsInfo[3].BaseColumnName, Is.Null); - Assert.That(columnsInfo[3].BaseSchemaName, Is.Null); - Assert.That(columnsInfo[3].IsAliased, Is.Null); - Assert.That(columnsInfo[3].IsKey, Is.Null); - Assert.That(columnsInfo[3].IsUnique, Is.Null); - Assert.That(columnsInfo[4].ColumnName, Is.EqualTo("varchar")); - Assert.That(columnsInfo[4].BaseColumnName, Is.Null); - Assert.That(columnsInfo[4].BaseSchemaName, Is.Null); - Assert.That(columnsInfo[4].IsAliased, Is.Null); - Assert.That(columnsInfo[4].IsKey, Is.Null); - Assert.That(columnsInfo[4].IsUnique, Is.Null); - } - } + [Test] + public async Task DefaultValue() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "with_default INTEGER DEFAULT(8), without_default INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT with_default,without_default,8 FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].DefaultValue, Is.EqualTo("8")); + Assert.That(columns[1].DefaultValue, Is.Null); + Assert.That(columns[2].DefaultValue, Is.Null); + } - /// - [Test] - [TestCase("integer")] - [TestCase("real")] - [TestCase("integer[]")] - [TestCase("character varying(10)", 10)] - [TestCase("character varying")] - [TestCase("character varying(10)[]", 10)] - [TestCase("character(10)", 10)] - [TestCase("character", 1)] - [TestCase("numeric(1000, 2)", null, 1000, 2)] - [TestCase("numeric(1000)", null, 1000, null)] - [TestCase("numeric")] - [TestCase("timestamp without time zone")] - [TestCase("timestamp(2) without time zone", null, 2)] - [TestCase("timestamp(2) with time zone", null, 2)] - [TestCase("time without time zone")] - [TestCase("time(2) without time zone", null, 2)] - [TestCase("time(2) with time zone", null, 2)] - [TestCase("interval")] - [TestCase("interval(2)", null, 2)] - [TestCase("bit", 1)] - [TestCase("bit(3)", 3)] - [TestCase("bit varying")] - [TestCase("bit varying(3)", 3)] - public async Task DataTypeName(string typeName, int? size = null, int? precision = null, int? scale = null) - { - var openingParen = typeName.IndexOf('('); - var typeNameWithoutFacets = openingParen == -1 - ? typeName - : typeName.Substring(0, openingParen) + typeName.Substring(typeName.IndexOf(')') + 1); + [Test] + public async Task Same_column_name() + { + using var conn = await OpenConnectionAsync(); + var table1 = await GetTempTableName(conn); + var table2 = await GetTempTableName(conn); - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, $"foo {typeName}", out var table); + await conn.ExecuteNonQueryAsync($@" +CREATE TABLE {table1} (foo INTEGER); +CREATE TABLE {table2} (foo INTEGER)"); - using var cmd = new NpgsqlCommand($"SELECT foo,NULL::{typeName} FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - var tableColumn = columns[0]; - var nonTableColumn = columns[1]; - Assert.That(tableColumn.DataTypeName, Is.EqualTo(typeNameWithoutFacets)); - Assert.That(tableColumn.ColumnSize, Is.EqualTo(size)); - Assert.That(tableColumn.NumericPrecision, Is.EqualTo(precision)); - Assert.That(tableColumn.NumericScale, Is.EqualTo(scale)); - Assert.That(nonTableColumn.DataTypeName, Is.EqualTo(typeNameWithoutFacets)); - Assert.That(nonTableColumn.ColumnSize, Is.EqualTo(size)); - Assert.That(nonTableColumn.NumericPrecision, Is.EqualTo(precision)); - Assert.That(nonTableColumn.NumericScale, Is.EqualTo(scale)); - } + using var cmd = new NpgsqlCommand($"SELECT {table1}.foo,{table2}.foo FROM {table1},{table2}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); + Assert.That(columns[0].BaseTableName, Does.StartWith("temp_table")); + Assert.That(columns[1].ColumnName, Is.EqualTo("foo")); + Assert.That(columns[1].BaseTableName, Does.StartWith("temp_table")); + Assert.That(columns[0].BaseTableName, Is.Not.EqualTo(columns[1].BaseTableName)); + } - [Test] - public async Task DefaultValue() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable( - conn, "with_default INTEGER DEFAULT(8), without_default INTEGER", out var table); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1553")] + public async Task Domain_type() + { + // if (IsMultiplexing) + // Assert.Ignore("Multiplexing: ReloadTypes"); + using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Domain types not support on Redshift"); + + const string domainTypeName = "my_domain"; + var schema = await CreateTempSchema(conn); + var tableName = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {schema}.{domainTypeName} AS varchar(2)"); + conn.ReloadTypes(); + await conn.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (domain {schema}.{domainTypeName})"); + using var cmd = new NpgsqlCommand($"SELECT domain FROM {tableName}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await GetColumnSchema(reader); + var domainSchema = columns.Single(c => c.ColumnName == "domain"); + Assert.That(domainSchema.ColumnSize, Is.EqualTo(2)); + var pgType = domainSchema.PostgresType; + Assert.That(pgType, Is.InstanceOf()); + Assert.That(((PostgresDomainType)pgType).BaseType.Name, Is.EqualTo("character varying")); + } - using var cmd = new NpgsqlCommand($"SELECT with_default,without_default,8 FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].DefaultValue, Is.EqualTo("8")); - Assert.That(columns[1].DefaultValue, Is.Null); - Assert.That(columns[2].DefaultValue, Is.Null); - } + [Test] + public async Task NpgsqlDbType() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); + + using var cmd = new NpgsqlCommand($"SELECT foo,8::INTEGER FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Integer)); + Assert.That(columns[1].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Integer)); + } - [Test] - public async Task SameColumnName() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTableName(conn, out var table1); - await using var __ = await GetTempTableName(conn, out var table2); + [Test] + [NonParallelizable] + public async Task NpgsqlDbType_extension() + { + using var conn = await OpenConnectionAsync(); + await EnsureExtensionAsync(conn, "hstore", "9.1"); + + using var cmd = new NpgsqlCommand("SELECT NULL::HSTORE", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + // The full datatype name for PostGIS is public.geometry (unlike int4 which is in pg_catalog). + Assert.That(columns[0].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Hstore)); + } - await conn.ExecuteNonQueryAsync($@" -CREATE TABLE {table1} (foo INTEGER); -CREATE TABLE {table2} (foo INTEGER)"); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1950")] + public async Task No_resultset() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("COMMIT", conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + await GetColumnSchema(reader); + } - using var cmd = new NpgsqlCommand($"SELECT {table1}.foo,{table2}.foo FROM {table1},{table2}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); - Assert.That(columns[0].BaseTableName, Does.StartWith("temp_table")); - Assert.That(columns[1].ColumnName, Is.EqualTo("foo")); - Assert.That(columns[1].BaseTableName, Does.StartWith("temp_table")); - Assert.That(columns[0].BaseTableName, Is.Not.EqualTo(columns[1].BaseTableName)); - } + [Test] + public async Task IsAliased() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1553")] - public async Task DomainTypes() - { - if (IsRedshift) - Assert.Ignore("Domain types not support on Redshift"); - // if (IsMultiplexing) - // Assert.Ignore("Multiplexing: ReloadTypes"); - using var conn = await OpenConnectionAsync(); - await conn.ExecuteNonQueryAsync("CREATE DOMAIN pg_temp.mydomain AS varchar(2)"); - conn.ReloadTypes(); - await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (domain mydomain)"); - using var cmd = new NpgsqlCommand("SELECT domain FROM data", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - var domainSchema = columns.Single(c => c.ColumnName == "domain"); - Assert.That(domainSchema.ColumnSize, Is.EqualTo(2)); - var pgType = domainSchema.PostgresType; - Assert.That(pgType, Is.InstanceOf()); - Assert.That(((PostgresDomainType)pgType).BaseType.Name, Is.EqualTo("character varying")); - } + using var cmd = new NpgsqlCommand($"SELECT foo, foo AS bar, NULL AS foobar FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - [Test] - public async Task NpgsqlDbType() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].IsAliased, Is.False); + Assert.That(columns[1].IsAliased, Is.True); + Assert.That(columns[2].IsAliased, Is.Null); + } - using var cmd = new NpgsqlCommand($"SELECT foo,8::INTEGER FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Integer)); - Assert.That(columns[1].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Integer)); - } + [Test] // #4672 + public async Task With_parameter_without_value() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - [Test] - public async Task NpgsqlDbTypeExtension() + using var cmd = new NpgsqlCommand($"SELECT foo FROM {table} WHERE foo > @p", conn) { - using var conn = await OpenConnectionAsync(); - await EnsureExtensionAsync(conn, "hstore", "9.1"); + Parameters = { new() { ParameterName = "p", NpgsqlDbType = NpgsqlTypes.NpgsqlDbType.Integer } } + }; + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - using var cmd = new NpgsqlCommand("SELECT NULL::HSTORE", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - var columns = await GetColumnSchema(reader); - // The full datatype name for PostGIS is public.geometry (unlike int4 which is in pg_catalog). - Assert.That(columns[0].NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Hstore)); - } + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1950")] - public async Task NoResultset() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("COMMIT", conn); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - await GetColumnSchema(reader); - } + [Test] + public async Task GetColumnSchema_via_interface() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - [Test] - public async Task IsAliased() + using var cmd = new NpgsqlCommand($"SELECT foo FROM {table} WHERE foo > @p", conn) { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); - - using var cmd = new NpgsqlCommand($"SELECT foo, foo AS bar, NULL AS foobar FROM {table}", conn); - await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + Parameters = { new() { ParameterName = "p", NpgsqlDbType = NpgsqlTypes.NpgsqlDbType.Integer } } + }; + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var columns = await GetColumnSchema(reader); - Assert.That(columns[0].IsAliased, Is.False); - Assert.That(columns[1].IsAliased, Is.True); - Assert.That(columns[2].IsAliased, Is.Null); - } - - #region Not supported - - [Test] - public async Task IsExpression() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + var iface = (IDbColumnSchemaGenerator)reader; + var schema = iface.GetColumnSchema(); + Assert.NotNull(schema); + Assert.AreEqual(1, schema.Count); + Assert.NotNull(schema[0]); + } - using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - Assert.That(reader.GetColumnSchema().Single().IsExpression, Is.False); - } + #region Not supported - [Test] - public async Task IsHidden() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + [Test] + public async Task IsExpression() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - Assert.That(reader.GetColumnSchema().Single().IsHidden, Is.False); - } + using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.That(reader.GetColumnSchema().Single().IsExpression, Is.False); + } - [Test] - public async Task IsIdentity() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "foo INTEGER", out var table); + [Test] + public async Task IsHidden() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER"); - using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - Assert.That(reader.GetColumnSchema().Single().IsIdentity, Is.False); - } + using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.That(reader.GetColumnSchema().Single().IsHidden, Is.False); + } - #endregion + #endregion - class SomeComposite - { - public int Foo { get; set; } - } + class SomeComposite + { + public int Foo { get; set; } + } - public ReaderNewSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } + public ReaderNewSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } - private async Task> GetColumnSchema(NpgsqlDataReader reader) - => IsAsync ? await reader.GetColumnSchemaAsync() : reader.GetColumnSchema(); - } + async Task> GetColumnSchema(NpgsqlDataReader reader) + => IsAsync ? await reader.GetColumnSchemaAsync() : reader.GetColumnSchema(); } diff --git a/test/Npgsql.Tests/ReaderOldSchemaTests.cs b/test/Npgsql.Tests/ReaderOldSchemaTests.cs index 5e2b790f11..edbeb15842 100644 --- a/test/Npgsql.Tests/ReaderOldSchemaTests.cs +++ b/test/Npgsql.Tests/ReaderOldSchemaTests.cs @@ -5,205 +5,220 @@ using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +/// +/// This tests the .NET Framework DbDataReader schema/metadata API, which returns DataTable. +/// For the new CoreCLR API, see . +/// +public class ReaderOldSchemaTests : SyncOrAsyncTestBase { - /// - /// This tests the .NET Framework DbDataReader schema/metadata API, which returns DataTable. - /// For the new CoreCLR API, see . - /// - public class ReaderOldSchemaTests : SyncOrAsyncTestBase + [Test] + public async Task Primary_key_composite() { - [Test] - public async Task PrimaryKeyFieldsMetadataSupport() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTableName(conn, out var table); + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} ( field_pk1 INT2 NOT NULL, field_pk2 INT2 NOT NULL, field_serial SERIAL, - CONSTRAINT data2_pkey PRIMARY KEY (field_pk1, field_pk2) + CONSTRAINT {table}_pk PRIMARY KEY (field_pk1, field_pk2) )"); - using var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var dr = command.ExecuteReader(CommandBehavior.KeyInfo); - dr.Read(); - var dataTable = await GetSchemaTable(dr); -#pragma warning disable 8602 // Warning should be removable after rc2 (https://github.com/dotnet/runtime/pull/42215) - DataRow[] keyColumns = - dataTable!.Rows.Cast().Where(r => (bool)r["IsKey"]).ToArray()!; -#pragma warning restore 8602 - Assert.That(keyColumns, Has.Length.EqualTo(2)); - Assert.That(keyColumns.Count(c => (string)c["ColumnName"] == "field_pk1"), Is.EqualTo(1)); - Assert.That(keyColumns.Count(c => (string)c["ColumnName"] == "field_pk2"), Is.EqualTo(1)); - } + using var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var dr = command.ExecuteReader(CommandBehavior.KeyInfo); + dr.Read(); + var dataTable = await GetSchemaTable(dr); + var keyColumns = dataTable!.Rows.Cast().Where(r => (bool)r["IsKey"]).ToArray()!; + Assert.That(keyColumns, Has.Length.EqualTo(2)); + Assert.That(keyColumns.Count(c => (string)c["ColumnName"] == "field_pk1"), Is.EqualTo(1)); + Assert.That(keyColumns.Count(c => (string)c["ColumnName"] == "field_pk2"), Is.EqualTo(1)); + } - [Test] - public async Task PrimaryKeyFieldMetadataSupport() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "id SERIAL PRIMARY KEY, serial SERIAL", out var table); - - using var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var dr = command.ExecuteReader(CommandBehavior.KeyInfo); - dr.Read(); - var metadata = await GetSchemaTable(dr); -#pragma warning disable 8602 // Warning should be removable after rc2 (https://github.com/dotnet/runtime/pull/42215) - var key = metadata!.Rows.Cast().Single(r => (bool)r["IsKey"])!; -#pragma warning restore 8602 - Assert.That(key["ColumnName"], Is.EqualTo("id")); - } + [Test] + public async Task Primary_key() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id SERIAL PRIMARY KEY, serial SERIAL"); + + using var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var dr = command.ExecuteReader(CommandBehavior.KeyInfo); + dr.Read(); + var metadata = await GetSchemaTable(dr); + var key = metadata!.Rows.Cast().Single(r => (bool)r["IsKey"])!; + Assert.That(key["ColumnName"], Is.EqualTo("id")); + } - [Test] - public async Task IsAutoIncrementMetadataSupport() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "id SERIAL PRIMARY KEY", out var table); - - var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); - - using var dr = command.ExecuteReader(CommandBehavior.KeyInfo); - var metadata = await GetSchemaTable(dr); -#pragma warning disable 8602 // Warning should be removable after rc2 (https://github.com/dotnet/runtime/pull/42215) - Assert.That(metadata!.Rows.Cast() - .Where(r => ((string)r["ColumnName"]).Contains("serial")) - .All(r => (bool)r["IsAutoIncrement"])); -#pragma warning restore 8602 - } + [Test] + public async Task IsAutoIncrement() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); - [Test] - public async Task IsReadOnlyMetadataSupport() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTableName(conn, out var table); - await using var __ = await GetTempViewName(conn, out var view); + var table = await CreateTempTable(conn, "serial SERIAL, int INT"); + + var command = new NpgsqlCommand($"SELECT serial, int, 8 FROM {table}", conn); + await using var reader = command.ExecuteReader(CommandBehavior.KeyInfo); + var rows = (await GetSchemaTable(reader))!.Rows; - await conn.ExecuteNonQueryAsync($@" + Assert.That(rows[0]["IsAutoIncrement"], Is.True); + Assert.That(rows[1]["IsAutoIncrement"], Is.False); + Assert.That(rows[2]["IsAutoIncrement"], Is.False); + } + + [Test] + public async Task IsAutoIncrement_identity() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); + MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); + + var table = + await CreateTempTable(conn, "identity1 INT GENERATED BY DEFAULT AS IDENTITY, identity2 INT GENERATED ALWAYS AS IDENTITY"); + + await using var command = new NpgsqlCommand($"SELECT identity1, identity2 FROM {table}", conn); + await using var reader = command.ExecuteReader(CommandBehavior.KeyInfo); + var rows = (await GetSchemaTable(reader))!.Rows; + + Assert.That(rows[0]["IsAutoIncrement"], Is.True); + Assert.That(rows[1]["IsAutoIncrement"], Is.True); + } + + [Test] + public async Task IsIdentity() + { + await using var conn = await OpenConnectionAsync(); + IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); + var table = await CreateTempTable( + conn, + "identity1 INT GENERATED BY DEFAULT AS IDENTITY, identity2 INT GENERATED ALWAYS AS IDENTITY, serial SERIAL, int INT"); + + await using var cmd = new NpgsqlCommand($"SELECT identity1, identity2, serial, int, 8 FROM {table}", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var rows = (await GetSchemaTable(reader))!.Rows; + + Assert.That(rows[0]["IsIdentity"], Is.True); + Assert.That(rows[1]["IsIdentity"], Is.True); + Assert.That(rows[2]["IsIdentity"], Is.False); + Assert.That(rows[3]["IsIdentity"], Is.False); + Assert.That(rows[4]["IsIdentity"], Is.False); + } + + [Test] + public async Task IsReadOnly() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + var view = await GetTempViewName(conn); + + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} (id SERIAL PRIMARY KEY, int2 SMALLINT); CREATE OR REPLACE VIEW {view} (id, int2) AS SELECT id, int2 + int2 AS int2 FROM {table}"); - var command = new NpgsqlCommand($"SELECT * FROM {view}", conn); - - using var dr = command.ExecuteReader(); - var metadata = await GetSchemaTable(dr); - - foreach (var r in metadata!.Rows.OfType()) - { - switch ((string)r["ColumnName"]) - { - case "field_pk": - if (conn.PostgreSqlVersion < new Version("9.4")) - { - // 9.3 and earlier: IsUpdatable = False - Assert.IsTrue((bool)r["IsReadonly"], "field_pk"); - } - else - { - // 9.4: IsUpdatable = True - Assert.IsFalse((bool)r["IsReadonly"], "field_pk"); - } - break; - case "field_int2": - Assert.IsTrue((bool)r["IsReadonly"]); - break; - } - } - } + var command = new NpgsqlCommand($"SELECT id, int2 FROM {view}", conn); - // ReSharper disable once InconsistentNaming - [Test] - public async Task AllowDBNull() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "nullable INTEGER, non_nullable INTEGER NOT NULL", out var table); - - using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - using var metadata = await GetSchemaTable(reader); - foreach (var row in metadata!.Rows.OfType()) - { - var isNullable = (bool)row["AllowDBNull"]; - switch ((string)row["ColumnName"]) - { - case "nullable": - Assert.IsTrue(isNullable); - continue; - case "non_nullable": - Assert.IsFalse(isNullable); - continue; - } - } - } + using var dr = command.ExecuteReader(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var metadata = await GetSchemaTable(dr); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1027")] - public async Task WithoutResult() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT 1", conn); - using var reader = await cmd.ExecuteReaderAsync(); - reader.NextResult(); - // We're no longer on a result - var table = await GetSchemaTable(reader); - Assert.That(table, Is.Null); - } + var idRow = metadata!.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "id"); + Assert.IsNotNull(idRow, "Unable to find metadata for id column"); + var int2Row = metadata.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "int2"); + Assert.IsNotNull(int2Row, "Unable to find metadata for int2 column"); + + Assert.IsFalse((bool)idRow!["IsReadonly"]); + Assert.IsTrue((bool)int2Row!["IsReadonly"]); + } + + // ReSharper disable once InconsistentNaming + [Test] + public async Task AllowDBNull() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "nullable INTEGER, non_nullable INTEGER NOT NULL"); + + using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + using var metadata = await GetSchemaTable(reader); - [Test] - public async Task PrecisionAndScale() + var nullableRow = metadata!.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "nullable"); + Assert.IsNotNull(nullableRow, "Unable to find metadata for nullable column"); + var nonNullableRow = metadata.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "non_nullable"); + Assert.IsNotNull(nonNullableRow, "Unable to find metadata for non_nullable column"); + + Assert.IsTrue((bool)nullableRow!["AllowDBNull"]); + Assert.IsFalse((bool)nonNullableRow!["AllowDBNull"]); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1027")] + public async Task Without_result() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1", conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.NextResult(); + // We're no longer on a result + var table = await GetSchemaTable(reader); + Assert.That(table, Is.Null); + } + + [Test] + public async Task Precision_and_scale() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1::NUMERIC AS result", conn); + using var reader = await cmd.ExecuteReaderAsync(); + var schemaTable = await GetSchemaTable(reader); + foreach (var myField in schemaTable!.Rows.OfType()) { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT 1::NUMERIC AS result", conn); - using var reader = await cmd.ExecuteReaderAsync(); - var schemaTable = await GetSchemaTable(reader); - foreach (var myField in schemaTable!.Rows.OfType()) - { - Assert.That(myField["NumericScale"], Is.EqualTo(0)); - Assert.That(myField["NumericPrecision"], Is.EqualTo(0)); - } + Assert.That(myField["NumericScale"], Is.EqualTo(0)); + Assert.That(myField["NumericPrecision"], Is.EqualTo(0)); } + } - [Test] - public async Task SchemaOnly([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - // if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - // return; + [Test] + public async Task SchemaOnly([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + // if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + // return; - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); - var query = $@" + var query = $@" SELECT 1 AS some_column; UPDATE {table} SET name='yo' WHERE 1=0; SELECT 1 AS some_other_column, 2"; - using var cmd = new NpgsqlCommand(query, conn); - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) - { - Assert.That(reader.Read(), Is.False); - var t = await GetSchemaTable(reader); - Assert.That(t!.Rows[0]["ColumnName"], Is.EqualTo("some_column")); - Assert.That(reader.NextResult(), Is.True); - Assert.That(reader.Read(), Is.False); - t = await GetSchemaTable(reader); - Assert.That(t!.Rows[0]["ColumnName"], Is.EqualTo("some_other_column")); - Assert.That(t.Rows[1]["ColumnName"], Is.EqualTo("?column?")); - Assert.That(reader.NextResult(), Is.False); - } - - // Close reader in the middle - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) - reader.Read(); + using var cmd = new NpgsqlCommand(query, conn); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) + { + Assert.That(reader.Read(), Is.False); + var t = await GetSchemaTable(reader); + Assert.That(t!.Rows[0]["ColumnName"], Is.EqualTo("some_column")); + Assert.That(reader.NextResult(), Is.True); + Assert.That(reader.Read(), Is.False); + t = await GetSchemaTable(reader); + Assert.That(t!.Rows[0]["ColumnName"], Is.EqualTo("some_other_column")); + Assert.That(t.Rows[1]["ColumnName"], Is.EqualTo("?column?")); + Assert.That(reader.NextResult(), Is.False); } - [Test] - public async Task BaseColumnName() - { - using var conn = OpenConnection(); + // Close reader in the middle + using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) + reader.Read(); + } + + [Test] + public async Task BaseColumnName() + { + using var conn = OpenConnection(); - conn.ExecuteNonQuery(@" + conn.ExecuteNonQuery(@" CREATE TEMP TABLE data ( Cod varchar(5) NOT NULL, Descr varchar(40), @@ -212,21 +227,20 @@ CONSTRAINT PK_test_Cod PRIMARY KEY (Cod) ); "); - var cmd = new NpgsqlCommand("SELECT Cod as CodAlias, Descr as DescrAlias, Date FROM data", conn); + var cmd = new NpgsqlCommand("SELECT Cod as CodAlias, Descr as DescrAlias, Date FROM data", conn); - using var dr = cmd.ExecuteReader(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); - var dt = await GetSchemaTable(dr); + using var dr = cmd.ExecuteReader(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var dt = await GetSchemaTable(dr); - Assert.That(dt!.Rows[0]["BaseColumnName"].ToString(), Is.EqualTo("cod")); - Assert.That(dt.Rows[0]["ColumnName"].ToString(), Is.EqualTo("codalias")); - Assert.That(dt.Rows[1]["BaseColumnName"].ToString(), Is.EqualTo("descr")); - Assert.That(dt.Rows[1]["ColumnName"].ToString(), Is.EqualTo("descralias")); - Assert.That(dt.Rows[2]["BaseColumnName"].ToString(), Is.EqualTo("date")); - Assert.That(dt.Rows[2]["ColumnName"].ToString(), Is.EqualTo("date")); - } + Assert.That(dt!.Rows[0]["BaseColumnName"].ToString(), Is.EqualTo("cod")); + Assert.That(dt.Rows[0]["ColumnName"].ToString(), Is.EqualTo("codalias")); + Assert.That(dt.Rows[1]["BaseColumnName"].ToString(), Is.EqualTo("descr")); + Assert.That(dt.Rows[1]["ColumnName"].ToString(), Is.EqualTo("descralias")); + Assert.That(dt.Rows[2]["BaseColumnName"].ToString(), Is.EqualTo("date")); + Assert.That(dt.Rows[2]["ColumnName"].ToString(), Is.EqualTo("date")); + } - public ReaderOldSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } + public ReaderOldSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } - private async Task GetSchemaTable(NpgsqlDataReader dr) => IsAsync ? await dr.GetSchemaTableAsync() : dr.GetSchemaTable(); - } + async Task GetSchemaTable(NpgsqlDataReader dr) => IsAsync ? await dr.GetSchemaTableAsync() : dr.GetSchemaTable(); } diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index bfd7c37efb..8562ed83b7 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -9,1018 +9,984 @@ using System.Threading; using System.Threading.Tasks; using Npgsql.BackendMessages; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; using Npgsql.Tests.Support; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; +using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.Default)] +[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.Default)] +[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.SequentialAccess)] +[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] +public class ReaderTests : MultiplexingTestBase { - [TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.Default)] - [TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.Default)] - [TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.SequentialAccess)] - [TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] - public class ReaderTests : MultiplexingTestBase - { - [Test] - public async Task SeekColumns() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 1,2,3", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - if (IsSequential) - Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); - else - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - Assert.That(reader.GetInt32(1), Is.EqualTo(2)); - if (IsSequential) - Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); - else - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - } - } + static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; + static uint ByteaOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Bytea).Value; - [Test] - public async Task NoResultSet() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "id INT", out var table); - - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} VALUES (8)", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); - Assert.That(reader.Read(), Is.False); - Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); - Assert.That(reader.FieldCount, Is.EqualTo(0)); - Assert.That(reader.NextResult(), Is.False); - Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); - } - - using (var cmd = new NpgsqlCommand($"SELECT 1; INSERT INTO {table} VALUES (8)", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - await reader.NextResultAsync(); - Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); - Assert.That(reader.Read(), Is.False); - Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); - Assert.That(reader.FieldCount, Is.EqualTo(0)); - } - } - } + [Test] + public async Task Resumable_non_consumed_to_non_resumable() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( "SELECT 'aaaaaaaa', 1", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + + await reader.IsDBNullAsync(0); // resumable, no consumption + _ = reader.IsDBNull(0); // resumable, no consumption + await using var stream = await reader.GetStreamAsync(0); // non-resumable + if (IsSequential) + Assert.That(() => reader.GetString(0), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Seek_columns() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1,2,3", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + if (IsSequential) + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + else + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.GetInt32(1), Is.EqualTo(2)); + if (IsSequential) + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + else + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + } - [Test] - public async Task EmptyResultSet() + [Test] + public async Task No_resultset() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); + + using (var cmd = new NpgsqlCommand($"INSERT INTO {table} VALUES (8)", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 1 AS foo WHERE FALSE", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader.Read(), Is.False); - Assert.That(reader.FieldCount, Is.EqualTo(1)); - Assert.That(reader.GetOrdinal("foo"), Is.EqualTo(0)); - Assert.That(() => reader[0], Throws.Exception.TypeOf()); - } + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.FieldCount, Is.EqualTo(0)); + Assert.That(reader.NextResult(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); } - [Test] - public async Task FieldCount() + using (var cmd = new NpgsqlCommand($"SELECT 1; INSERT INTO {table} VALUES (8)", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "int INT", out var table); - - using (var cmd = new NpgsqlCommand("SELECT 1; SELECT 2,3", conn)) - { - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader.FieldCount, Is.EqualTo(1)); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.FieldCount, Is.EqualTo(1)); - Assert.That(reader.Read(), Is.False); - Assert.That(reader.FieldCount, Is.EqualTo(1)); - Assert.That(reader.NextResult(), Is.True); - Assert.That(reader.FieldCount, Is.EqualTo(2)); - Assert.That(reader.NextResult(), Is.False); - Assert.That(reader.FieldCount, Is.EqualTo(0)); - } - - cmd.CommandText = $"INSERT INTO {table} (int) VALUES (1)"; - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - // Note MSDN docs that seem to say we should case -1 in this case: https://msdn.microsoft.com/en-us/library/system.data.idatarecord.fieldcount(v=vs.110).aspx - // But SqlClient returns 0 - Assert.That(() => reader.FieldCount, Is.EqualTo(0)); - - } - } - } + await reader.NextResultAsync(); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.FieldCount, Is.EqualTo(0)); } + } + + [Test] + public async Task Empty_resultset() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1 AS foo WHERE FALSE", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.GetOrdinal("foo"), Is.EqualTo(0)); + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + } - [Test] - public async Task RecordsAffected() + [Test] + public async Task FieldCount() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "int INT"); + + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2,3", conn); + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "int INT", out var table); - - var sb = new StringBuilder(); - for (var i = 0; i < 10; i++) - sb.Append($"INSERT INTO {table} (int) VALUES ({i});"); - sb.Append("SELECT 1;"); // Testing, that on close reader consumes all rows (as insert doesn't have a result set, but select does) - for (var i = 10; i < 15; i++) - sb.Append($"INSERT INTO {table} (int) VALUES ({i});"); - var cmd = new NpgsqlCommand(sb.ToString(), conn); - var reader = await cmd.ExecuteReaderAsync(Behavior); - reader.Close(); - Assert.That(reader.RecordsAffected, Is.EqualTo(15)); - - cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - reader = await cmd.ExecuteReaderAsync(Behavior); - reader.Close(); - Assert.That(reader.RecordsAffected, Is.EqualTo(-1)); - - cmd = new NpgsqlCommand($"UPDATE {table} SET int=int+1 WHERE int > 10", conn); - reader = await cmd.ExecuteReaderAsync(Behavior); - reader.Close(); - Assert.That(reader.RecordsAffected, Is.EqualTo(4)); - - cmd = new NpgsqlCommand($"UPDATE {table} SET int=8 WHERE int=666", conn); - reader = await cmd.ExecuteReaderAsync(Behavior); - reader.Close(); - Assert.That(reader.RecordsAffected, Is.EqualTo(0)); - - cmd = new NpgsqlCommand($"DELETE FROM {table} WHERE int > 10", conn); - reader = await cmd.ExecuteReaderAsync(Behavior); - reader.Close(); - Assert.That(reader.RecordsAffected, Is.EqualTo(4)); - } + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.NextResult(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(2)); + Assert.That(reader.NextResult(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(0)); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1037")] - public async Task Statements() + cmd.CommandText = $"INSERT INTO {table} (int) VALUES (1)"; + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - - var query = -$@"INSERT INTO {table} (name) VALUES ('a'); -UPDATE {table} SET name='b' WHERE name='doesnt_exist'; -UPDATE {table} SET name='b'; -BEGIN; -SELECT name FROM {table}; -DELETE FROM {table}; -COMMIT;"; - using var cmd = new NpgsqlCommand(query, conn); - using var reader = await cmd.ExecuteReaderAsync(Behavior); + // Note MSDN docs that seem to say we should case -1 in this case: https://msdn.microsoft.com/en-us/library/system.data.idatarecord.fieldcount(v=vs.110).aspx + // But SqlClient returns 0 + Assert.That(() => reader.FieldCount, Is.EqualTo(0)); - var i = 0; - Assert.That(reader.Statements, Has.Count.EqualTo(7)); - Assert.That(reader.Statements[i].SQL, Is.EqualTo($"INSERT INTO {table} (name) VALUES ('a')")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Insert)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(1)); - Assert.That(reader.Statements[++i].SQL, Is.EqualTo($"UPDATE {table} SET name='b' WHERE name='doesnt_exist'")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Update)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(0)); - Assert.That(reader.Statements[++i].SQL, Is.EqualTo($"UPDATE {table} SET name='b'")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Update)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(1)); - Assert.That(reader.Statements[++i].SQL, Is.EqualTo("BEGIN")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Other)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(0)); - await reader.NextResultAsync(); // Consume SELECT result set to parse the CommandComplete - Assert.That(reader.Statements[++i].SQL, Is.EqualTo($"SELECT name FROM {table}")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Select)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(1)); - Assert.That(reader.Statements[++i].SQL, Is.EqualTo($"DELETE FROM {table}")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Delete)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(1)); - Assert.That(reader.Statements[++i].SQL, Is.EqualTo("COMMIT")); - Assert.That(reader.Statements[i].StatementType, Is.EqualTo(StatementType.Other)); - Assert.That(reader.Statements[i].Rows, Is.EqualTo(0)); } + } - [Test] - public async Task StatementOID() - { - using var conn = await OpenConnectionAsync(); + [Test] + public async Task RecordsAffected() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "int INT"); + + var sb = new StringBuilder(); + for (var i = 0; i < 10; i++) + sb.Append($"INSERT INTO {table} (int) VALUES ({i});"); + sb.Append("SELECT 1;"); // Testing, that on close reader consumes all rows (as insert doesn't have a result set, but select does) + for (var i = 10; i < 15; i++) + sb.Append($"INSERT INTO {table} (int) VALUES ({i});"); + var cmd = new NpgsqlCommand(sb.ToString(), conn); + var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(15)); + + cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); + reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(-1)); + + cmd = new NpgsqlCommand($"UPDATE {table} SET int=int+1 WHERE int > 10", conn); + reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(4)); + + cmd = new NpgsqlCommand($"UPDATE {table} SET int=8 WHERE int=666", conn); + reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(0)); + + cmd = new NpgsqlCommand($"DELETE FROM {table} WHERE int > 10", conn); + reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(4)); + + if (conn.PostgreSqlVersion.IsGreaterOrEqual(15)) + { + cmd = new NpgsqlCommand($"MERGE INTO {table} S USING (SELECT 2 as int) T ON T.int = S.int WHEN MATCHED THEN UPDATE SET int = S.int", conn); + reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(1)); + } + } - MaximumPgVersionExclusive(conn, "12.0", -"Support for 'CREATE TABLE ... WITH OIDS' has been removed in 12.0. See https://www.postgresql.org/docs/12/release-12.html#id-1.11.6.5.4"); +#pragma warning disable CS0618 + [Test] + public async Task StatementOID_legacy_batching() + { + using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTableName(conn, out var table); + MaximumPgVersionExclusive(conn, "12.0", + "Support for 'CREATE TABLE ... WITH OIDS' has been removed in 12.0. See https://www.postgresql.org/docs/12/release-12.html#id-1.11.6.5.4"); - var query = $@" + var table = await GetTempTableName(conn); + + var query = $@" CREATE TABLE {table} (name TEXT) WITH OIDS; INSERT INTO {table} (name) VALUES ('a'); UPDATE {table} SET name='b' WHERE name='doesnt_exist';"; - using (var cmd = new NpgsqlCommand(query,conn)) - { - using var reader = await cmd.ExecuteReaderAsync(Behavior); - - Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); - Assert.That(reader.Statements[1].OID, Is.Not.EqualTo(0)); - Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); - } - - using (var cmd = new NpgsqlCommand($"SELECT name FROM {table}; DELETE FROM {table}", conn)) - { - using var reader = await cmd.ExecuteReaderAsync(Behavior); + using (var cmd = new NpgsqlCommand(query,conn)) + { + using var reader = await cmd.ExecuteReaderAsync(Behavior); - await reader.NextResultAsync(); // Consume SELECT result set - Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); - Assert.That(reader.Statements[1].OID, Is.EqualTo(0)); - } + Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); + Assert.That(reader.Statements[1].OID, Is.Not.EqualTo(0)); + Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); } - [Test] - public async Task GetStringWithParameter() + using (var cmd = new NpgsqlCommand($"SELECT name FROM {table}; DELETE FROM {table}", conn)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - const string text = "Random text"; - await conn.ExecuteNonQueryAsync($@"INSERT INTO {table} (name) VALUES ('{text}')"); - - var command = new NpgsqlCommand($"SELECT name FROM {table} WHERE name = :value;", conn); - var param = new NpgsqlParameter - { - ParameterName = "value", - DbType = DbType.String, - Size = text.Length, - Value = text - }; - //param.NpgsqlDbType = NpgsqlDbType.Text; - command.Parameters.Add(param); - - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var result = dr.GetString(0); - Assert.AreEqual(text, result); - } - } + using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.NextResultAsync(); // Consume SELECT result set + Assert.That(reader.Statements[0].OID, Is.EqualTo(0)); + Assert.That(reader.Statements[1].OID, Is.EqualTo(0)); } + } +#pragma warning restore CS0618 - [Test] - public async Task GetStringWithQuoteWithParameter() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table); - await conn.ExecuteNonQueryAsync($@" + [Test] + public async Task Get_string_with_parameter() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + const string text = "Random text"; + await conn.ExecuteNonQueryAsync($@"INSERT INTO {table} (name) VALUES ('{text}')"); + + var command = new NpgsqlCommand($"SELECT name FROM {table} WHERE name = :value;", conn); + var param = new NpgsqlParameter + { + ParameterName = "value", + DbType = DbType.String, + Size = text.Length, + Value = text + }; + //param.NpgsqlDbType = NpgsqlDbType.Text; + command.Parameters.Add(param); + + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + var result = dr.GetString(0); + Assert.AreEqual(text, result); + } + + [Test] + public async Task Get_string_with_quote_with_parameter() + { + using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + await conn.ExecuteNonQueryAsync($@" CREATE TABLE {table} (name TEXT); INSERT INTO {table} (name) VALUES ('Text with '' single quote');"); - const string test = "Text with ' single quote"; - var command = new NpgsqlCommand($"SELECT name FROM {table} WHERE name = :value;", conn); - - var param = new NpgsqlParameter(); - param.ParameterName = "value"; - param.DbType = DbType.String; - //param.NpgsqlDbType = NpgsqlDbType.Text; - param.Size = test.Length; - param.Value = test; - command.Parameters.Add(param); - - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var result = dr.GetString(0); - Assert.AreEqual(test, result); - } - } - } - - [Test] - public async Task GetValueByName() - { - using (var conn = await OpenConnectionAsync()) - { - using (var command = new NpgsqlCommand(@"SELECT 'Random text' AS real_column", conn)) - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - Assert.That(dr["real_column"], Is.EqualTo("Random text")); - Assert.That(() => dr["non_existing"], Throws.Exception.TypeOf()); - } - } - } - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/794")] - public async Task GetFieldType() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); - } - using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) - { - cmd.AllResultTypesAreUnknown = true; - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(string))); - } - } - } - } + const string test = "Text with ' single quote"; + var command = new NpgsqlCommand($"SELECT name FROM {table} WHERE name = :value;", conn); + + var param = new NpgsqlParameter(); + param.ParameterName = "value"; + param.DbType = DbType.String; + //param.NpgsqlDbType = NpgsqlDbType.Text; + param.Size = test.Length; + param.Value = test; + command.Parameters.Add(param); + + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + var result = dr.GetString(0); + Assert.AreEqual(test, result); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1096")] - public async Task GetFieldTypeSchemaOnly() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); - } - } - } + [Test] + public async Task Get_value_by_name() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand(@"SELECT 'Random text' AS real_column", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + Assert.That(dr["real_column"], Is.EqualTo("Random text")); + Assert.That(() => dr["non_existing"], Throws.Exception.TypeOf()); + } - [Test] - public async Task GetPostgresType() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/794")] + public async Task GetFieldType() + { + using var conn = await OpenConnectionAsync(); + using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: Fails"); - - using (var conn = await OpenConnectionAsync()) - { - PostgresType intType; - using (var cmd = new NpgsqlCommand(@"SELECT 1::INTEGER AS some_column", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - intType = (PostgresBaseType)reader.GetPostgresType(0); - Assert.That(intType.Namespace, Is.EqualTo("pg_catalog")); - Assert.That(intType.Name, Is.EqualTo("integer")); - Assert.That(intType.FullName, Is.EqualTo("pg_catalog.integer")); - Assert.That(intType.DisplayName, Is.EqualTo("integer")); - Assert.That(intType.InternalName, Is.EqualTo("int4")); - } - - using (var cmd = new NpgsqlCommand(@"SELECT '{1}'::INTEGER[] AS some_column", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - var intArrayType = (PostgresArrayType)reader.GetPostgresType(0); - Assert.That(intArrayType.Name, Is.EqualTo("integer[]")); - Assert.That(intArrayType.Element, Is.SameAs(intType)); - Assert.That(intArrayType.DisplayName, Is.EqualTo("integer[]")); - Assert.That(intArrayType.InternalName, Is.EqualTo("_int4")); - Assert.That(intType.Array, Is.SameAs(intArrayType)); - } - } + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); } - - /// - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/787")] - [TestCase("integer")] - [TestCase("real")] - [TestCase("integer[]")] - [TestCase("character varying(10)")] - [TestCase("character varying")] - [TestCase("character varying(10)[]")] - [TestCase("character(10)")] - [TestCase("character", "character(1)")] - [TestCase("numeric(1000, 2)")] - [TestCase("numeric(1000)")] - [TestCase("numeric")] - [TestCase("timestamp without time zone")] - [TestCase("timestamp(2) without time zone")] - [TestCase("timestamp(2) with time zone")] - [TestCase("time without time zone")] - [TestCase("time(2) without time zone")] - [TestCase("time(2) with time zone")] - [TestCase("interval")] - [TestCase("interval(2)")] - [TestCase("bit", "bit(1)")] - [TestCase("bit(3)")] - [TestCase("bit varying")] - [TestCase("bit varying(3)")] - public async Task GetDataTypeName(string typeName, string? normalizedName = null) + using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) { - if (normalizedName == null) - normalizedName = typeName; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT NULL::{typeName} AS some_column", conn)) + cmd.AllResultTypesAreUnknown = true; using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo(normalizedName)); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(string))); } } + } - [Test] - public async Task GetDataTypeNameEnum() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: ReloadTypes"); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1096")] + public async Task GetFieldType_SchemaOnly() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); + } - using (var conn = await OpenConnectionAsync()) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.my_enum AS ENUM ('one')"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT 'one'::my_enum", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".my_enum")); - } - } + [Test] + public async Task GetPostgresType() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: Fails"); + + using var conn = await OpenConnectionAsync(); + PostgresType intType; + using (var cmd = new NpgsqlCommand(@"SELECT 1::INTEGER AS some_column", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { + reader.Read(); + intType = (PostgresBaseType)reader.GetPostgresType(0); + Assert.That(intType.Namespace, Is.EqualTo("pg_catalog")); + Assert.That(intType.Name, Is.EqualTo("integer")); + Assert.That(intType.FullName, Is.EqualTo("pg_catalog.integer")); + Assert.That(intType.DisplayName, Is.EqualTo("integer")); + Assert.That(intType.InternalName, Is.EqualTo("int4")); + } + + using (var cmd = new NpgsqlCommand(@"SELECT '{1}'::INTEGER[] AS some_column", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { + reader.Read(); + var intArrayType = (PostgresArrayType)reader.GetPostgresType(0); + Assert.That(intArrayType.Name, Is.EqualTo("integer[]")); + Assert.That(intArrayType.Element, Is.SameAs(intType)); + Assert.That(intArrayType.DisplayName, Is.EqualTo("integer[]")); + Assert.That(intArrayType.InternalName, Is.EqualTo("_int4")); + Assert.That(intType.Array, Is.SameAs(intArrayType)); } + } - [Test] - public async Task GetDataTypeNameDomain() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: ReloadTypes"); + /// + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/787")] + [TestCase("integer")] + [TestCase("real")] + [TestCase("integer[]")] + [TestCase("character varying(10)")] + [TestCase("character varying")] + [TestCase("character varying(10)[]")] + [TestCase("character(10)")] + [TestCase("character")] + [TestCase("character(1)", "character")] + [TestCase("numeric(1000, 2)")] + [TestCase("numeric(1000)")] + [TestCase("numeric")] + [TestCase("timestamp without time zone")] + [TestCase("timestamp(2) without time zone")] + [TestCase("timestamp(2) with time zone")] + [TestCase("time without time zone")] + [TestCase("time(2) without time zone")] + [TestCase("time(2) with time zone")] + [TestCase("interval")] + [TestCase("interval(2)")] + [TestCase("bit", "bit(1)")] + [TestCase("bit(3)")] + [TestCase("bit varying")] + [TestCase("bit varying(3)")] + public async Task GetDataTypeName(string typeName, string? normalizedName = null) + { + if (normalizedName == null) + normalizedName = typeName; + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT NULL::{typeName} AS some_column", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo(normalizedName)); + } - using (var conn = await OpenConnectionAsync()) - { - conn.ExecuteNonQuery("CREATE DOMAIN pg_temp.my_domain AS VARCHAR(10)"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT 'one'::my_domain", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - // In the RowDescription, PostgreSQL sends the type OID of the underlying type and not of the domain. - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("character varying(10)")); - } - } - } + [Test] + public async Task GetDataTypeName_enum() + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS ENUM ('one')"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{typeName}")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/794")] - public async Task GetDataTypeNameTypesUnknown() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand(@"SELECT 1::INTEGER AS some_column", conn)) - { - cmd.AllResultTypesAreUnknown = true; - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); - } - } - } - } + [Test] + public async Task GetDataTypeName_domain() + { + await using var dataSource = CreateDataSource(csb => csb.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {typeName} AS VARCHAR(10)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + // In the RowDescription, PostgreSQL sends the type OID of the underlying type and not of the domain. + Assert.That(reader.GetDataTypeName(0), Is.EqualTo("character varying(10)")); + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/791")] - [IssueLink("https://github.com/npgsql/npgsql/issues/794")] - public async Task GetDataTypeOID() - { - using (var conn = await OpenConnectionAsync()) - { - var int4OID = await conn.ExecuteScalarAsync("SELECT oid FROM pg_type WHERE typname = 'int4'"); - using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetDataTypeOID(0), Is.EqualTo(int4OID)); - } - using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) - { - cmd.AllResultTypesAreUnknown = true; - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetDataTypeOID(0), Is.EqualTo(int4OID)); - } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/794")] + public async Task GetDataTypeNameTypes_unknown() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"SELECT 1::INTEGER AS some_column", conn); + cmd.AllResultTypesAreUnknown = true; + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); + } - [Test] - public async Task GetName() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/791")] + [IssueLink("https://github.com/npgsql/npgsql/issues/794")] + public async Task GetDataTypeOID() + { + using var conn = await OpenConnectionAsync(); + var int4OID = await conn.ExecuteScalarAsync("SELECT oid FROM pg_type WHERE typname = 'int4'"); + using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand(@"SELECT 1 AS some_column", conn)) - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - Assert.That(dr.GetName(0), Is.EqualTo("some_column")); - } - + reader.Read(); + Assert.That(reader.GetDataTypeOID(0), Is.EqualTo(int4OID)); } - - [Test] - public async Task GetFieldValueAsObject() + using (var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn)) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 'foo'::TEXT", conn)) + cmd.AllResultTypesAreUnknown = true; using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo("foo")); + Assert.That(reader.GetDataTypeOID(0), Is.EqualTo(int4OID)); } } + } - [Test] - public async Task GetValues() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand(@"SELECT 'hello', 1, '2014-01-01'::DATE", conn)) - { - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var values = new object[4]; - Assert.That(dr.GetValues(values), Is.EqualTo(3)); - Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new DateTime(2014, 1, 1), null })); - } - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var values = new object[2]; - Assert.That(dr.GetValues(values), Is.EqualTo(2)); - Assert.That(values, Is.EqualTo(new object[] { "hello", 1 })); - } - } - } + [Test] + public async Task GetName() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand(@"SELECT 1 AS some_column", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + Assert.That(dr.GetName(0), Is.EqualTo("some_column")); + } - [Test] - public async Task GetProviderSpecificValues() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand(@"SELECT 'hello', 1, '2014-01-01'::DATE", conn)) - { - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var values = new object[4]; - Assert.That(dr.GetProviderSpecificValues(values), Is.EqualTo(3)); - Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new NpgsqlDate(2014, 1, 1), null })); - } - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - var values = new object[2]; - Assert.That(dr.GetProviderSpecificValues(values), Is.EqualTo(2)); - Assert.That(values, Is.EqualTo(new object[] { "hello", 1 })); - } - } - } + [Test] + public async Task GetFieldValue_as_object() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 'foo'::TEXT", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo("foo")); + } - [Test] - public async Task ExecuteReaderGettingEmptyResultSetWithOutputParameter() + [Test] + public async Task GetValues() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand(@"SELECT 'hello', 1, '2014-01-01'::DATE", conn); + using (var dr = await command.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var command = new NpgsqlCommand($"SELECT * FROM {table} WHERE name = NULL;", conn); - var param = new NpgsqlParameter("some_param", NpgsqlDbType.Varchar); - param.Direction = ParameterDirection.Output; - command.Parameters.Add(param); - using (var dr = await command.ExecuteReaderAsync(Behavior)) - Assert.IsFalse(dr.NextResult()); - } + dr.Read(); + var values = new object[4]; + Assert.That(dr.GetValues(values), Is.EqualTo(3)); + Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new DateTime(2014, 1, 1), null })); } - - [Test] - public async Task GetValueFromEmptyResultset() + using (var dr = await command.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - using (var command = new NpgsqlCommand($"SELECT * FROM {table} WHERE name = :value;", conn)) - { - const string test = "Text single quote"; - var param = new NpgsqlParameter(); - param.ParameterName = "value"; - param.DbType = DbType.String; - //param.NpgsqlDbType = NpgsqlDbType.Text; - param.Size = test.Length; - param.Value = test; - command.Parameters.Add(param); - - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - // This line should throw the invalid operation exception as the datareader will - // have an empty resultset. - Assert.That(() => Console.WriteLine(dr.IsDBNull(1)), - Throws.Exception.TypeOf()); - } - } - } + dr.Read(); + var values = new object[2]; + Assert.That(dr.GetValues(values), Is.EqualTo(2)); + Assert.That(values, Is.EqualTo(new object[] { "hello", 1 })); } + } - [Test] - public async Task ReadPastDataReaderEnd() - { - using (var conn = await OpenConnectionAsync()) - { - var command = new NpgsqlCommand("SELECT 1", conn); - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - while (dr.Read()) {} - Assert.That(() => dr[0], Throws.Exception.TypeOf()); - } - } - } + [Test] + public async Task ExecuteReader_getting_empty_resultset_with_output_parameter() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + var command = new NpgsqlCommand($"SELECT * FROM {table} WHERE name = NULL;", conn); + var param = new NpgsqlParameter("some_param", NpgsqlDbType.Varchar); + param.Direction = ParameterDirection.Output; + command.Parameters.Add(param); + using var dr = await command.ExecuteReaderAsync(Behavior); + Assert.IsFalse(dr.NextResult()); + } + + [Test] + public async Task Get_value_from_empty_resultset() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + using var command = new NpgsqlCommand($"SELECT * FROM {table} WHERE name = :value;", conn); + const string test = "Text single quote"; + var param = new NpgsqlParameter(); + param.ParameterName = "value"; + param.DbType = DbType.String; + //param.NpgsqlDbType = NpgsqlDbType.Text; + param.Size = test.Length; + param.Value = test; + command.Parameters.Add(param); + + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + // This line should throw the invalid operation exception as the datareader will + // have an empty resultset. + Assert.That(() => Console.WriteLine(dr.IsDBNull(1)), + Throws.Exception.TypeOf()); + } + + [Test] + public async Task Read_past_reader_end() + { + using var conn = await OpenConnectionAsync(); + var command = new NpgsqlCommand("SELECT 1", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + while (dr.Read()) {} + Assert.That(() => dr[0], Throws.Exception.TypeOf()); + } - [Test] - public async Task SingleResult() + [Test] + public async Task Reader_dispose_state_does_not_leak() + { + if (IsMultiplexing || Behavior != CommandBehavior.Default) + return; + + var startReaderClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var continueReaderClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await using var dataSource = CreateDataSource(); + await using var conn1 = await dataSource.OpenConnectionAsync(); + var connID = conn1.Connector!.Id; + var readerCloseTask = Task.Run(async () => { - using (var conn = await OpenConnectionAsync()) + using var cmd = conn1.CreateCommand(); + cmd.CommandText = "SELECT 1"; + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection); + reader.ReaderClosed += (s, e) => { - var cmd = new NpgsqlCommand(@"SELECT 1; SELECT 2", conn); - var rdr = await cmd.ExecuteReaderAsync(CommandBehavior.SingleResult); - Assert.That(rdr.Read(), Is.True); - Assert.That(rdr.GetInt32(0), Is.EqualTo(1)); - Assert.That(rdr.NextResult(), Is.False); - } - } + startReaderClosedTcs.SetResult(new()); + continueReaderClosedTcs.Task.GetAwaiter().GetResult(); + }; + }); + + await startReaderClosedTcs.Task; + await using var conn2 = await dataSource.OpenConnectionAsync(); + Assert.That(conn2.Connector!.Id, Is.EqualTo(connID)); + using var cmd = conn2.CreateCommand(); + cmd.CommandText = "SELECT 1"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(reader.State, Is.EqualTo(ReaderState.BeforeResult)); + continueReaderClosedTcs.SetResult(new()); + await readerCloseTask; + Assert.That(reader.State, Is.EqualTo(ReaderState.BeforeResult)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/400")] - public async Task ExceptionThrownFromExecuteQuery([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + [Test] + public async Task SingleResult() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand(@"SELECT 1; SELECT 2", conn); + var reader = await command.ExecuteReaderAsync(CommandBehavior.SingleResult | Behavior); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.NextResult(), Is.False); + } - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/400")] + public async Task Exception_thrown_from_ExecuteReaderAsync([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; - await conn.ExecuteNonQueryAsync($@" + using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@" CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' LANGUAGE 'plpgsql'; "); - using (var cmd = new NpgsqlCommand($"SELECT {function}()", conn)) - { - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - Assert.That(async () => await cmd.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); - } - } - } + using var cmd = new NpgsqlCommand($"SELECT {function}()", conn); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + Assert.That(async () => await cmd.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1032")] - public async Task ExceptionThrownFromNextResult([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1032")] + public async Task Exception_thrown_from_NextResult([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); + using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' LANGUAGE 'plpgsql'; "); - using (var cmd = new NpgsqlCommand($"SELECT 1; SELECT {function}()", conn)) - { - if (prepare == PrepareOrNot.Prepared) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - Assert.That(() => reader.NextResult(), Throws.Exception.TypeOf()); - } - } - } + using var cmd = new NpgsqlCommand($"SELECT 1; SELECT {function}()", conn); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(() => reader.NextResult(), Throws.Exception.TypeOf()); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] - public async Task NpgsqlExceptionReferencesStatement() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] + public async Task NpgsqlException_references_BatchCommand_with_single_command() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); - await conn.ExecuteNonQueryAsync($@" + await conn.ExecuteNonQueryAsync($@" CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' -LANGUAGE 'plpgsql'; - "); +LANGUAGE 'plpgsql'"); - // Exception in single-statement command - using (var cmd = new NpgsqlCommand($"SELECT {function}()", conn)) - { - try - { - await cmd.ExecuteReaderAsync(Behavior); - Assert.Fail(); - } - catch (PostgresException e) - { - Assert.That(e.Statement, Is.SameAs(cmd.Statements[0])); - } - } - - // Exception in multi-statement command - using (var cmd = new NpgsqlCommand($"SELECT 1; {function}()", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - try - { - await reader.NextResultAsync(); - Assert.Fail(); - } - catch (PostgresException e) - { - Assert.That(e.Statement, Is.SameAs(cmd.Statements[1])); - } - } - } - } + // We use NpgsqlConnection.CreateCommand to test that the command isn't recycled when referenced in an exception + var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT {function}()"; - #region SchemaOnly + var exception = Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync(Behavior))!; + Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[0])); - [Test] - public async Task SchemaOnlyReturnsNoData() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) - Assert.That(reader.Read(), Is.False); - } + // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command + // resources are referenced by the exception above, which is very likely to escape the using statement of the command. + cmd.Dispose(); + var cmd2 = conn.CreateCommand(); + Assert.AreNotSame(cmd2, cmd); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] + public async Task NpgsqlException_references_BatchCommand_with_multiple_commands() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE OR REPLACE FUNCTION {function}() RETURNS VOID AS + 'BEGIN RAISE EXCEPTION ''testexception'' USING ERRCODE = ''12345''; END;' +LANGUAGE 'plpgsql'"); + + // We use NpgsqlConnection.CreateCommand to test that the command isn't recycled when referenced in an exception + var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT 1; {function}()"; - [Test] - public async Task SchemaOnlyCommandBehaviorSupportFunctioncall() + await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = GetTempFunctionName(conn, out var function); - - await conn.ExecuteNonQueryAsync($"CREATE OR REPLACE FUNCTION {function}() RETURNS SETOF integer as 'SELECT 1;' LANGUAGE 'sql';"); - var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; - using (var dr = await command.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) - { - var i = 0; - while (dr.Read()) - i++; - Assert.AreEqual(0, i); - } - } + var exception = Assert.ThrowsAsync(() => reader.NextResultAsync())!; + Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[1])); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2827")] - public async Task SchemaOnlyNextResultBeyondEnd() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "id INT", out var table); + // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command + // resources are referenced by the exception above, which is very likely to escape the using statement of the command. + cmd.Dispose(); + var cmd2 = conn.CreateCommand(); + Assert.AreNotSame(cmd2, cmd); + } - using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - Assert.False(reader.NextResult()); - Assert.False(reader.NextResult()); - } + #region SchemaOnly - #endregion SchemaOnly + [Test] + public async Task SchemaOnly_returns_no_data() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.That(reader.Read(), Is.False); + } - #region GetOrdinal + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2827")] + public async Task SchemaOnly_next_result_beyond_end() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT"); - [Test] - public async Task GetOrdinal() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand(@"SELECT 0, 1 AS some_column WHERE 1=0", conn)) - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader.GetOrdinal("some_column"), Is.EqualTo(1)); - Assert.That(() => reader.GetOrdinal("doesn't_exist"), Throws.Exception.TypeOf()); - } - } + using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.False(reader.NextResult()); + Assert.False(reader.NextResult()); + } - [Test] - public async Task GetOrdinalInsensitivity() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select 123 as FIELD1", conn)) - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.GetOrdinal("fieLd1"), Is.EqualTo(0)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4124")] + public async Task SchemaOnly_GetDataTypeName_with_unsupported_type() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"select aggfnoid from pg_aggregate", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + + Assert.That(reader.GetDataTypeName(0), Is.EqualTo("regproc")); + } + + #endregion SchemaOnly + + #region GetOrdinal + + [Test] + public async Task GetOrdinal() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand(@"SELECT 0, 1 AS some_column WHERE 1=0", conn); + using var reader = await command.ExecuteReaderAsync(Behavior); + Assert.That(reader.GetOrdinal("some_column"), Is.EqualTo(1)); + Assert.That(() => reader.GetOrdinal("doesn't_exist"), Throws.Exception.TypeOf()); + } + + [Test] + public async Task GetOrdinal_case_insensitive() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("select 123 as FIELD1", conn); + using var reader = await command.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.GetOrdinal("fieLd1"), Is.EqualTo(0)); + } + + [Test] + public async Task GetOrdinal_kana_insensitive() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("select 123 as ヲァィゥェォャ", conn); + using var reader = await command.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader["ヲァィゥェォャ"], Is.EqualTo(123)); + } + + #endregion GetOrdinal - [Test] - public async Task GetOrdinalKanaInsensitive() + [Test] + public async Task Field_index_does_not_exist() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT 1", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + Assert.That(() => dr[5], Throws.Exception.TypeOf()); + } + + [Test, Description("Performs some operations while a reader is still open and checks for exceptions")] + public async Task Reader_is_still_open() + { + await using var conn = await OpenConnectionAsync(); + // We might get the connection, on which the second command was already prepared, so prepare wouldn't start the UserAction + if (!IsMultiplexing) + conn.UnprepareAll(); + using var cmd1 = new NpgsqlCommand("SELECT 1", conn); + await using var reader1 = await cmd1.ExecuteReaderAsync(Behavior); + Assert.That(() => conn.ExecuteNonQuery("SELECT 1"), Throws.Exception.TypeOf()); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception.TypeOf()); + + using var cmd2 = new NpgsqlCommand("SELECT 2", conn); + Assert.That(() => cmd2.ExecuteReader(Behavior), Throws.Exception.TypeOf()); + if (!IsMultiplexing) + Assert.That(() => cmd2.Prepare(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Cleans_up_ok_with_dispose_calls([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT 1", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + dr.Read(); + dr.Close(); + + using var upd = conn.CreateCommand(); + upd.CommandText = "SELECT 1"; + if (prepare == PrepareOrNot.Prepared) + upd.Prepare(); + } + + [Test] + public async Task Null() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p2::TEXT", conn); + cmd.Parameters.Add(new NpgsqlParameter("p1", DbType.String) { Value = DBNull.Value }); + cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = DBNull.Value }); + + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + for (var i = 0; i < cmd.Parameters.Count; i++) { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select 123 as ヲァィゥェォャ", conn)) - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader["ヲァィゥェォャ"], Is.EqualTo(123)); - } + Assert.That(reader.IsDBNull(i), Is.True); + Assert.That(reader.IsDBNullAsync(i).Result, Is.True); + Assert.That(reader.GetValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(reader.GetFieldValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(() => reader.GetString(i), Throws.Exception.TypeOf()); } + } - #endregion GetOrdinal + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/742")] + [IssueLink("https://github.com/npgsql/npgsql/issues/800")] + [IssueLink("https://github.com/npgsql/npgsql/issues/1234")] + [IssueLink("https://github.com/npgsql/npgsql/issues/1898")] + public async Task HasRows([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; - [Test] - public async Task FieldIndexDoesntExist() + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + var command = new NpgsqlCommand($"SELECT 1; SELECT * FROM {table} WHERE name='does_not_exist'", conn); + if (prepare == PrepareOrNot.Prepared) + command.Prepare(); + using (var reader = await command.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT 1", conn)) - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - Assert.That(() => dr[5], Throws.Exception.TypeOf()); - } + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.HasRows, Is.True); + await reader.NextResultAsync(); + Assert.That(reader.HasRows, Is.False); } - [Test, Description("Performs some operations while a reader is still open and checks for exceptions")] - public async Task ReaderIsStillOpen() + command.CommandText = $"SELECT * FROM {table}"; + if (prepare == PrepareOrNot.Prepared) + command.Prepare(); + using (var reader = await command.ExecuteReaderAsync(Behavior)) { - await using var conn = await OpenConnectionAsync(); - // We might get the connection, on which the second command was already prepared, so prepare wouldn't start the UserAction - if (!IsMultiplexing) - conn.UnprepareAll(); - using var cmd1 = new NpgsqlCommand("SELECT 1", conn); - await using var reader1 = await cmd1.ExecuteReaderAsync(Behavior); - Assert.That(() => conn.ExecuteNonQuery("SELECT 1"), Throws.Exception.TypeOf()); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Throws.Exception.TypeOf()); - - using var cmd2 = new NpgsqlCommand("SELECT 2", conn); - Assert.That(() => cmd2.ExecuteReader(Behavior), Throws.Exception.TypeOf()); - if (!IsMultiplexing) - Assert.That(() => cmd2.Prepare(), Throws.Exception.TypeOf()); + reader.Read(); + Assert.That(reader.HasRows, Is.False); } - [Test] - public async Task CleansupOkWithDisposeCalls([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + command.CommandText = "SELECT 1"; + if (prepare == PrepareOrNot.Prepared) + command.Prepare(); + using (var reader = await command.ExecuteReaderAsync(Behavior)) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT 1", conn)) - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - dr.Read(); - dr.Close(); - - using (var upd = conn.CreateCommand()) - { - upd.CommandText = "SELECT 1"; - if (prepare == PrepareOrNot.Prepared) - upd.Prepare(); - } - } + reader.Read(); + reader.Close(); + Assert.That(() => reader.HasRows, Throws.Exception.TypeOf()); } - [Test] - public async Task Null() + command.CommandText = $"INSERT INTO {table} (name) VALUES ('foo'); SELECT * FROM {table}"; + if (prepare == PrepareOrNot.Prepared) + command.Prepare(); + using (var reader = await command.ExecuteReaderAsync(Behavior)) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2::TEXT", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", DbType.String) { Value = DBNull.Value }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = DBNull.Value }); - - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.IsDBNull(i), Is.True); - Assert.That(reader.IsDBNullAsync(i).Result, Is.True); - Assert.That(reader.GetValue(i), Is.EqualTo(DBNull.Value)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(DBNull.Value)); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(DBNull.Value)); - Assert.That(() => reader.GetString(i), Throws.Exception.TypeOf()); - } - } - } + Assert.That(reader.HasRows, Is.True); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/742")] - [IssueLink("https://github.com/npgsql/npgsql/issues/800")] - [IssueLink("https://github.com/npgsql/npgsql/issues/1234")] - [IssueLink("https://github.com/npgsql/npgsql/issues/1898")] - public async Task HasRows([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - - var command = new NpgsqlCommand($"SELECT 1; SELECT * FROM {table} WHERE name='does_not_exist'", conn); - if (prepare == PrepareOrNot.Prepared) - command.Prepare(); - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader.HasRows, Is.True); - Assert.That(reader.HasRows, Is.True); - Assert.That(reader.Read(), Is.True); - Assert.That(reader.HasRows, Is.True); - Assert.That(reader.Read(), Is.False); - Assert.That(reader.HasRows, Is.True); - await reader.NextResultAsync(); - Assert.That(reader.HasRows, Is.False); - } - - command.CommandText = $"SELECT * FROM {table}"; - if (prepare == PrepareOrNot.Prepared) - command.Prepare(); - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.HasRows, Is.False); - } - - command.CommandText = "SELECT 1"; - if (prepare == PrepareOrNot.Prepared) - command.Prepare(); - using (var reader = await command.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - reader.Close(); - Assert.That(() => reader.HasRows, Throws.Exception.TypeOf()); - } - - command.CommandText = $"INSERT INTO {table} (name) VALUES ('foo'); SELECT * FROM {table}"; - if (prepare == PrepareOrNot.Prepared) - command.Prepare(); - using (var reader = await command.ExecuteReaderAsync()) - { - Assert.That(reader.HasRows, Is.True); - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo("foo")); - } - - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + [Test] + public async Task HasRows_without_resultset() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + using var command = new NpgsqlCommand($"DELETE FROM {table} WHERE name = 'unknown'", conn); + using var reader = await command.ExecuteReaderAsync(Behavior); + Assert.IsFalse(reader.HasRows); + } + + [Test] + public async Task Interval_as_TimeSpan() + { + using var conn = await OpenConnectionAsync(); + using var command = new NpgsqlCommand("SELECT CAST('1 hour' AS interval) AS dauer", conn); + using var dr = await command.ExecuteReaderAsync(Behavior); + Assert.IsTrue(dr.HasRows); + Assert.IsTrue(dr.Read()); + Assert.IsTrue(dr.HasRows); + var ts = dr.GetTimeSpan(0); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5439")] + public async Task SequentialBufferedSeek() + { + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = """select v.i, jsonb_build_object(), current_timestamp + make_interval(0, 0, 0, 0, 0, 0, v.i), null::jsonb, '{"value": 42}'::jsonb from generate_series(1, 1000) as v(i)"""; + var rdr = await cmd.ExecuteReaderAsync(Behavior); + while (await rdr.ReadAsync()) { + var v1 = rdr[0]; + var v2 = rdr[1]; + //_ = rdr[2]; // uncomment line for successful execution + var v3 = rdr[3]; + var v4 = rdr[4]; } + } - [Test] - public async Task HasRowsWithoutResultset() - { - using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - using var command = new NpgsqlCommand($"DELETE FROM {table} WHERE name = 'unknown'", conn); - using var reader = await command.ExecuteReaderAsync(Behavior); - Assert.IsFalse(reader.HasRows); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5430")] + public async Task SequentialBufferedSeekLong() + { + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = """select v.i, repeat('1', 10), repeat('2', 10), repeat('3', 10), repeat('4', 10), 1, 2 from generate_series(1, 1000) as v(i)"""; + var rdr = await cmd.ExecuteReaderAsync(Behavior); + while (await rdr.ReadAsync()) + { + _ = rdr[0]; + _ = rdr[1]; + //_ = rdr[2]; + //_ = rdr[3]; + //_ = rdr[4]; + //_ = rdr[5]; // uncomment lines for successful execution + _ = rdr[6]; } + } - [Test] - public async Task IntervalAsTimeSpan() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT CAST('1 hour' AS interval) AS dauer", conn)) - using (var dr = await command.ExecuteReaderAsync(Behavior)) - { - Assert.IsTrue(dr.HasRows); - Assert.IsTrue(dr.Read()); - Assert.IsTrue(dr.HasRows); - var ts = dr.GetTimeSpan(0); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5430")] + public async Task SequentialBufferedSeekReread() + { + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = """select v.i, repeat('1', 10), repeat('2', 10), repeat('3', 10), repeat('4', 10), 1, NULL from generate_series(1, 1000) as v(i)"""; + var rdr = await cmd.ExecuteReaderAsync(Behavior); + while (await rdr.ReadAsync()) + { + _ = rdr[0]; + _ = rdr[1]; + //_ = rdr[2]; + //_ = rdr[3]; + //_ = rdr[4]; + //_ = rdr[5]; // uncomment lines for successful execution + _ = rdr.IsDBNull(6); + _ = rdr[6]; + Assert.True(rdr.IsDBNull(6)); } + } - [Test] - public async Task CloseConnectionInMiddleOfRow() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5484")] + public async Task GetFieldValueAsync_AsyncRead() + { + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + var expected = new string('a', conn.Settings.ReadBufferSize + 1); + cmd.CommandText = $"""select repeat('a', {conn.Settings.ReadBufferSize+1}) from generate_series(1, 1000)"""; + var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + while (await reader.ReadAsync()) { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT 1, 2", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - } - } + Assert.AreEqual(expected, await reader.GetFieldValueAsync(0)); } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/pull/1266")] - [Description("NextResult was throwing an ArgumentOutOfRangeException when trying to determine the statement to associate with the PostgresException")] - public async Task ReaderNextResultExceptionHandling() - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await GetTempTableName(conn, out var table1); - await using var __ = await GetTempTableName(conn, out var table2); - await using var ___ = GetTempFunctionName(conn, out var function); + [Test] + public async Task Close_connection_in_middle_of_row() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT 1, 2", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/pull/1266")] + [Description("NextResult was throwing an ArgumentOutOfRangeException when trying to determine the statement to associate with the PostgresException")] + public async Task Reader_next_result_exception_handling() + { + using var conn = await OpenConnectionAsync(); + var table1 = await GetTempTableName(conn); + var table2 = await GetTempTableName(conn); + var function = await GetTempFunctionName(conn); - var initializeTablesSql = $@" + var initializeTablesSql = $@" CREATE TABLE {table1} (value int NOT NULL); CREATE TABLE {table2} (value int UNIQUE); -ALTER TABLE ONLY {table1} ADD CONSTRAINT fkey FOREIGN KEY (value) REFERENCES {table2}(value) DEFERRABLE INITIALLY DEFERRED; -CREATE FUNCTION {function}(_value int) RETURNS int AS $BODY$ +ALTER TABLE ONLY {table1} ADD CONSTRAINT {table1}_{table2}_fk FOREIGN KEY (value) REFERENCES {table2}(value) DEFERRABLE INITIALLY DEFERRED; +CREATE OR REPLACE FUNCTION {function}(_value int) RETURNS int AS $BODY$ BEGIN INSERT INTO {table1}(value) VALUES(_value); RETURN _value; @@ -1028,1054 +994,1430 @@ public async Task ReaderNextResultExceptionHandling() $BODY$ LANGUAGE plpgsql VOLATILE"; - await conn.ExecuteNonQueryAsync(initializeTablesSql); - using (var cmd = new NpgsqlCommand($"SELECT {function}(1)", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - Assert.That(() => reader.NextResult(), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("23503")); - } - } + await conn.ExecuteNonQueryAsync(initializeTablesSql); + using var cmd = new NpgsqlCommand($"SELECT {function}(1)", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(() => reader.NextResult(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.ForeignKeyViolation)); + } + + [Test] + public async Task Invalid_cast() + { + using var conn = await OpenConnectionAsync(); + // Chunking type handler + using (var cmd = new NpgsqlCommand("SELECT 'foo'", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { + reader.Read(); + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); } + // Simple type handler + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { + reader.Read(); + Assert.That(() => reader.GetDateTime(0), Throws.Exception.TypeOf()); + } + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - [Test] - public async Task InvalidCast() + [Test, Description("Reads a lot of rows to make sure the long unoptimized path for Read() works")] + public async Task Many_reads() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT generate_series(1, {conn.Settings.ReadBufferSize})", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + for (var i = 1; i <= conn.Settings.ReadBufferSize; i++) { - using (var conn = await OpenConnectionAsync()) - { - // Chunking type handler - using (var cmd = new NpgsqlCommand("SELECT 'foo'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); - } - // Simple type handler - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(() => reader.GetDate(0), Throws.Exception.TypeOf()); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(i)); } + Assert.That(reader.Read(), Is.False); + } + + [Test] + public async Task Nullable_scalar() + { + // We read the same column multiple times + if (IsSequential) + return; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); + var p1 = new NpgsqlParameter { ParameterName = "p1", Value = DBNull.Value, NpgsqlDbType = NpgsqlDbType.Smallint }; + var p2 = new NpgsqlParameter { ParameterName = "p2", Value = (short)8 }; + Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); + Assert.That(p2.DbType, Is.EqualTo(DbType.Int16)); + cmd.Parameters.Add(p1); + cmd.Parameters.Add(p2); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + for (var i = 0; i < cmd.Parameters.Count; i++) + { + Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(short))); + Assert.That(reader.GetDataTypeName(i), Is.EqualTo("smallint")); + } + + Assert.That(() => reader.GetFieldValue(0), Is.EqualTo(DBNull.Value)); + Assert.That(() => reader.GetFieldValue(0), Throws.TypeOf()); + Assert.That(() => reader.GetFieldValue(0), Throws.Nothing); + Assert.That(reader.GetFieldValue(0), Is.Null); + + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2913")] + public async Task Bug2913_reading_previous_query_messages() + { + // No point in testing for multiplexing, as every query may use another connection + if (IsMultiplexing) + return; - [Test, Description("Reads a lot of rows to make sure the long unoptimized path for Read() works")] - public async Task ManyReads() + var firstMrs = new ManualResetEventSlim(false); + var secondMrs = new ManualResetEventSlim(false); + + var secondQuery = Task.Run(async () => { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT generate_series(1, {conn.Settings.ReadBufferSize})", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) + firstMrs.Wait(); + await using var secondConn = await OpenConnectionAsync(); + using var secondCmd = new NpgsqlCommand(@"SELECT 1; SELECT 2;", secondConn); + await using var secondReader = await secondCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + + // Check, that StatementIndex is equals to default value + Assert.That(secondReader.StatementIndex, Is.EqualTo(0)); + secondMrs.Wait(); + // Check, that the first query didn't change StatementIndex + Assert.That(secondReader.StatementIndex, Is.EqualTo(0)); + }); + + await using (var firstConn = await OpenConnectionAsync()) + { + // Executing a query, which fails with NpgsqlException on reader disposing, as NotExistingTable doesn't exist + using var firstCmd = new NpgsqlCommand(@"SELECT 1; SELECT * FROM NotExistingTable;", firstConn); + await using var firstReader = await firstCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + + Assert.That(firstReader.StatementIndex, Is.EqualTo(0)); + + firstReader.ReaderClosed += (s, e) => { - for (var i = 1; i <= conn.Settings.ReadBufferSize; i++) - { - Assert.That(reader.Read(), Is.True); - Assert.That(reader.GetInt32(0), Is.EqualTo(i)); - } - Assert.That(reader.Read(), Is.False); - } + // Starting a second query, which in case of a bug uses firstConn + firstMrs.Set(); + // Waiting for the second query to start executing + Thread.Sleep(100); + // After waiting, reader is free to reset prepared statements, which also increments StatementIndex + }; + + Assert.ThrowsAsync(firstReader.NextResultAsync); + + secondMrs.Set(); } - [Test] - public async Task NullableScalar() + await secondQuery; + + // If we're here and a bug is still not fixed, we fail while executing reader, as we're reading skipped messages for the second query + await using var thirdConn = OpenConnection(); + using var thirdCmd = new NpgsqlCommand(@"SELECT 1; SELECT 2;", thirdConn); + await using var thirdReader = await thirdCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/2913")] + [IssueLink("https://github.com/npgsql/npgsql/issues/3289")] + public async Task Reader_close_and_dispose() + { + await using var conn = await OpenConnectionAsync(); + using var cmd1 = conn.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + + var reader1 = await cmd1.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + await reader1.CloseAsync(); + + await conn.OpenAsync(); + cmd1.Connection = conn; + var reader2 = await cmd1.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + Assert.That(reader1, Is.Not.SameAs(reader2)); + Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + + await reader1.DisposeAsync(); + + Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/2964")] + public async Task Bug2964_connection_close_and_reader_dispose() + { + await using var conn = await OpenConnectionAsync(); + using var cmd1 = conn.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + + var reader1 = await cmd1.ExecuteReaderAsync(Behavior); + await conn.CloseAsync(); + await conn.OpenAsync(); + + var reader2 = await cmd1.ExecuteReaderAsync(Behavior); + Assert.That(reader1, Is.Not.SameAs(reader2)); + Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + + await reader1.DisposeAsync(); + + Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + } + + [Test] + public async Task Reader_reuse_on_dispose() + { + await using var conn = await OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; + + var reader1 = await cmd.ExecuteReaderAsync(Behavior); + await reader1.ReadAsync(); + await reader1.DisposeAsync(); + + var reader2 = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(reader1, Is.SameAs(reader2)); + await reader2.DisposeAsync(); + } + + [Test] + public async Task Unbound_reader_reuse() + { + await using var dataSource = CreateDataSource(csb => { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var p1 = new NpgsqlParameter { ParameterName = "p1", Value = DBNull.Value, NpgsqlDbType = NpgsqlDbType.Smallint }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = (short)8 }; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); - Assert.That(p2.DbType, Is.EqualTo(DbType.Int16)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(short))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("smallint")); - } - - Assert.That(() => reader.GetFieldValue(0), Is.EqualTo(DBNull.Value)); - Assert.That(() => reader.GetFieldValue(0), Throws.TypeOf()); - Assert.That(() => reader.GetFieldValue(0), Throws.Nothing); - Assert.That(reader.GetFieldValue(0), Is.Null); - - Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); - Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); - Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); - } - } + csb.MinPoolSize = 1; + csb.MaxPoolSize = 1; + }); + await using var conn1 = await dataSource.OpenConnectionAsync(); + using var cmd1 = conn1.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + var reader1 = await cmd1.ExecuteReaderAsync(Behavior); + await using (var __ = reader1) + { + Assert.That(async () => await reader1.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader1.GetInt32(0), Is.EqualTo(1)); + + await reader1.CloseAsync(); + await conn1.CloseAsync(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2913")] - public async Task ReaderReadingPreviousQueryMessagesBug() + await using var conn2 = await dataSource.OpenConnectionAsync(); + using var cmd2 = conn2.CreateCommand(); + cmd2.CommandText = "SELECT 2"; + var reader2 = await cmd2.ExecuteReaderAsync(Behavior); + await using (var __ = reader2) { - // No point in testing for multiplexing, as every query may use another connection - if (IsMultiplexing) - return; + Assert.That(async () => await reader2.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader2.GetInt32(0), Is.EqualTo(2)); + Assert.That(reader1, Is.Not.SameAs(reader2)); - var firstMrs = new ManualResetEventSlim(false); - var secondMrs = new ManualResetEventSlim(false); + await reader2.CloseAsync(); + await conn2.CloseAsync(); + } - var secondQuery = Task.Run(async () => - { - firstMrs.Wait(); - await using var secondConn = await OpenConnectionAsync(); - using var secondCmd = new NpgsqlCommand(@"SELECT 1; SELECT 2;", secondConn); - await using var secondReader = await secondCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); - - // Check, that StatementIndex is equals to default value - Assert.That(secondReader.StatementIndex, Is.EqualTo(0)); - secondMrs.Wait(); - // Check, that the first query didn't change StatementIndex - Assert.That(secondReader.StatementIndex, Is.EqualTo(0)); - }); - - await using (var firstConn = await OpenConnectionAsync()) - { - // Executing a query, which fails with NpgsqlException on reader disposing, as NotExistingTable doesn't exist - using var firstCmd = new NpgsqlCommand(@"SELECT 1; SELECT * FROM NotExistingTable;", firstConn); - await using var firstReader = await firstCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); + await using var conn3 = await dataSource.OpenConnectionAsync(); + using var cmd3 = conn3.CreateCommand(); + cmd3.CommandText = "SELECT 3"; + var reader3 = await cmd3.ExecuteReaderAsync(Behavior); + await using (var __ = reader3) + { + Assert.That(async () => await reader3.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader3.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader1, Is.SameAs(reader3)); - Assert.That(firstReader.StatementIndex, Is.EqualTo(0)); + await reader3.CloseAsync(); + await conn3.CloseAsync(); + } + } - firstReader.ReaderClosed += (s, e) => - { - // Starting a second query, which in case of a bug uses firstConn - firstMrs.Set(); - // Waiting for the second query to start executing - Thread.Sleep(100); - // After waiting, reader is free to reset prepared statements, which also increments StatementIndex - }; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3772")] + public async Task Bug3772() + { + if (!IsSequential) + return; - Assert.ThrowsAsync(firstReader.NextResultAsync); + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - secondMrs.Set(); - } + var pgMock = await postmasterMock.WaitForServerConnection(); + pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid), new FieldDescription(ByteaOid)); + + var intValue = new byte[] { 0, 0, 0, 1 }; + var byteValue = new byte[] { 1, 2, 3, 4 }; + + var writeBuffer = pgMock.WriteBuffer; + writeBuffer.WriteByte((byte)BackendMessageCode.DataRow); + writeBuffer.WriteInt32(4 + 2 + intValue.Length + byteValue.Length + 8); + writeBuffer.WriteInt16(2); + writeBuffer.WriteInt32(intValue.Length); + writeBuffer.WriteBytes(intValue); + await pgMock.FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int, some_byte FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + reader.GetInt32(0); + + Assert.Zero(reader.Connector.ReadBuffer.ReadBytesLeft); + Assert.NotZero(reader.Connector.ReadBuffer.ReadPosition); + + writeBuffer.WriteInt32(byteValue.Length); + writeBuffer.WriteBytes(byteValue); + await pgMock + .WriteDataRow(intValue, Enumerable.Range(1, 100).Select(x => (byte)x).ToArray()) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + await reader.GetFieldValueAsync(1); + + Assert.DoesNotThrowAsync(reader.ReadAsync); + } + + [Test] // #4377 + public async Task Dispose_does_not_swallow_exceptions([Values(true, false)] bool async) + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = IsMultiplexing ? await conn.BeginTransactionAsync() : null; + var pgMock = await postmasterMock.WaitForServerConnection(); + + if (IsMultiplexing) + pgMock + .WriteEmptyQueryResponse() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock); + + // Write responses for the query, but break the connection before sending CommandComplete/ReadyForQuery + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT 1", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + + pgMock.Close(); + + if (async) + Assert.Throws(() => reader.Dispose()); + else + Assert.ThrowsAsync(async () => await reader.DisposeAsync()); + } + + [Test] + public async Task Read_string_as_char() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 'abcdefgh', 'ijklmnop'"; + + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.IsTrue(await reader.ReadAsync()); + Assert.That(reader.GetChar(0), Is.EqualTo('a')); + if (Behavior == CommandBehavior.SequentialAccess) + Assert.Throws(() => reader.GetChar(0)); + else + Assert.That(reader.GetChar(0), Is.EqualTo('a')); + Assert.That(reader.GetChar(1), Is.EqualTo('i')); + } - await secondQuery; + #region GetBytes / GetStream - // If we're here and a bug is still not fixed, we fail while executing reader, as we're reading skipped messages for the second query - await using var thirdConn = OpenConnection(); - using var thirdCmd = new NpgsqlCommand(@"SELECT 1; SELECT 2;", thirdConn); - await using var thirdReader = await thirdCmd.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); - } + [Test] + public async Task GetBytes() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "bytes BYTEA"); + + // TODO: This is too small to actually test any interesting sequential behavior + byte[] expected = { 1, 2, 3, 4, 5 }; + var actual = new byte[expected.Length]; + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES ({EncodeByteaHex(expected)})"); + + var query = $"SELECT bytes, 'foo', bytes, 'bar', bytes, bytes FROM {table}"; + using var cmd = new NpgsqlCommand(query, conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + Assert.That(reader.GetBytes(0, 0, actual, 0, 2), Is.EqualTo(2)); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + if (IsSequential) + Assert.That(() => reader.GetBytes(0, 0, actual, 4, 1), + Throws.Exception.TypeOf(), "Seek back sequential"); + else + { + Assert.That(reader.GetBytes(0, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + } + Assert.That(reader.GetBytes(0, 2, actual, 2, 3), Is.EqualTo(3)); + Assert.That(actual, Is.EqualTo(expected)); + Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(reader.GetString(1), Is.EqualTo("foo")); + reader.GetBytes(2, 0, actual, 0, 2); + // Jump to another column from the middle of the column + reader.GetBytes(4, 0, actual, 0, 2); + Assert.That(reader.GetBytes(4, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), + "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); + Assert.That(() => reader.GetBytes(4, 0, actual, 0, actual.Length + 1), + Throws.Exception.TypeOf(), "Length great than output buffer length"); + // Close in the middle of a column + reader.GetBytes(5, 0, actual, 0, 2); + + //var result = (byte[]) cmd.ExecuteScalar(); + //Assert.AreEqual(2, result.Length); + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/2913")] - [IssueLink("https://github.com/npgsql/npgsql/issues/3289")] - public async Task ReaderCloseAndDisposeBug() - { - await using var conn = await OpenConnectionAsync(); - using var cmd1 = conn.CreateCommand(); - cmd1.CommandText = "SELECT 1"; + [Test] + public async Task GetStream_second_time_throws([Values(true, false)] bool isAsync) + { + var expected = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + var streamGetter = BuildStreamGetter(isAsync); - var reader1 = await cmd1.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); - await reader1.CloseAsync(); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT {EncodeByteaHex(expected)}::bytea", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); - await conn.OpenAsync(); - cmd1.Connection = conn; - var reader2 = await cmd1.ExecuteReaderAsync(Behavior | CommandBehavior.CloseConnection); - Assert.That(reader1, Is.Not.SameAs(reader2)); - Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + await reader.ReadAsync(); - await reader1.DisposeAsync(); + using var stream = await streamGetter(reader, 0); - Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); - } + Assert.That(async () => await streamGetter(reader, 0), + Throws.Exception.TypeOf()); + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/2964")] - public async Task ConnectionCloseAndReaderDisposeBug() - { - await using var conn = await OpenConnectionAsync(); - using var cmd1 = conn.CreateCommand(); - cmd1.CommandText = "SELECT 1"; + public static IEnumerable GetStreamCases() + { + var binary = MemoryMarshal + .AsBytes(Enumerable.Range(0, 1024).ToArray()) + .ToArray(); + yield return (binary, binary); + + var bigBinary = MemoryMarshal + .AsBytes(Enumerable.Range(0, 8193).ToArray()) + .ToArray(); + yield return (bigBinary, bigBinary); + + var bigint = 0xDEADBEEFL; + var bigintBinary = BitConverter.GetBytes( + BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(bigint) + : bigint); + yield return (bigint, bigintBinary); + } - var reader1 = await cmd1.ExecuteReaderAsync(Behavior); - await conn.CloseAsync(); - await conn.OpenAsync(); + [Test] + public async Task GetStream( + [Values] bool isAsync, + [ValueSource(nameof(GetStreamCases))] (T Generic, byte[] Binary) value) + { + var streamGetter = BuildStreamGetter(isAsync); + var expected = value.Binary; + var actual = new byte[expected.Length]; - var reader2 = await cmd1.ExecuteReaderAsync(Behavior); - Assert.That(reader1, Is.Not.SameAs(reader2)); - Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p, @p", conn) { Parameters = { new NpgsqlParameter("p", value.Generic) } }; + using var reader = await cmd.ExecuteReaderAsync(Behavior); - await reader1.DisposeAsync(); + await reader.ReadAsync(); - Assert.That(reader2.State, Is.EqualTo(ReaderState.BeforeResult)); - } + using var stream = await streamGetter(reader, 0); + Assert.That(stream.CanSeek, Is.EqualTo(Behavior == CommandBehavior.Default)); + Assert.That(stream.Length, Is.EqualTo(expected.Length)); - [Test] - public async Task ReaderReuseOnDispose() + var position = 0; + while (position < actual.Length) { - await using var conn = await OpenConnectionAsync(); - await using var tx = await conn.BeginTransactionAsync(); - using var cmd = conn.CreateCommand(); - cmd.CommandText = "SELECT 1"; - - var reader1 = await cmd.ExecuteReaderAsync(Behavior); - await reader1.ReadAsync(); - await reader1.DisposeAsync(); - - var reader2 = await cmd.ExecuteReaderAsync(Behavior); - Assert.That(reader1, Is.SameAs(reader2)); - await reader2.DisposeAsync(); + if (isAsync) + position += await stream.ReadAsync(actual, position, actual.Length - position); + else + position += stream.Read(actual, position, actual.Length - position); } - - [Test] - public async Task DisposeSwallowsExceptions([Values(true, false)] bool async) - { - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); - var pgMock = await postmasterMock.WaitForServerConnection(); - // Write responses for the query, but break the connection before sending CommandComplete/ReadyForQuery - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) - .FlushAsync(); + Assert.That(actual, Is.EqualTo(expected)); + } - using var cmd = new NpgsqlCommand("SELECT 1", conn); - using var reader = await cmd.ExecuteReaderAsync(); - await reader.ReadAsync(); + [Test] + public async Task Open_stream_when_changing_columns([Values(true, false)] bool isAsync) + { + var streamGetter = BuildStreamGetter(isAsync); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"SELECT @p, @p", conn); + var data = new byte[] { 1, 2, 3 }; + cmd.Parameters.Add(new NpgsqlParameter("p", data)); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + var stream = await streamGetter(reader, 0); + // ReSharper disable once UnusedVariable + var v = reader.GetValue(1); + Assert.That(() => stream.ReadByte(), Throws.Exception.TypeOf()); + } - pgMock.Close(); + [Test] + public async Task Open_stream_when_changing_rows([Values(true, false)] bool isAsync) + { + var streamGetter = BuildStreamGetter(isAsync); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"SELECT @p", conn); + var data = new byte[] { 1, 2, 3 }; + cmd.Parameters.Add(new NpgsqlParameter("p", data)); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + var s1 = await streamGetter(reader, 0); + reader.Read(); + Assert.That(() => s1.ReadByte(), Throws.Exception.TypeOf()); + } - if (async) - Assert.DoesNotThrow(() => reader.Dispose()); - else - Assert.DoesNotThrowAsync(async () => await reader.DisposeAsync()); - } + [Test] + public async Task GetBytes_with_null([Values(true, false)] bool isAsync) + { + var streamGetter = BuildStreamGetter(isAsync); + + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "bytes BYTEA"); + + var buf = new byte[8]; + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES (NULL)"); + using var cmd = new NpgsqlCommand($"SELECT bytes FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.IsDBNull(0), Is.True); + Assert.That(() => reader.GetBytes(0, 0, buf, 0, 1), Throws.Exception.TypeOf(), "GetBytes"); + Assert.That(async () => await streamGetter(reader, 0), Throws.Exception.TypeOf(), "GetStream"); + Assert.That(() => reader.GetBytes(0, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetBytes with null buffer"); + } - #region GetBytes / GetStream + static Func> BuildStreamGetter(bool isAsync) + => isAsync + ? (r, index) => r.GetStreamAsync(index) + : (r, index) => Task.FromResult(r.GetStream(index)); - [Test] - public async Task GetBytes() - { - using (var conn = await OpenConnectionAsync()) - { - await using var __ = await CreateTempTable(conn, "bytes BYTEA", out var table); - - // TODO: This is too small to actually test any interesting sequential behavior - byte[] expected = { 1, 2, 3, 4, 5 }; - var actual = new byte[expected.Length]; - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES ({EncodeByteaHex(expected)})"); - - var query = $"SELECT bytes, 'foo', bytes, 'bar', bytes, bytes FROM {table}"; - using (var cmd = new NpgsqlCommand(query, conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - - Assert.That(reader.GetBytes(0, 0, actual, 0, 2), Is.EqualTo(2)); - Assert.That(actual[0], Is.EqualTo(expected[0])); - Assert.That(actual[1], Is.EqualTo(expected[1])); - Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); - if (IsSequential) - Assert.That(() => reader.GetBytes(0, 0, actual, 4, 1), - Throws.Exception.TypeOf(), "Seek back sequential"); - else - { - Assert.That(reader.GetBytes(0, 0, actual, 4, 1), Is.EqualTo(1)); - Assert.That(actual[4], Is.EqualTo(expected[0])); - } - Assert.That(reader.GetBytes(0, 2, actual, 2, 3), Is.EqualTo(3)); - Assert.That(actual, Is.EqualTo(expected)); - Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); - - Assert.That(() => reader.GetBytes(1, 0, null, 0, 0), Throws.Exception.TypeOf(), - "GetBytes on non-bytea"); - Assert.That(() => reader.GetBytes(1, 0, actual, 0, 1), - Throws.Exception.TypeOf(), - "GetBytes on non-bytea"); - Assert.That(reader.GetString(1), Is.EqualTo("foo")); - reader.GetBytes(2, 0, actual, 0, 2); - // Jump to another column from the middle of the column - reader.GetBytes(4, 0, actual, 0, 2); - Assert.That(reader.GetBytes(4, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), - "Length greater than data length"); - Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); - Assert.That(() => reader.GetBytes(4, 0, actual, 0, actual.Length + 1), - Throws.Exception.TypeOf(), "Length great than output buffer length"); - // Close in the middle of a column - reader.GetBytes(5, 0, actual, 0, 2); - } - - //var result = (byte[]) cmd.ExecuteScalar(); - //Assert.AreEqual(2, result.Length); - } - } + [Test] + public async Task GetStream_after_consuming_column_throws([Values] bool async) + { + if (!IsSequential) + return; - [Test] - public async Task GetStreamSecondTimeThrows([Values(true, false)] bool isAsync) - { - var expected = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var streamGetter = BuildStreamGetter(isAsync); + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(@"SELECT '\xDEADBEEF'::bytea", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand($"SELECT {EncodeByteaHex(expected)}::bytea", conn); - using var reader = await cmd.ExecuteReaderAsync(Behavior); + _ = reader.GetFieldValue(0); - await reader.ReadAsync(); + if (async) + Assert.That(() => reader.GetStreamAsync(0), Throws.Exception.TypeOf()); + else + Assert.That(() => reader.GetStream(0), Throws.Exception.TypeOf()); + } - using var stream = await streamGetter(reader, 0); + [Test] + public async Task GetStream_in_middle_of_column_throws([Values] bool async) + { + if (!IsSequential) + return; - Assert.That(async () => await streamGetter(reader, 0), - Throws.Exception.TypeOf()); - } + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand(@"SELECT '\xDEADBEEF'::bytea", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - public static IEnumerable GetStreamCases() - { - var binary = MemoryMarshal - .AsBytes(Enumerable.Range(0, 1024).ToArray()) - .ToArray(); - yield return (binary, binary); - - var bigBinary = MemoryMarshal - .AsBytes(Enumerable.Range(0, 8193).ToArray()) - .ToArray(); - yield return (bigBinary, bigBinary); - - var bigint = 0xDEADBEEFL; - var bigintBinary = BitConverter.GetBytes( - BitConverter.IsLittleEndian - ? BinaryPrimitives.ReverseEndianness(bigint) - : bigint); - yield return (bigint, bigintBinary); - } + _ = reader.GetBytes(0, 0, new byte[2], 0, 2); - [Test] - public async Task GetStream( - [ValueSource(nameof(GetStreamCases))] (T Generic, byte[] Binary) value, - [Values(true, false)] bool isAsync) - { - var streamGetter = BuildStreamGetter(isAsync); - var expected = value.Binary; - var actual = new byte[expected.Length]; + if (async) + Assert.That(() => reader.GetStreamAsync(0), Throws.Exception.TypeOf()); + else + Assert.That(() => reader.GetStream(0), Throws.Exception.TypeOf()); + } - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p, @p", conn) { Parameters = { new NpgsqlParameter("p", value.Generic) } }; - using var reader = await cmd.ExecuteReaderAsync(Behavior); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5223")] + public async Task GetStream_seek() + { + // Sequential doesn't allow to seek + if (IsSequential) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 'abcdefgh'"; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var buffer = new byte[4]; + + await using var stream = reader.GetStream(0); + Assert.IsTrue(stream.CanSeek); + + var seekPosition = stream.Seek(-1, SeekOrigin.End); + Assert.That(seekPosition, Is.EqualTo(stream.Length - 1)); + var read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(1)); + Assert.That(Encoding.ASCII.GetString(buffer, 0, 1), Is.EqualTo("h")); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(0)); + + seekPosition = stream.Seek(2, SeekOrigin.Begin); + Assert.That(seekPosition, Is.EqualTo(2)); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("cdef")); + + seekPosition = stream.Seek(-3, SeekOrigin.Current); + Assert.That(seekPosition, Is.EqualTo(3)); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("defg")); + + stream.Position = 1; + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("bcde")); + } - await reader.ReadAsync(); + #endregion GetBytes / GetStream - using var stream = await streamGetter(reader, 0); - Assert.That(stream.CanSeek, Is.EqualTo(Behavior == CommandBehavior.Default)); - Assert.That(stream.Length, Is.EqualTo(expected.Length)); + #region GetChars / GetTextReader - var position = 0; - while (position < actual.Length) - position += await stream.ReadAsync(actual, position, actual.Length - position); + [Test] + public async Task GetChars() + { + using var conn = await OpenConnectionAsync(); + // TODO: This is too small to actually test any interesting sequential behavior + const string str = "ABCDE"; + var expected = str.ToCharArray(); + var actual = new char[expected.Length]; + + var queryText = $@"SELECT '{str}', 3, '{str}', 4, '{str}', '{str}', '{str}'"; + using var cmd = new NpgsqlCommand(queryText, conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + Assert.That(reader.GetChars(0, 0, actual, 0, 2), Is.EqualTo(2)); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + if (!IsSequential) + Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + // Note: Unlike with bytea, finding out the length of the column consumes it (variable-width + // UTF8 encoding) + Assert.That(reader.GetChars(2, 0, actual, 0, 2), Is.EqualTo(2)); + if (IsSequential) + Assert.That(() => reader.GetChars(2, 0, actual, 4, 1), Throws.Exception.TypeOf(), "Seek back sequential"); + else + { + Assert.That(reader.GetChars(2, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + } + Assert.That(reader.GetChars(2, 2, actual, 2, 3), Is.EqualTo(3)); + Assert.That(actual, Is.EqualTo(expected)); + //Assert.That(reader.GetChars(2, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(() => reader.GetChars(3, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetChars on non-text"); + Assert.That(() => reader.GetChars(3, 0, actual, 0, 1), Throws.Exception.TypeOf(), "GetChars on non-text"); + Assert.That(reader.GetInt32(3), Is.EqualTo(4)); + reader.GetChars(4, 0, actual, 0, 2); + // Jump to another column from the middle of the column + reader.GetChars(5, 0, actual, 0, 2); + Assert.That(reader.GetChars(5, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); + Assert.That(() => reader.GetChars(5, 0, actual, 0, actual.Length + 1), Throws.Exception.TypeOf(), "Length great than output buffer length"); + // Close in the middle of a column + reader.GetChars(6, 0, actual, 0, 2); + } - Assert.That(actual, Is.EqualTo(expected)); - } + [Test] + public async Task GetChars_AdvanceConsumed() + { + const string value = "01234567"; - [Test] - public async Task OpenStreamWhenChangingColumns([Values(true, false)] bool isAsync) - { - var streamGetter = BuildStreamGetter(isAsync); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT '{value}'", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(@"SELECT @p, @p", conn)) - { - var data = new byte[] { 1, 2, 3 }; - cmd.Parameters.Add(new NpgsqlParameter("p", data)); - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - var stream = await streamGetter(reader, 0); - // ReSharper disable once UnusedVariable - var v = reader.GetValue(1); - Assert.That(() => stream.ReadByte(), Throws.Exception.TypeOf()); - } - } - } + var buffer = new char[2]; + // Don't start at the beginning of the column. + reader.GetChars(0, 2, buffer, 0, 2); + reader.GetChars(0, 4, buffer, 0, 2); + reader.GetChars(0, 6, buffer, 0, 2); - [Test] - public async Task OpenStreamWhenChangingRows([Values(true, false)] bool isAsync) + // Ask for data past the start and the previous point, exercising restart logic. + if (!IsSequential) { - var streamGetter = BuildStreamGetter(isAsync); - - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(@"SELECT @p", conn)) - { - var data = new byte[] { 1, 2, 3 }; - cmd.Parameters.Add(new NpgsqlParameter("p", data)); - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - var s1 = await streamGetter(reader, 0); - reader.Read(); - Assert.That(() => s1.ReadByte(), Throws.Exception.TypeOf()); - } - } + reader.GetChars(0, 4, buffer, 0, 2); + reader.GetChars(0, 6, buffer, 0, 2); } + } - [Test] - public async Task GetBytesWithNull([Values(true, false)] bool isAsync) - { - var streamGetter = BuildStreamGetter(isAsync); - - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "bytes BYTEA", out var table); - - var buf = new byte[8]; - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES (NULL)"); - using (var cmd = new NpgsqlCommand($"SELECT bytes FROM {table}", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.IsDBNull(0), Is.True); - Assert.That(() => reader.GetBytes(0, 0, buf, 0, 1), Throws.Exception.TypeOf(), "GetBytes"); - Assert.That(async () => await streamGetter(reader, 0), Throws.Exception.TypeOf(), "GetStream"); - Assert.That(() => reader.GetBytes(0, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetBytes with null buffer"); - } - } - } + [Test] + public async Task GetTextReader([Values(true, false)] bool isAsync) + { + Func> textReaderGetter; + if (isAsync) + textReaderGetter = (r, index) => r.GetTextReaderAsync(index); + else + textReaderGetter = (r, index) => Task.FromResult(r.GetTextReader(index)); + + using var conn = await OpenConnectionAsync(); + // TODO: This is too small to actually test any interesting sequential behavior + const string str = "ABCDE"; + var expected = str.ToCharArray(); + var actual = new char[expected.Length]; + //ExecuteNonQuery(String.Format(@"INSERT INTO data (field_text) VALUES ('{0}')", str)); + + var queryText = $@"SELECT '{str}', 'foo'"; + using var cmd = new NpgsqlCommand(queryText, conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + var textReader = await textReaderGetter(reader, 0); + textReader.Read(actual, 0, 2); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + Assert.That(async () => await textReaderGetter(reader, 0), + Throws.Exception.TypeOf(), + "Sequential text reader twice on same column"); + textReader.Read(actual, 2, 1); + Assert.That(actual[2], Is.EqualTo(expected[2])); + textReader.Dispose(); + + if (IsSequential) + Assert.That(() => reader.GetChars(0, 0, actual, 4, 1), + Throws.Exception.TypeOf(), "Seek back sequential"); + else + { + Assert.That(reader.GetChars(0, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + } + Assert.That(reader.GetString(1), Is.EqualTo("foo")); + } - static Func> BuildStreamGetter(bool isAsync) - => isAsync - ? (Func>)((r, index) => r.GetStreamAsync(index)) - : (r, index) => Task.FromResult(r.GetStream(index)); + [Test] + public async Task TextReader_zero_length_column() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT ''"; - #endregion GetBytes / GetStream + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.IsTrue(await reader.ReadAsync()); - #region GetChars / GetTextReader + using var textReader = reader.GetTextReader(0); + Assert.That(textReader.Peek(), Is.EqualTo(-1)); + Assert.That(textReader.ReadToEnd(), Is.EqualTo(string.Empty)); + } - [Test] - public async Task GetChars() - { - using (var conn = await OpenConnectionAsync()) - { - // TODO: This is too small to actually test any interesting sequential behavior - const string str = "ABCDE"; - var expected = str.ToCharArray(); - var actual = new char[expected.Length]; - - var queryText = $@"SELECT '{str}', 3, '{str}', 4, '{str}', '{str}', '{str}'"; - using (var cmd = new NpgsqlCommand(queryText, conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - - Assert.That(reader.GetChars(0, 0, actual, 0, 2), Is.EqualTo(2)); - Assert.That(actual[0], Is.EqualTo(expected[0])); - Assert.That(actual[1], Is.EqualTo(expected[1])); - Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); - // Note: Unlike with bytea, finding out the length of the column consumes it (variable-width - // UTF8 encoding) - Assert.That(reader.GetChars(2, 0, actual, 0, 2), Is.EqualTo(2)); - if (IsSequential) - Assert.That(() => reader.GetChars(2, 0, actual, 4, 1), Throws.Exception.TypeOf(), "Seek back sequential"); - else - { - Assert.That(reader.GetChars(2, 0, actual, 4, 1), Is.EqualTo(1)); - Assert.That(actual[4], Is.EqualTo(expected[0])); - } - Assert.That(reader.GetChars(2, 2, actual, 2, 3), Is.EqualTo(3)); - Assert.That(actual, Is.EqualTo(expected)); - //Assert.That(reader.GetChars(2, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); - - Assert.That(() => reader.GetChars(3, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetChars on non-text"); - Assert.That(() => reader.GetChars(3, 0, actual, 0, 1), Throws.Exception.TypeOf(), "GetChars on non-text"); - Assert.That(reader.GetInt32(3), Is.EqualTo(4)); - reader.GetChars(4, 0, actual, 0, 2); - // Jump to another column from the middle of the column - reader.GetChars(5, 0, actual, 0, 2); - Assert.That(reader.GetChars(5, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), "Length greater than data length"); - Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); - Assert.That(() => reader.GetChars(5, 0, actual, 0, actual.Length + 1), Throws.Exception.TypeOf(), "Length great than output buffer length"); - // Close in the middle of a column - reader.GetChars(6, 0, actual, 0, 2); - } - } - } + [Test] + public async Task Open_TextReader_when_changing_columns() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"SELECT 'some_text', 'some_text'", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + var textReader = reader.GetTextReader(0); + // ReSharper disable once UnusedVariable + var v = reader.GetValue(1); + Assert.That(() => textReader.Peek(), Throws.Exception.TypeOf()); + } - [Test] - public async Task GetTextReader([Values(true, false)] bool isAsync) - { - Func> textReaderGetter; - if (isAsync) - textReaderGetter = (r, index) => r.GetTextReaderAsync(index); - else - textReaderGetter = (r, index) => Task.FromResult(r.GetTextReader(index)); + [Test] + public async Task Open_TextReader_when_changing_rows() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand(@"SELECT 'some_text', 'some_text'", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + var tr1 = reader.GetTextReader(0); + reader.Read(); + Assert.That(() => tr1.Peek(), Throws.Exception.TypeOf()); + } - using (var conn = await OpenConnectionAsync()) - { - // TODO: This is too small to actually test any interesting sequential behavior - const string str = "ABCDE"; - var expected = str.ToCharArray(); - var actual = new char[expected.Length]; - //ExecuteNonQuery(String.Format(@"INSERT INTO data (field_text) VALUES ('{0}')", str)); - - var queryText = $@"SELECT '{str}', 'foo'"; - using (var cmd = new NpgsqlCommand(queryText, conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - - var textReader = await textReaderGetter(reader, 0); - textReader.Read(actual, 0, 2); - Assert.That(actual[0], Is.EqualTo(expected[0])); - Assert.That(actual[1], Is.EqualTo(expected[1])); - Assert.That(async () => await textReaderGetter(reader, 0), - Throws.Exception.TypeOf(), - "Sequential text reader twice on same column"); - textReader.Read(actual, 2, 1); - Assert.That(actual[2], Is.EqualTo(expected[2])); - textReader.Dispose(); - - if (IsSequential) - Assert.That(() => reader.GetChars(0, 0, actual, 4, 1), - Throws.Exception.TypeOf(), "Seek back sequential"); - else - { - Assert.That(reader.GetChars(0, 0, actual, 4, 1), Is.EqualTo(1)); - Assert.That(actual[4], Is.EqualTo(expected[0])); - } - Assert.That(reader.GetString(1), Is.EqualTo("foo")); - } - } - } + [Test] + public async Task GetChars_when_null() + { + var buf = new char[8]; + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT NULL::TEXT", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + Assert.That(reader.IsDBNull(0), Is.True); + Assert.That(() => reader.GetChars(0, 0, buf, 0, 1), Throws.Exception.TypeOf(), "GetChars"); + Assert.That(() => reader.GetTextReader(0), Throws.Exception.TypeOf(), "GetTextReader"); + Assert.That(() => reader.GetChars(0, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetChars with null buffer"); + } - [Test] - public async Task OpenTextReaderWhenChangingColumns() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(@"SELECT 'some_text', 'some_text'", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - var textReader = reader.GetTextReader(0); - // ReSharper disable once UnusedVariable - var v = reader.GetValue(1); - Assert.That(() => textReader.Peek(), Throws.Exception.TypeOf()); - } - } + [Test] + public async Task Reader_is_reused() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: Fails"); - [Test] - public async Task OpenReaderWhenChangingRows() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(@"SELECT 'some_text', 'some_text'", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - var tr1 = reader.GetTextReader(0); - reader.Read(); - Assert.That(() => tr1.Peek(), Throws.Exception.TypeOf()); - } - } + using var conn = await OpenConnectionAsync(); + NpgsqlDataReader reader1; - [Test] - public async Task GetCharsWhenNull() + using (var cmd = new NpgsqlCommand("SELECT 8", conn)) + using (reader1 = await cmd.ExecuteReaderAsync(Behavior)) { - var buf = new char[8]; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT NULL::TEXT", conn)) - using (var reader = await cmd.ExecuteReaderAsync(Behavior)) - { - reader.Read(); - Assert.That(reader.IsDBNull(0), Is.True); - Assert.That(() => reader.GetChars(0, 0, buf, 0, 1), Throws.Exception.TypeOf(), "GetChars"); - Assert.That(() => reader.GetTextReader(0), Throws.Exception.TypeOf(), "GetTextReader"); - Assert.That(() => reader.GetChars(0, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetChars with null buffer"); - } + reader1.Read(); + Assert.That(reader1.GetInt32(0), Is.EqualTo(8)); } - [Test] - public async Task ReaderIsReused() + using (var cmd = new NpgsqlCommand("SELECT 9", conn)) + using (var reader2 = await cmd.ExecuteReaderAsync(Behavior)) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: Fails"); - - using (var conn = await OpenConnectionAsync()) - { - NpgsqlDataReader reader1; - - using (var cmd = new NpgsqlCommand("SELECT 8", conn)) - using (reader1 = await cmd.ExecuteReaderAsync(Behavior)) - { - reader1.Read(); - Assert.That(reader1.GetInt32(0), Is.EqualTo(8)); - } - - using (var cmd = new NpgsqlCommand("SELECT 9", conn)) - using (var reader2 = await cmd.ExecuteReaderAsync(Behavior)) - { - Assert.That(reader2, Is.SameAs(reader1)); - reader2.Read(); - Assert.That(reader2.GetInt32(0), Is.EqualTo(9)); - } - } + Assert.That(reader2, Is.SameAs(reader1)); + reader2.Read(); + Assert.That(reader2.GetInt32(0), Is.EqualTo(9)); } + } - #endregion GetChars / GetTextReader + [Test] + public async Task GetTextReader_after_consuming_column_throws([Values] bool async) + { + if (!IsSequential) + return; -#if DEBUG - [Test, Description("Tests that everything goes well when a type handler generates a NpgsqlSafeReadException")] - [Timeout(5000)] - public async Task SafeReadException() - { - if (IsMultiplexing) - return; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 'foo'", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - using (var conn = await OpenConnectionAsync()) - { - // Temporarily reroute integer to go to a type handler which generates SafeReadExceptions - conn.TypeMapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "integer", - TypeHandlerFactory = new ExplodingTypeHandlerFactory(true) - }.Build()); - using (var cmd = new NpgsqlCommand(@"SELECT 1, 'hello'", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(() => reader.GetInt32(0), - Throws.Exception.With.Message.EqualTo("Safe read exception as requested")); - Assert.That(reader.GetString(1), Is.EqualTo("hello")); - } - } - } + _ = reader.GetString(0); - [Test, Description("Tests that when a type handler generates an exception that isn't a NpgsqlSafeReadException, the connection is properly broken")] - [Timeout(5000)] - public async Task NonSafeReadException() - { - if (IsMultiplexing) - return; + if (async) + Assert.That(() => reader.GetTextReaderAsync(0), Throws.Exception.TypeOf()); + else + Assert.That(() => reader.GetTextReader(0), Throws.Exception.TypeOf()); + } - using (var conn = await OpenConnectionAsync()) - { - // Temporarily reroute integer to go to a type handler which generates some exception - conn.TypeMapper.AddMapping(new NpgsqlTypeMappingBuilder() - { - PgTypeName = "integer", - TypeHandlerFactory = new ExplodingTypeHandlerFactory(false) - }.Build()); - using (var cmd = new NpgsqlCommand(@"SELECT 1, 'hello'", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(() => reader.GetInt32(0), - Throws.Exception.With.Message.EqualTo("Non-safe read exception as requested")); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); - } - } - } -#endif + [Test] + public async Task GetTextReader_in_middle_of_column_throws([Values] bool async) + { + if (!IsSequential) + return; - #region Cancellation + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 'foo'", conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - [Test, Description("Cancels ReadAsync via the cancellation token, with successful PG cancellation")] - public async Task ReadAsync_cancel_soft() - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + _ = reader.GetChars(0, 0, new char[2], 0, 2); - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + if (async) + Assert.That(() => reader.GetTextReaderAsync(0), Throws.Exception.TypeOf()); + else + Assert.That(() => reader.GetTextReader(0), Throws.Exception.TypeOf()); + } - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) - .FlushAsync(); + #endregion GetChars / GetTextReader - using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); - await using (var reader = await cmd.ExecuteReaderAsync()) - { - // Successfully read the first row - Assert.True(await reader.ReadAsync()); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5450")] + public async Task EndRead_StreamActive([Values]bool async) + { + if (IsMultiplexing) + return; - // Attempt to read the second row - simulate blocking and cancellation - var cancellationSource = new CancellationTokenSource(); - var task = reader.ReadAsync(cancellationSource.Token); - cancellationSource.Cancel(); + const int columnLength = 1; - var (processId, _) = await postmasterMock.WaitForCancellationRequest(); - Assert.That(processId, Is.EqualTo(conn.ProcessID)); + await using var conn = await OpenConnectionAsync(); + var buffer = conn.Connector!.ReadBuffer; + buffer.FilledBytes += columnLength; + var reader = buffer.PgReader; + reader.Init(columnLength, DataFormat.Binary, resumable: false); + if (async) + await reader.StartReadAsync(Size.Unknown, CancellationToken.None); + else + reader.StartRead(Size.Unknown); - await pgMock - .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) - .WriteReadyForQuery() - .FlushAsync(); + await using (var _ = reader.GetStream()) + { + if (async) + Assert.DoesNotThrowAsync(async () => await reader.EndReadAsync()); + else + Assert.DoesNotThrow(() => reader.EndRead()); + } - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, - Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); - Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + reader.Commit(resuming: false); + } - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); - } + [Test, Description("Tests that everything goes well when a type handler generates a NpgsqlSafeReadException")] + public async Task SafeReadException() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + // Temporarily reroute integer to go to a type handler which generates SafeReadExceptions + dataSourceBuilder.AddTypeInfoResolverFactory(new ExplodingTypeHandlerResolverFactory(safe: true)); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand(@"SELECT 1, 'hello'", connection); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + Assert.That(() => reader.GetInt32(0), + Throws.Exception.With.Message.EqualTo("Safe read exception as requested")); + Assert.That(reader.GetString(1), Is.EqualTo("hello")); + } - await pgMock.WriteScalarResponseAndFlush(1); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + [Test, Description("Tests that when a type handler generates an exception that isn't a NpgsqlSafeReadException, the connection is properly broken")] + public async Task Non_SafeReadException() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + // Temporarily reroute integer to go to a type handler which generates some exception + dataSourceBuilder.AddTypeInfoResolverFactory(new ExplodingTypeHandlerResolverFactory(safe: false)); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var cmd = new NpgsqlCommand(@"SELECT 1, 'hello'", connection); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await reader.ReadAsync(); + Assert.That(() => reader.GetInt32(0), Throws.Exception.With.Message.EqualTo("Broken")); + Assert.That(connection.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); + } - [Test, Description("Cancels NextResultAsync via the cancellation token, with successful PG cancellation")] - public async Task NextResult_cancel_soft() - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + #region Cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + [Test, Description("Cancels ReadAsync via the NpgsqlCommand.Cancel, with successful PG cancellation")] + public async Task ReadAsync_cancel_command_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - // Write responses to the query we're about to send, only for the first resultset (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) - .WriteCommandComplete() - .FlushAsync(); + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); - await using (var reader = await cmd.ExecuteReaderAsync()) - { - // Successfully read the first resultset - Assert.True(await reader.ReadAsync()); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); - // Attempt to advance to the second resultset - simulate blocking and cancellation - var cancellationSource = new CancellationTokenSource(); - var task = reader.NextResultAsync(cancellationSource.Token); - cancellationSource.Cancel(); + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { + // Successfully read the first row + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - var (processId, _) = await postmasterMock.WaitForCancellationRequest(); - Assert.That(processId, Is.EqualTo(conn.ProcessID)); + // Attempt to read the second row - simulate blocking and cancellation + var task = reader.ReadAsync(); + cmd.Cancel(); - await pgMock - .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) - .WriteReadyForQuery() - .FlushAsync(); + var processId = (await postmasterMock.WaitForCancellationRequest()).ProcessId; + Assert.That(processId, Is.EqualTo(conn.ProcessID)); - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, - Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); - Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); - } + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); - await pgMock.WriteScalarResponseAndFlush(1); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); } - [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] - public async Task ReadAsync_cancel_hard([Values(true, false)] bool passCancelledToken) - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + await pgMock.WriteScalarResponseAndFlush(1); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + [Test, Description("Cancels ReadAsync via the cancellation token, with successful PG cancellation")] + public async Task ReadAsync_cancel_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) - .FlushAsync(); + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { // Successfully read the first row Assert.True(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to read the second row - simulate blocking and cancellation var cancellationSource = new CancellationTokenSource(); - if (passCancelledToken) - cancellationSource.Cancel(); var task = reader.ReadAsync(cancellationSource.Token); cancellationSource.Cancel(); - var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + var processId = (await postmasterMock.WaitForCancellationRequest()).ProcessId; Assert.That(processId, Is.EqualTo(conn.ProcessID)); - // Send no response from server, wait for the cancellation attempt to time out - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.TypeOf()); + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); } - [Test, Description("Cancels NextResultAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] - public async Task NextResultAsync_cancel_hard([Values(true, false)] bool passCancelledToken) - { - if (IsMultiplexing) - return; // Multiplexing, cancellation - - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); - - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) - .WriteCommandComplete() - .FlushAsync(); - - using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + await pgMock.WriteScalarResponseAndFlush(1); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + [Test, Description("Cancels NextResultAsync via the cancellation token, with successful PG cancellation")] + public async Task NextResult_cancel_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + // Write responses to the query we're about to send, only for the first resultset (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .WriteCommandComplete() + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) + { // Successfully read the first resultset Assert.True(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - // Attempt to read the second row - simulate blocking and cancellation + // Attempt to advance to the second resultset - simulate blocking and cancellation var cancellationSource = new CancellationTokenSource(); - if (passCancelledToken) - cancellationSource.Cancel(); var task = reader.NextResultAsync(cancellationSource.Token); cancellationSource.Cancel(); - var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + var processId = (await postmasterMock.WaitForCancellationRequest()).ProcessId; Assert.That(processId, Is.EqualTo(conn.ProcessID)); - // Send no response from server, wait for the cancellation attempt to time out - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.TypeOf()); + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); } - [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] - public async Task GetFieldValueAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + await pgMock.WriteScalarResponseAndFlush(1); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } - if (!IsSequential) - return; + [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] + public async Task ReadAsync_cancel_hard([Values(true, false)] bool passCancelledToken) + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + // Successfully read the first row + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + if (passCancelledToken) + cancellationSource.Cancel(); + var task = reader.ReadAsync(cancellationSource.Token); + cancellationSource.Cancel(); - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + var processId = (await postmasterMock.WaitForCancellationRequest()).ProcessId; + Assert.That(processId, Is.EqualTo(conn.ProcessID)); - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) - .WriteDataRowWithFlush(new byte[10000]); + // Send no response from server, wait for the cancellation attempt to time out + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); - using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - await reader.ReadAsync(); + [Test, Description("Cancels NextResultAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] + public async Task NextResultAsync_cancel_hard([Values(true, false)] bool passCancelledToken) + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .WriteCommandComplete() + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + // Successfully read the first resultset + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + if (passCancelledToken) + cancellationSource.Cancel(); + var task = reader.NextResultAsync(cancellationSource.Token); + cancellationSource.Cancel(); - using var cts = new CancellationTokenSource(); - if (passCancelledToken) - cts.Cancel(); - var task = reader.GetFieldValueAsync(0, cts.Token); - cts.Cancel(); + var processId = (await postmasterMock.WaitForCancellationRequest()).ProcessId; + Assert.That(processId, Is.EqualTo(conn.ProcessID)); - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.Null); + // Send no response from server, wait for the cancellation attempt to time out + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] - public async Task IsDBNullAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] + public async Task GetFieldValueAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - if (!IsSequential) - return; + if (!IsSequential) + return; - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRowWithFlush(new byte[10000], new byte[4]); + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid)) + .WriteDataRowWithFlush(new byte[10000]); - using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); - await reader.ReadAsync(); + await reader.ReadAsync(); - using var cts = new CancellationTokenSource(); - if (passCancelledToken) - cts.Cancel(); - var task = reader.IsDBNullAsync(1, cts.Token); + using var cts = new CancellationTokenSource(); + if (passCancelledToken) cts.Cancel(); + var task = reader.GetFieldValueAsync(0, cts.Token); + cts.Cancel(); - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.Null); + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.Null); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - #endregion Cancellation + [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] + public async Task IsDBNullAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - #region Timeout + if (!IsSequential) + return; - [Test, Description("Timeouts sequential ReadAsGetFieldValueAsync")] - [Timeout(10000)] - public async Task GetFieldValueAsync_sequential_timeout() - { - if (IsMultiplexing) - return; // Multiplexing, cancellation + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - if (!IsSequential) - return; + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid), new FieldDescription(Int4Oid)) + .WriteDataRowWithFlush(new byte[10000], new byte[4]); - var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - csb.CommandTimeout = 3; - csb.CancellationTimeout = 15000; + using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); - await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + await reader.ReadAsync(); - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) - .WriteDataRowWithFlush(new byte[10000]); + using var cts = new CancellationTokenSource(); + if (passCancelledToken) + cts.Cancel(); + var task = reader.IsDBNullAsync(1, cts.Token); + cts.Cancel(); - using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.Null); - await reader.ReadAsync(); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, Description("Cancellation does not work with the multiplexing")] + public async Task Cancel_multiplexing_disabled() + { + if (!IsMultiplexing) + return; + + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT generate_series(1, 100); SELECT generate_series(1, 100)", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + var cancelledToken = new CancellationToken(canceled: true); + Assert.IsTrue(await reader.ReadAsync()); + while (await reader.ReadAsync(cancelledToken)) { } + Assert.IsTrue(await reader.NextResultAsync(cancelledToken)); + while (await reader.ReadAsync(cancelledToken)) { } + Assert.IsFalse(conn.Connector!.UserCancellationRequested); + } - var task = reader.GetFieldValueAsync(0); + #endregion Cancellation - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.TypeOf()); + #region Timeout - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + [Test, Description("Timeouts sequential ReadAsGetFieldValueAsync")] + public async Task GetFieldValueAsync_sequential_timeout() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - [Test, Description("Timeouts sequential IsDBNullAsync")] - [Timeout(10000)] - public async Task IsDBNullAsync_sequential_timeout() + if (!IsSequential) + return; + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { - if (IsMultiplexing) - return; // Multiplexing, cancellation + CommandTimeout = 3, + CancellationTimeout = 15000 + }; - if (!IsSequential) - return; + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - csb.CommandTimeout = 3; - csb.CancellationTimeout = 15000; + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid)) + .WriteDataRowWithFlush(new byte[10000]); - await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); + using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); - // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) - var pgMock = await postmasterMock.WaitForServerConnection(); - await pgMock - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRowWithFlush(new byte[10000], new byte[4]); + await reader.ReadAsync(); - using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); + var task = reader.GetFieldValueAsync(0); - await reader.ReadAsync(); + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.TypeOf()); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } - var task = reader.GetFieldValueAsync(0); + [Test, Description("Timeouts sequential IsDBNullAsync")] + public async Task IsDBNullAsync_sequential_timeout() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation - var exception = Assert.ThrowsAsync(async () => await task); - Assert.That(exception.InnerException, Is.TypeOf()); + if (!IsSequential) + return; - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - } + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + CommandTimeout = 3, + CancellationTimeout = 15000 + }; + + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(ByteaOid), new FieldDescription(Int4Oid)) + .WriteDataRowWithFlush(new byte[10000], new byte[4]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); - #endregion + var task = reader.GetFieldValueAsync(0); - #region Initialization / setup / teardown + var exception = Assert.ThrowsAsync(async () => await task)!; + Assert.That(exception.InnerException, Is.TypeOf()); - // ReSharper disable InconsistentNaming - readonly bool IsSequential; - readonly CommandBehavior Behavior; - // ReSharper restore InconsistentNaming + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3446")] + public async Task Bug3446() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); - public ReaderTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) : base(multiplexingMode) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(new byte[4]) + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { - Behavior = behavior; - IsSequential = (Behavior & CommandBehavior.SequentialAccess) != 0; + await reader.ReadAsync(); + cmd.Cancel(); + await postmasterMock.WaitForCancellationRequest(); + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); } - #endregion + Assert.That(conn.Connector!.State, Is.EqualTo(ConnectorState.Ready)); } - #region Mock Type Handlers + #endregion + + #region Initialization / setup / teardown - class ExplodingTypeHandlerFactory : NpgsqlTypeHandlerFactory + // ReSharper disable InconsistentNaming + readonly bool IsSequential; + readonly CommandBehavior Behavior; + // ReSharper restore InconsistentNaming + + public ReaderTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) : base(multiplexingMode) { - readonly bool _safe; - internal ExplodingTypeHandlerFactory(bool safe) => _safe = safe; - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new ExplodingTypeHandler(postgresType, _safe); + Behavior = behavior; + IsSequential = (Behavior & CommandBehavior.SequentialAccess) != 0; } - class ExplodingTypeHandler : NpgsqlSimpleTypeHandler + #endregion +} + +#region Mock Type Handlers + +sealed class ExplodingTypeHandlerResolverFactory : PgTypeInfoResolverFactory +{ + readonly bool _safe; + public ExplodingTypeHandlerResolverFactory(bool safe) => _safe = safe; + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(_safe); + public override IPgTypeInfoResolver? CreateArrayResolver() => null; + + sealed class Resolver : IPgTypeInfoResolver { readonly bool _safe; - internal ExplodingTypeHandler(PostgresType postgresType, bool safe) - : base(postgresType) => _safe = safe; + public Resolver(bool safe) => _safe = safe; - public override int Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { - buf.ReadInt32(); + if (dataTypeName == DataTypeNames.Int4 && (type == typeof(int) || type is null)) + return new(options, new ExplodingTypeHandler(_safe), DataTypeNames.Int4); - throw _safe - ? new Exception("Safe read exception as requested") - : buf.Connector.Break(new Exception("Non-safe read exception as requested")); + return null; } - - public override int ValidateAndGetLength(int value, NpgsqlParameter? parameter) => throw new NotSupportedException(); - public override void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) => throw new NotSupportedException(); } +} - #endregion +class ExplodingTypeHandler : PgBufferedConverter +{ + readonly bool _safe; + + internal ExplodingTypeHandler(bool safe) => _safe = safe; + + public override Size GetSize(SizeContext context, int value, ref object? writeState) + => throw new NotSupportedException(); + + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + protected override void WriteCore(PgWriter writer, int value) + => throw new NotSupportedException(); + + protected override int ReadCore(PgReader reader) + { + if (_safe) + throw new Exception("Safe read exception as requested"); + + reader.BreakConnection(); + return default; + } } + +#endregion diff --git a/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs b/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs index 06b9700c04..a8a363a583 100644 --- a/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs @@ -6,207 +6,265 @@ using Npgsql.Replication.Internal; using NpgsqlTypes; -namespace Npgsql.Tests.Replication +namespace Npgsql.Tests.Replication; + +/// +/// Tests for common logical replication functionality. +/// +/// +/// While these tests might seem superfluous since we perform similar tests +/// for the individual logical replication tests, they are in fact not, because +/// the methods they test are extension points for plugin developers. +/// +[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] +[NonParallelizable] +public class CommonLogicalReplicationTests : SafeReplicationTestBase { - /// - /// Tests for common logical replication functionality. - /// - /// - /// While these tests might seem superfluous since we perform similar tests - /// for the individual logical replication tests, they are in fact not, because - /// the methods they test are extension points for plugin developers. - /// - public class CommonLogicalReplicationTests : SafeReplicationTestBase - { - // We use the test_decoding logical decoding plugin for the common - // logical replication tests because it has existed since the - // beginning of logical decoding and by that has the best backwards - // compatibility. - const string OutputPlugin = "test_decoding"; - - [TestCase(true)] - [TestCase(false)] - public Task CreateReplicationSlotForPlugin(bool temporary) - => SafeReplicationTest( - async (slotName, _) => + // We use the test_decoding logical decoding plugin for the common + // logical replication tests because it has existed since the + // beginning of logical decoding and by that has the best backwards + // compatibility. + const string OutputPlugin = "test_decoding"; + + [Test] + public Task CreateLogicalReplicationSlot([Values]bool temporary, [Values]bool twoPhase) + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + if (twoPhase) + TestUtil.MinimumPgVersion(c, "15.0", "Replication slots with two phase commit support were introduced in PostgreSQL 15"); + if (temporary) + TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + + await using var rc = await OpenReplicationConnectionAsync(); + var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, temporary, twoPhase: twoPhase); + + using var cmd = + new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.SlotName}'", + c); + await using var reader = await cmd.ExecuteReaderAsync(); + + Assert.That(reader.Read, Is.True); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); + if (c.PostgreSqlVersion >= Version.Parse("15.0")) + Assert.That(reader.GetFieldValue(reader.GetOrdinal("two_phase")), Is.EqualTo(twoPhase)); + if (c.PostgreSqlVersion >= Version.Parse("10.0")) + Assert.That(reader.GetFieldValue(reader.GetOrdinal("temporary")), Is.EqualTo(temporary)); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("active")), Is.EqualTo(temporary)); + if (c.PostgreSqlVersion >= Version.Parse("9.6")) + Assert.That(reader.GetFieldValue(reader.GetOrdinal("confirmed_flush_lsn")), + Is.EqualTo(options.ConsistentPoint)); + Assert.That(reader.Read, Is.False); + }, nameof(CreateLogicalReplicationSlot) + (temporary ? "_tmp" : "") + (twoPhase ? "_tp" : "")); + + [Test] + public Task CreateLogicalReplicationSlot_NoExport([Values]bool temporary, [Values]bool twoPhase) + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + if (temporary) + TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + if (twoPhase) + TestUtil.MinimumPgVersion(c, "15.0", "Replication slots with two phase commit support were introduced in PostgreSQL 15"); + + TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); + await using var rc = await OpenReplicationConnectionAsync(); + var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, temporary, LogicalSlotSnapshotInitMode.NoExport, twoPhase); + Assert.That(options.SnapshotName, Is.Null); + }, nameof(CreateLogicalReplicationSlot_NoExport) + (temporary ? "_tmp" : "") + (twoPhase ? "_tp" : "")); + + [Test(Description = "Tests whether we throw a helpful exception about the unsupported *_SNAPSHOT syntax on old servers.")] + [TestCase(LogicalSlotSnapshotInitMode.Export)] + [TestCase(LogicalSlotSnapshotInitMode.NoExport)] + [TestCase(LogicalSlotSnapshotInitMode.Use)] + public Task CreateLogicalReplicationSlot_with_SnapshotInitMode_on_old_postgres_throws(LogicalSlotSnapshotInitMode mode) + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + TestUtil.MaximumPgVersionExclusive(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - if (temporary) - TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + await using var rc = await OpenReplicationConnectionAsync(); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: mode); + }, Throws.InstanceOf() + .With.Message.StartsWith("The EXPORT_SNAPSHOT, USE_SNAPSHOT and NOEXPORT_SNAPSHOT syntax was introduced in PostgreSQL") + .And.InnerException.TypeOf() + .And.InnerException.Property("SqlState").EqualTo(PostgresErrorCodes.SyntaxError)); + }); + [Test(Description = "Tests whether we throw a helpful exception about unsupported temporary replication slots on old servers.")] + public Task CreateLogicalReplicationSlot_with_isTemporary_set_to_true_on_old_postgres_throws() + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + TestUtil.MaximumPgVersionExclusive(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + Assert.That(async () => + { await using var rc = await OpenReplicationConnectionAsync(); - var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, temporary); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, isTemporary: true); + }, Throws.InstanceOf() + .With.Message.StartsWith("Temporary replication slots were introduced in PostgreSQL") + .And.InnerException.TypeOf() + .And.InnerException.Property("SqlState").EqualTo(PostgresErrorCodes.SyntaxError)); + }); - using var cmd = - new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.SlotName}'", - c); - await using var reader = await cmd.ExecuteReaderAsync(); + [Test(Description = "Tests whether we throw a helpful exception about the unsupported TWO_PHASE syntax on old servers.")] + public Task CreateLogicalReplicationSlot_with_twoPhase_set_to_true_on_old_postgres_throws() + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + TestUtil.MaximumPgVersionExclusive(c, "15.0", + "Logical replication support for prepared transactions was introduced in PostgreSQL 15"); + Assert.That(async () => + { + await using var rc = await OpenReplicationConnectionAsync(); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, twoPhase: true); + }, Throws.InstanceOf() + .With.Message.StartsWith("Logical replication support for prepared transactions was introduced in PostgreSQL") + .And.InnerException.TypeOf() + .And.InnerException.Property("SqlState").EqualTo(PostgresErrorCodes.SyntaxError)); + }); + [Test(Description = "We can use the exported snapshot to query the database in the very moment the replication slot was created.")] + public Task CreateLogicalReplicationSlot_Export([Values]bool temporary, [Values]bool twoPhase, [Values]bool implicitInitMode) + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + if (temporary) + TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + if (twoPhase) + TestUtil.MinimumPgVersion(c, "15.0", "Replication slots with two phase commit support were introduced in PostgreSQL 15"); + if (!implicitInitMode) + TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); + await using (var transaction = c.BeginTransaction()) + { + await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (value text)"); + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (value) VALUES('Before snapshot')"); + transaction.Commit(); + } + await using var rc = await OpenReplicationConnectionAsync(); + var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, temporary, implicitInitMode ? null : LogicalSlotSnapshotInitMode.Export, twoPhase); + await using (var transaction = c.BeginTransaction()) + { + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (value) VALUES('After snapshot')"); + transaction.Commit(); + } + await using (var transaction = c.BeginTransaction(IsolationLevel.RepeatableRead)) + { + await c.ExecuteScalarAsync($"SET TRANSACTION SNAPSHOT '{options.SnapshotName}';", transaction); + using var cmd = new NpgsqlCommand($"SELECT value FROM {tableName}", c, transaction); + await using var reader = await cmd.ExecuteReaderAsync(); Assert.That(reader.Read, Is.True); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); - if (c.PostgreSqlVersion >= Version.Parse("10.0")) - Assert.That(reader.GetFieldValue(reader.GetOrdinal("temporary")), Is.EqualTo(temporary)); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("active")), Is.EqualTo(temporary)); - if (c.PostgreSqlVersion >= Version.Parse("9.6")) - Assert.That(reader.GetFieldValue(reader.GetOrdinal("confirmed_flush_lsn")), - Is.EqualTo(options.ConsistentPoint)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo("Before snapshot")); Assert.That(reader.Read, Is.False); - }, nameof(CreateReplicationSlotForPlugin) + temporary); + } + }, nameof(CreateLogicalReplicationSlot_Export) + (temporary ? "_tmp" : "") + (twoPhase ? "_tp" : "") + (implicitInitMode ? "_i" : "")); - [Test] - public Task CreateReplicationSlotForPluginNoExportSnapshot() - => SafeReplicationTest( - async (slotName, _) => + [Test(Description = "Since we currently don't provide an API to start a transaction on a logical replication connection, " + + "USE_SNAPSHOT currently doesn't work and always leads to an exception. On the other hand, starting" + + "a transaction would only be useful if we'd also provide an API to issue commands.")] + public Task CreateLogicalReplicationSlot_Use([Values]bool temporary, [Values]bool twoPhase) + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + if (temporary) + TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + if (twoPhase) + TestUtil.MinimumPgVersion(c, "15.0", "Replication slots with two phase commit support were introduced in PostgreSQL 15"); + + TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); await using var rc = await OpenReplicationConnectionAsync(); - var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: LogicalSlotSnapshotInitMode.NoExport); - Assert.That(options.SnapshotName, Is.Null); - }); - - [Test(Description = "Tests whether we throw a helpful exception about the unsupported *_SNAPSHOT syntax on old servers.")] - [TestCase(LogicalSlotSnapshotInitMode.Export)] - [TestCase(LogicalSlotSnapshotInitMode.NoExport)] - [TestCase(LogicalSlotSnapshotInitMode.Use)] - public Task CreateReplicationSlotForPluginExportSnapshotSyntaxThrows(LogicalSlotSnapshotInitMode mode) - => SafeReplicationTest( - async (slotName, _) => - { - await using var c = await OpenConnectionAsync(); - TestUtil.MaximumPgVersionExclusive(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: mode); - }, Throws.InstanceOf() - .With.Message.StartsWith("The EXPORT_SNAPSHOT, USE_SNAPSHOT and NOEXPORT_SNAPSHOT syntax was introduced in PostgreSQL") - .And.InnerException.TypeOf() - .And.InnerException.Property("SqlState").EqualTo(PostgresErrorCodes.SyntaxError)); - }); - - [Test(Description = "We can use the exported snapshot to query the database in the very moment the replication slot was created.")] - public Task CreateReplicationSlotForPluginExportSnapshot() - => SafeReplicationTest( - async (slotName, tableName) => + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, temporary, LogicalSlotSnapshotInitMode.Use, twoPhase); + }, Throws.InstanceOf() + .With.Property("SqlState") + .EqualTo("XX000") + .And.Message.Contains( + c.PostgreSqlVersion.Major < 15 + ? "USE_SNAPSHOT" + : "(SNAPSHOT 'use')" + )); + }, nameof(CreateLogicalReplicationSlot_Use) + (temporary ? "_tmp" : "") + (twoPhase ? "_tp" : "")); + + [Test] + public void CreateLogicalReplicationSlot_with_null_slot_throws() + => Assert.That(async () => + { + await using var rc = await OpenReplicationConnectionAsync(); + await rc.CreateLogicalReplicationSlot(null!, OutputPlugin); + }, Throws.ArgumentNullException + .With.Property("ParamName") + .EqualTo("slotName")); + + [Test] + public Task CreateLogicalReplicationSlot_with_null_output_plugin_throws() + => SafeReplicationTest( + (slotName, _) => + { + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); - await using (var transaction = c.BeginTransaction()) - { - await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (value text)"); - await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (value) VALUES('Before snapshot')"); - transaction.Commit(); - } await using var rc = await OpenReplicationConnectionAsync(); - var options = await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: LogicalSlotSnapshotInitMode.Export); - await using (var transaction = c.BeginTransaction()) - { - await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (value) VALUES('After snapshot')"); - transaction.Commit(); - } - await using (var transaction = c.BeginTransaction(IsolationLevel.RepeatableRead)) - { - await c.ExecuteScalarAsync($"SET TRANSACTION SNAPSHOT '{options.SnapshotName}';", transaction); - using var cmd = new NpgsqlCommand($"SELECT value FROM {tableName}", c, transaction); - await using var reader = await cmd.ExecuteReaderAsync(); - Assert.That(reader.Read, Is.True); - Assert.That(reader.GetFieldValue(0), Is.EqualTo("Before snapshot")); - Assert.That(reader.Read, Is.False); - } - }); - - [Test(Description = "Since we currently don't provide an API to start a transaction on a logical replication connection, " + - "USE_SNAPSHOT currently doesn't work and always leads to an exception. On the other hand, starting" + - "a transaction would only be useful if we'd also provide an API to issue commands.")] - public Task CreateReplicationSlotForPluginUseSnapshot() - => SafeReplicationTest( - async (slotName, _) => - { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The *_SNAPSHOT syntax was introduced in PostgreSQL 10"); - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: LogicalSlotSnapshotInitMode.Use); - }, Throws.InstanceOf() - .With.Property("SqlState") - .EqualTo("XX000") - .And.Message.Contains("USE_SNAPSHOT")); - }); - - [Test] - public void CreateReplicationSlotForPluginNullSlot() - => Assert.That(async () => + await rc.CreateLogicalReplicationSlot(slotName, null!); + }, Throws.ArgumentNullException + .With.Property("ParamName") + .EqualTo("outputPlugin")); + return Task.CompletedTask; + }); + + [Test] + public Task CreateLogicalReplicationSlot_with_cancelled_token() + => SafeReplicationTest( + (slotName, _) => { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.CreateLogicalReplicationSlot(null!, OutputPlugin); - }, Throws.ArgumentNullException - .With.Property("ParamName") - .EqualTo("slotName")); - - [Test] - public Task CreateReplicationSlotForPluginNullPlugin() - => SafeReplicationTest( - (slotName, _) => + Assert.That(async () => { - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.CreateLogicalReplicationSlot(slotName, null!); - }, Throws.ArgumentNullException - .With.Property("ParamName") - .EqualTo("outputPlugin")); - return Task.CompletedTask; - }); - - [Test] - public Task CreateReplicationSlotForPluginCancelled() - => SafeReplicationTest( - (slotName, _) => - { - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - using var cts = GetCancelledCancellationTokenSource(); - await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, cancellationToken: cts.Token); - }, Throws.Exception.AssignableTo()); - return Task.CompletedTask; - }); - - [Test] - public Task CreateReplicationSlotForPluginInvalidSlotSnapshotInitMode() - => SafeReplicationTest( - (slotName, _) => + await using var rc = await OpenReplicationConnectionAsync(); + var token = GetCancelledCancellationToken(); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, cancellationToken: token); + }, Throws.Exception.AssignableTo()); + return Task.CompletedTask; + }); + + [Test] + public Task CreateLogicalReplicationSlot_with_invalid_SnapshotInitMode_throws() + => SafeReplicationTest( + (slotName, _) => + { + Assert.That(async () => { - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: (LogicalSlotSnapshotInitMode)42); - }, Throws.InstanceOf() - .With.Property("ParamName") - .EqualTo("slotSnapshotInitMode") - .And.Property("ActualValue") - .EqualTo((LogicalSlotSnapshotInitMode)42)); - return Task.CompletedTask; - }); - - [Test] - public Task CreateReplicationSlotForPluginDisposed() - => SafeReplicationTest( - (slotName, _) => + await using var rc = await OpenReplicationConnectionAsync(); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin, slotSnapshotInitMode: (LogicalSlotSnapshotInitMode)42); + }, Throws.InstanceOf() + .With.Property("ParamName") + .EqualTo("slotSnapshotInitMode") + .And.Property("ActualValue") + .EqualTo((LogicalSlotSnapshotInitMode)42)); + return Task.CompletedTask; + }); + + [Test] + public Task CreateLogicalReplicationSlot_with_disposed_connection_throws() + => SafeReplicationTest( + (slotName, _) => + { + Assert.That(async () => { - Assert.That(async () => - { - var rc = await OpenReplicationConnectionAsync(); - await rc.DisposeAsync(); - await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin); - }, Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(nameof(LogicalReplicationConnection))); - return Task.CompletedTask; - }); - - protected override string Postfix => "commonl_l"; - } + var rc = await OpenReplicationConnectionAsync(); + await rc.DisposeAsync(); + await rc.CreateLogicalReplicationSlot(slotName, OutputPlugin); + }, Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(nameof(LogicalReplicationConnection))); + return Task.CompletedTask; + }); + + protected override string Postfix => "commonl_l"; } diff --git a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs index 370aea81e8..36a11b434a 100644 --- a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Runtime.CompilerServices; @@ -11,517 +10,545 @@ using Npgsql.Replication.TestDecoding; using NpgsqlTypes; -namespace Npgsql.Tests.Replication +namespace Npgsql.Tests.Replication; + +[TestFixture(typeof(LogicalReplicationConnection))] +[TestFixture(typeof(PhysicalReplicationConnection))] +[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] +[NonParallelizable] +public class CommonReplicationTests : SafeReplicationTestBase + where TConnection : ReplicationConnection, new() { - [TestFixture(typeof(LogicalReplicationConnection))] - [TestFixture(typeof(PhysicalReplicationConnection))] - public class CommonReplicationTests : SafeReplicationTestBase - where TConnection : ReplicationConnection, new() + #region Open + + [Test, NonParallelizable] + public async Task Open() { - #region Open + await using var rc = await OpenReplicationConnectionAsync(); + } - [Test] - public async Task Open() + [Test] + public void Open_with_cancelled_token() + => Assert.That(async () => { - await using var rc = await OpenReplicationConnectionAsync(); - } + var token = GetCancelledCancellationToken(); + await using var rc = await OpenReplicationConnectionAsync(cancellationToken: token); + }, Throws.Exception.AssignableTo()); - [Test] - public void OpenCancelled() - => Assert.That(async () => - { - using var cts = GetCancelledCancellationTokenSource(); - await using var rc = await OpenReplicationConnectionAsync(cancellationToken: cts.Token); - }, Throws.Exception.AssignableTo()); + [Test] + public void Open_on_disposed_connection() + => Assert.That(async () => + { + var rc = await OpenReplicationConnectionAsync(); + await rc.DisposeAsync(); + await rc.Open(); + }, Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(typeof(TConnection).Name)); - [Test] - public void OpenDisposed() - => Assert.That(async () => - { - var rc = await OpenReplicationConnectionAsync(); - await rc.DisposeAsync(); - await rc.Open(); - }, Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(typeof(TConnection).Name)); + #endregion Open - #endregion Open + #region IdentifySystem - #region IdentifySystem + [Test] + public async Task IdentifySystem() + { + await using var rc = await OpenReplicationConnectionAsync(); + var info = await rc.IdentifySystem(); + Assert.That(info.Timeline, Is.GreaterThan(0)); + } - [Test] - public async Task IdentifySystem() + [Test] + public void IdentifySystem_with_cancelled_token() + => Assert.That(async () => { await using var rc = await OpenReplicationConnectionAsync(); - var info = await rc.IdentifySystem(); - Assert.That(info.Timeline, Is.GreaterThan(0)); - } + var token = GetCancelledCancellationToken(); + await rc.IdentifySystem(token); + }, Throws.Exception.AssignableTo()); - [Test] - public void IdentifySystemCancelled() - => Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - using var cts = GetCancelledCancellationTokenSource(); - await rc.IdentifySystem(cts.Token); - }, Throws.Exception.AssignableTo()); + [Test] + public void IdentifySystem_on_disposed_connection() + => Assert.That(async () => + { + var rc = await OpenReplicationConnectionAsync(); + await rc.DisposeAsync(); + await rc.IdentifySystem(); + }, Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(typeof(TConnection).Name)); - [Test] - public void IdentifySystemDisposed() - => Assert.That(async () => - { - var rc = await OpenReplicationConnectionAsync(); - await rc.DisposeAsync(); - await rc.IdentifySystem(); - }, Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(typeof(TConnection).Name)); + #endregion IdentifySystem - #endregion IdentifySystem + #region Show - #region Show + [Test] + public async Task Show() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); - [Test] - public async Task Show() - { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); + await using var rc = await OpenReplicationConnectionAsync(); + Assert.That(await rc.Show("integer_datetimes"), Is.EqualTo("on")); + } - await using var rc = await OpenReplicationConnectionAsync(); - Assert.That(await rc.Show("integer_datetimes"), Is.EqualTo("on")); - } + [Test] + public async Task Show_with_null_argument_throws() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); - [Test] - public async Task ShowNullArgument() + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); + await using var rc = await OpenReplicationConnectionAsync(); + await rc.Show(null!); + }, Throws.ArgumentNullException + .With.Property("ParamName") + .EqualTo("parameterName")); + } - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.Show(null!); - }, Throws.ArgumentNullException - .With.Property("ParamName") - .EqualTo("parameterName")); - } + [Test] + public async Task Show_with_cancelled_token() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); - [Test] - public async Task ShowCancelled() + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); + await using var rc = await OpenReplicationConnectionAsync(); + var token = GetCancelledCancellationToken(); + await rc.Show("integer_datetimes", token); + }, Throws.Exception.AssignableTo()); + } - Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - using var cts = GetCancelledCancellationTokenSource(); - await rc.Show("integer_datetimes", cts.Token); - }, Throws.Exception.AssignableTo()); - } + [Test] + public async Task Show_on_disposed_connection() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); - [Test] - public async Task ShowDisposed() + Assert.That(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command was added to the Streaming Replication Protocol in PostgreSQL 10"); + var rc = await OpenReplicationConnectionAsync(); + await rc.DisposeAsync(); + await rc.Show("integer_datetimes"); + }, Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(typeof(TConnection).Name)); + } - Assert.That(async () => - { - var rc = await OpenReplicationConnectionAsync(); - await rc.DisposeAsync(); - await rc.Show("integer_datetimes"); - }, Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(typeof(TConnection).Name)); - } + #endregion Show - #endregion Show + #region TimelineHistory - #region TimelineHistory + [Test, Explicit("After initdb a PostgreSQL cluster only has one timeline and no timeline history so this command fails. " + + "You need to explicitly create multiple timelines (e. g. via PITR or by promoting a standby) for this test to work.")] + public async Task TimelineHistory() + { + await using var rc = await OpenReplicationConnectionAsync(); + var systemInfo = await rc.IdentifySystem(); + var info = await rc.TimelineHistory(systemInfo.Timeline); + Assert.That(info.FileName, Is.Not.Null); + Assert.That(info.Content.Length, Is.GreaterThan(0)); + } - [Test, Explicit("After initdb a PostgreSQL cluster only has one timeline and no timeline history so this command fails. " + - "You need to explicitly create multiple timelines (e. g. via PITR or by promoting a standby) for this test to work.")] - public async Task TimelineHistory() + [Test] + public void TimelineHistory_with_cancelled_token() + => Assert.That(async () => { await using var rc = await OpenReplicationConnectionAsync(); var systemInfo = await rc.IdentifySystem(); - var info = await rc.TimelineHistory(systemInfo.Timeline); - Assert.That(info.FileName, Is.Not.Null); - Assert.That(info.Content.Length, Is.GreaterThan(0)); - } + var token = GetCancelledCancellationToken(); + await rc.TimelineHistory(systemInfo.Timeline, token); + }, Throws.Exception.AssignableTo()); - [Test] - public void TimelineHistoryCancelled() - => Assert.That(async () => - { - await using var rc = await OpenReplicationConnectionAsync(); - var systemInfo = await rc.IdentifySystem(); - using var cts = GetCancelledCancellationTokenSource(); - await rc.TimelineHistory(systemInfo.Timeline, cts.Token); - }, Throws.Exception.AssignableTo()); - - [Test] - public void TimelineHistoryNonExisting() - => Assert.That(async () => + [Test] + public void TimelineHistory_with_non_existing_timeline() + => Assert.That(async () => + { + await using var rc = await OpenReplicationConnectionAsync(); + await rc.TimelineHistory(uint.MaxValue); + }, Throws + .InstanceOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFile) + .Or + .InstanceOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.CharacterNotInRepertoire)); + + [Test] + public void TimelineHistory_on_disposed_connection() + => Assert.That(async () => + { + var rc = await OpenReplicationConnectionAsync(); + var systemInfo = await rc.IdentifySystem(); + await rc.DisposeAsync(); + await rc.TimelineHistory(systemInfo.Timeline); + }, Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(typeof(TConnection).Name)); + + #endregion TimelineHistory + + #region DropReplicationSlot + + [Test] + public void DropReplicationSlot_with_null_slot_throws() + => Assert.That(async () => + { + await using var rc = await OpenReplicationConnectionAsync(); + await rc.DropReplicationSlot(null!); + }, Throws.ArgumentNullException + .With.Property("ParamName") + .EqualTo("slotName")); + + [Test] + public Task DropReplicationSlot_with_cancelled_token() + => SafeReplicationTest( + async (slotName, _) => { + await CreateReplicationSlot(slotName); await using var rc = await OpenReplicationConnectionAsync(); - await rc.TimelineHistory(uint.MaxValue); - }, Throws - .InstanceOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFile) - .Or - .InstanceOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.CharacterNotInRepertoire)); - - [Test] - public void TimelineHistoryDisposed() - => Assert.That(async () => + var token = GetCancelledCancellationToken(); + Assert.That(async () => await rc.DropReplicationSlot(slotName, cancellationToken: token), Throws.Exception.AssignableTo()); + }); + + [Test] + public Task DropReplicationSlot_on_disposed_connection() + => SafeReplicationTest( + async (slotName, _) => { + await CreateReplicationSlot(slotName); var rc = await OpenReplicationConnectionAsync(); - var systemInfo = await rc.IdentifySystem(); await rc.DisposeAsync(); - await rc.TimelineHistory(systemInfo.Timeline); - }, Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(typeof(TConnection).Name)); + Assert.That(async () => await rc.DropReplicationSlot(slotName), Throws.InstanceOf() + .With.Property(nameof(ObjectDisposedException.ObjectName)) + .EqualTo(typeof(TConnection).Name)); + }); - #endregion TimelineHistory + #endregion - #region DropReplicationSlot - - [Test] - public void DropReplicationSlotNullSlot() - => Assert.That(async () => + [Test(Description = "Tests whether our automated feedback thread prevents the backend from disconnecting due to wal_sender_timeout")] + public Task Replication_survives_pauses_longer_than_wal_sender_timeout() + => SafeReplicationTest( + async (slotName, tableName) => { - await using var rc = await OpenReplicationConnectionAsync(); - await rc.DropReplicationSlot(null!); - }, Throws.ArgumentNullException - .With.Property("ParamName") - .EqualTo("slotName")); - - [Test] - public Task DropReplicationSlotCancelled() - => SafeReplicationTest( - async (slotName, _) => - { - await CreateReplicationSlot(slotName); - await using var rc = await OpenReplicationConnectionAsync(); - using var cts = GetCancelledCancellationTokenSource(); - Assert.That(async () => await rc.DropReplicationSlot(slotName, cancellationToken: cts.Token), Throws.Exception.AssignableTo()); - }); + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command, which is required to run this test was added to the Streaming Replication Protocol in PostgreSQL 10"); + await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL);"); + await using var rc = await OpenReplicationConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString)); + var walSenderTimeout = ParseTimespan(await rc.Show("wal_sender_timeout")); + var info = await rc.IdentifySystem(); + if (walSenderTimeout > TimeSpan.FromSeconds(3) && !TestUtil.IsOnBuildServer) + Assert.Ignore($"wal_sender_timeout is set to {walSenderTimeout}, skipping"); + var walReceiverStatusInterval = TimeSpan.FromTicks(walSenderTimeout.Ticks / 2L); + rc.WalReceiverStatusInterval = walReceiverStatusInterval; + await CreateReplicationSlot(slotName); + await c.ExecuteNonQueryAsync($"INSERT INTO \"{tableName}\" (name) VALUES ('val1')"); + using var streamingCts = new CancellationTokenSource(); + + var replicationEnumerator = StartReplication(rc, slotName, info.XLogPos, streamingCts.Token).GetAsyncEnumerator(streamingCts.Token); + Assert.That(await replicationEnumerator.MoveNextAsync(), Is.True); + + await Task.Delay(walSenderTimeout * 1.1, CancellationToken.None); + + await c.ExecuteNonQueryAsync($"INSERT INTO \"{tableName}\" (name) VALUES ('val2')"); + Assert.That(await replicationEnumerator.MoveNextAsync(), Is.True); + + var message = replicationEnumerator.Current; + Assert.That(message.WalStart, Is.GreaterThanOrEqualTo(info.XLogPos)); + Assert.That(message.WalEnd, Is.GreaterThanOrEqualTo(message.WalStart)); + + streamingCts.Cancel(); + + Assert.That(async () => { while (await replicationEnumerator.MoveNextAsync()){} }, Throws.Exception.AssignableTo() + .With.InnerException.InstanceOf() + .And.InnerException.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled)); + + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether synchronous replication works the way it should.")] + [Explicit("Test is flaky (on Windows)")] + public Task Synchronous_replication() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + //TestUtil.MinimumPgVersion(c, "9.4", "Logical Replication was introduced in PostgreSQL 9.4"); + // + TestUtil.MinimumPgVersion(c, "12.0", "Setting wal_sender_timeout at runtime was introduced in in PostgreSQL 12"); + + var synchronousCommit = (string)(await c.ExecuteScalarAsync("SHOW synchronous_commit"))!; + if (synchronousCommit != "local") + TestUtil.IgnoreExceptOnBuildServer("Ignoring because synchronous_commit isn't 'local'"); + var synchronousStandbyNames = (string)(await c.ExecuteScalarAsync("SHOW synchronous_standby_names"))!; + if (synchronousStandbyNames != "npgsql_test_sync_standby") + TestUtil.IgnoreExceptOnBuildServer("Ignoring because synchronous_standby_names isn't 'npgsql_test_sync_standby'"); + + await c.ExecuteNonQueryAsync(@$" + CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); + "); - [Test] - public Task DropReplicationSlotDisposed() - => SafeReplicationTest( - async (slotName, _) => + await using var rc = await OpenReplicationConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) { - await CreateReplicationSlot(slotName); - await using var rc = await OpenReplicationConnectionAsync(); - await rc.DisposeAsync(); - Assert.That(async () => await rc.DropReplicationSlot(slotName), Throws.InstanceOf() - .With.Property(nameof(ObjectDisposedException.ObjectName)) - .EqualTo(typeof(TConnection).Name)); + // This must be one of the configured synchronous_standby_names from postgresql.conf + ApplicationName = "npgsql_test_sync_standby", + // We need wal_sender_timeout to be at least twice checkpoint_timeout to avoid getting feedback requests + // from the backend in physical replication which makes this test fail, so we disable it for this test. + Options = "-c wal_sender_timeout=0" }); - - #endregion - - [Test(Description = "Tests whether our automated feedback thread prevents the backend from disconnecting due to wal_sender_timeout")] - public Task ReplicationSurvivesPausesLongerThanWalSenderTimeout() - => SafeReplicationTest( - async (slotName, tableName) => + var info = await rc.IdentifySystem(); + + // Set WalReceiverStatusInterval to infinite so that the automated feedback doesn't interfere with + // our manual feedback + rc.WalReceiverStatusInterval = Timeout.InfiniteTimeSpan; + + await CreateReplicationSlot(slotName); + using var streamingCts = new CancellationTokenSource(); + var messages = ParseMessages( + StartReplication(rc, slotName, info.XLogPos, streamingCts.Token)) + .GetAsyncEnumerator(); + + var value1String = Guid.NewGuid().ToString("B"); + // We need to start a separate thread here as the insert command wil not complete until + // the transaction successfully completes (which we block here from the standby side) and by that + // will occupy the connection it is bound to. + var insertTask = Task.Run(async () => { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The SHOW command, which is required to run this test was added to the Streaming Replication Protocol in PostgreSQL 10"); - var messages = new ConcurrentQueue(); - await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL);"); - await using var rc = await OpenReplicationConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = slotName - }); - var walSenderTimeout = ParseTimespan(await rc.Show("wal_sender_timeout")); - var info = await rc.IdentifySystem(); - if (walSenderTimeout > TimeSpan.FromSeconds(3) && !TestUtil.IsOnBuildServer) - Assert.Ignore($"wal_sender_timeout is set to {walSenderTimeout}, skipping"); - Console.WriteLine($"The server wal_sender_timeout is configured to {walSenderTimeout}"); - var walReceiverStatusInterval = TimeSpan.FromTicks(walSenderTimeout.Ticks / 2L); - Console.WriteLine($"Setting {nameof(ReplicationConnection)}.{nameof(ReplicationConnection.WalReceiverStatusInterval)} to {walReceiverStatusInterval}"); - rc.WalReceiverStatusInterval = walReceiverStatusInterval; - await CreateReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"INSERT INTO \"{tableName}\" (name) VALUES ('val1')"); - using var streamingCts = new CancellationTokenSource(); - - var replicationEnumerator = StartReplication(rc, slotName, info.XLogPos, streamingCts.Token).GetAsyncEnumerator(streamingCts.Token); - - var delay = TimeSpan.FromTicks((long)(walSenderTimeout.Ticks * 1.1)); - Console.WriteLine($"Going to sleep for {delay}"); - await Task.Delay(delay, CancellationToken.None); - - Assert.That(await replicationEnumerator.MoveNextAsync(), Is.True); - var message = replicationEnumerator.Current; - Assert.That(message.WalStart, Is.GreaterThanOrEqualTo(info.XLogPos)); - Assert.That(message.WalEnd, Is.GreaterThanOrEqualTo(message.WalStart)); - - streamingCts.Cancel(); - - Assert.That(async () => { while (await replicationEnumerator.MoveNextAsync()){} }, Throws.Exception.AssignableTo() - .With.InnerException.InstanceOf() - .And.InnerException.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled)); - - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + await using var dataSource = CreateDataSource(csb => csb.Options = "-c synchronous_commit=on"); + await using var insertConn = await dataSource.OpenConnectionAsync(); + await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value1String}')"); }); - [Test(Description = "Tests whether synchronous replication works the way it should.")] - [Explicit("Test is flaky (on Windows)")] - public Task SynchronousReplication() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - //TestUtil.MinimumPgVersion(c, "9.4", "Logical Replication was introduced in PostgreSQL 9.4"); - // - TestUtil.MinimumPgVersion(c, "12.0", "Setting wal_sender_timeout at runtime was introduced in in PostgreSQL 12"); - - var synchronousCommit = (string)(await c.ExecuteScalarAsync("SHOW synchronous_commit"))!; - if (synchronousCommit != "local") - TestUtil.IgnoreExceptOnBuildServer("Ignoring because synchronous_commit isn't 'local'"); - var synchronousStandbyNames = (string)(await c.ExecuteScalarAsync("SHOW synchronous_standby_names"))!; - if (synchronousStandbyNames != "npgsql_test_sync_standby") - TestUtil.IgnoreExceptOnBuildServer("Ignoring because synchronous_standby_names isn't 'npgsql_test_sync_standby'"); - - await c.ExecuteNonQueryAsync(@$" - CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); - "); - - await using var rc = await OpenReplicationConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) - { - // This must be one of the configured synchronous_standby_names from postgresql.conf - ApplicationName = "npgsql_test_sync_standby", - // We need wal_sender_timeout to be at least twice checkpoint_timeout to avoid getting feedback requests - // from the backend in physical replication which makes this test fail, so we disable it for this test. - Options = "-c wal_sender_timeout=0" - }); - var info = await rc.IdentifySystem(); - - // Set WalReceiverStatusInterval to infinite so that the automated feedback doesn't interfere with - // our manual feedback - rc.WalReceiverStatusInterval = Timeout.InfiniteTimeSpan; - - await CreateReplicationSlot(slotName); - using var streamingCts = new CancellationTokenSource(); - var messages = ParseMessages( - StartReplication(rc, slotName, info.XLogPos, streamingCts.Token)) - .GetAsyncEnumerator(); - - var value1String = Guid.NewGuid().ToString("B"); - // We need to start a separate thread here as the insert command wil not complete until - // the transaction successfully completes (which we block here from the standby side) and by that - // will occupy the connection it is bound to. - var insertTask = Task.Run(async () => - { - await using var insertConn = await OpenConnectionAsync(new NpgsqlConnectionStringBuilder(ConnectionString) - { - Options = "-c synchronous_commit=on" - }); - await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value1String}')"); - }); - - var commitLsn = await GetCommitLsn(value1String); + var commitLsn = await GetCommitLsn(value1String); - var result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.Null); // Not committed yet because we didn't report fsync yet + var result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.Null); // Not committed yet because we didn't report fsync yet - // Report last received LSN - await rc.SendStatusUpdate(CancellationToken.None); + // Report last received LSN + await rc.SendStatusUpdate(CancellationToken.None); - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.Null); // Not committed yet because we still didn't report fsync yet + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.Null); // Not committed yet because we still didn't report fsync yet - // Report last applied LSN - rc.LastAppliedLsn = commitLsn; - await rc.SendStatusUpdate(CancellationToken.None); + // Report last applied LSN + rc.LastAppliedLsn = commitLsn; + await rc.SendStatusUpdate(CancellationToken.None); - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.Null); // Not committed yet because we still didn't report fsync yet + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.Null); // Not committed yet because we still didn't report fsync yet - // Report last flushed LSN - rc.LastFlushedLsn = commitLsn; - await rc.SendStatusUpdate(CancellationToken.None); + // Report last flushed LSN + rc.LastFlushedLsn = commitLsn; + await rc.SendStatusUpdate(CancellationToken.None); - await insertTask; - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value1String)); // Now it's committed because we reported fsync + await insertTask; + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value1String)); // Now it's committed because we reported fsync - var value2String = Guid.NewGuid().ToString("B"); - insertTask = Task.Run(async () => - { - await using var insertConn = OpenConnection(new NpgsqlConnectionStringBuilder(ConnectionString) - { - Options = "-c synchronous_commit=remote_apply" - }); - await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value2String}')"); - }); + var value2String = Guid.NewGuid().ToString("B"); + insertTask = Task.Run(async () => + { + await using var dataSource = CreateDataSource(csb => csb.Options = "-c synchronous_commit=remote_apply"); + await using var insertConn = await dataSource.OpenConnectionAsync(); + await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value2String}')"); + }); - commitLsn = await GetCommitLsn(value2String); + commitLsn = await GetCommitLsn(value2String); - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value1String)); // Not committed yet because we didn't report apply yet + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value1String)); // Not committed yet because we didn't report apply yet - // Report last received LSN - await rc.SendStatusUpdate(CancellationToken.None); + // Report last received LSN + await rc.SendStatusUpdate(CancellationToken.None); - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value1String)); // Not committed yet because we still didn't report apply yet + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value1String)); // Not committed yet because we still didn't report apply yet - // Report last applied LSN - rc.LastAppliedLsn = commitLsn; - await rc.SendStatusUpdate(CancellationToken.None); + // Report last applied LSN + rc.LastAppliedLsn = commitLsn; + await rc.SendStatusUpdate(CancellationToken.None); - await insertTask; - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value2String)); // Now it's committed because we reported apply + await insertTask; + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value2String)); // Now it's committed because we reported apply - var value3String = Guid.NewGuid().ToString("B"); - insertTask = Task.Run(async () => - { - await using var insertConn = OpenConnection(new NpgsqlConnectionStringBuilder(ConnectionString) - { - Options = "-c synchronous_commit=remote_write" - }); - await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value3String}')"); - }); + var value3String = Guid.NewGuid().ToString("B"); + insertTask = Task.Run(async () => + { + await using var dataSource = CreateDataSource(csb => csb.Options = "-c synchronous_commit=remote_write"); + await using var insertConn = await dataSource.OpenConnectionAsync(); + await insertConn.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('{value3String}')"); + }); - await GetCommitLsn(value3String); + await GetCommitLsn(value3String); - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value2String)); // Not committed yet because we didn't report receive yet + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value2String)); // Not committed yet because we didn't report receive yet - // Report last received LSN - await rc.SendStatusUpdate(CancellationToken.None); + // Report last received LSN + await rc.SendStatusUpdate(CancellationToken.None); - await insertTask; - result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); - Assert.That(result, Is.EqualTo(value3String)); // Now it's committed because we reported receive + await insertTask; + result = await c.ExecuteScalarAsync($"SELECT name FROM {tableName} ORDER BY id DESC LIMIT 1;"); + Assert.That(result, Is.EqualTo(value3String)); // Now it's committed because we reported receive - streamingCts.Cancel(); - Assert.That(async () => await messages.MoveNextAsync(), Throws.Exception.AssignableTo() - .With.InnerException.InstanceOf() - .And.InnerException.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled)); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + streamingCts.Cancel(); + Assert.That(async () => await messages.MoveNextAsync(), Throws.Exception.AssignableTo() + .With.InnerException.InstanceOf() + .And.InnerException.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled)); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - static async IAsyncEnumerable<(NpgsqlLogSequenceNumber Lsn, string? MessageData)> ParseMessages( - IAsyncEnumerable messages) + static async IAsyncEnumerable<(NpgsqlLogSequenceNumber Lsn, string? MessageData)> ParseMessages( + IAsyncEnumerable messages) + { + await foreach (var msg in messages) { - await foreach (var msg in messages) + if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) { - if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) - { - var buffer = new MemoryStream(); - ((XLogDataMessage)msg).Data.CopyTo(buffer); - // Hack: This is really gruesome but we really have no idea how many - // messages we get in physical replication - var messageString = Encoding.ASCII.GetString(buffer.ToArray()); - yield return (msg.WalEnd, messageString); - } - else - { - yield return (msg.WalEnd, null); - } + var buffer = new MemoryStream(); + ((XLogDataMessage)msg).Data.CopyTo(buffer); + // Hack: This is really gruesome but we really have no idea how many + // messages we get in physical replication + var messageString = Encoding.ASCII.GetString(buffer.ToArray()); + yield return (msg.WalEnd, messageString); + } + else + { + yield return (msg.WalEnd, null); } } + } - async Task GetCommitLsn(string valueString) - { - if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) - while (await messages.MoveNextAsync()) - if (messages.Current.MessageData!.Contains(valueString)) - return messages.Current.Lsn; + async Task GetCommitLsn(string valueString) + { + if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) + while (await messages.MoveNextAsync()) + if (messages.Current.MessageData!.Contains(valueString)) + return messages.Current.Lsn; - // NpgsqlLogicalReplicationConnection - // Begin Transaction, Insert, Commit Transaction - for (var i = 0; i < 3; i++) - Assert.True(await messages.MoveNextAsync()); - return messages.Current.Lsn; + // NpgsqlLogicalReplicationConnection + // Begin Transaction, Insert, Commit Transaction + for (var i = 0; i < 3; i++) + Assert.True(await messages.MoveNextAsync()); + return messages.Current.Lsn; - } - }); + } + }); - #region BaseBackup + #region BaseBackup - // ToDo: Implement BaseBackup and create tests for it + // ToDo: Implement BaseBackup and create tests for it - #endregion + #endregion - async Task CreateReplicationSlot(string slotName) - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(typeof(TConnection) == typeof(PhysicalReplicationConnection) - ? $"SELECT pg_create_physical_replication_slot('{slotName}')" - : $"SELECT pg_create_logical_replication_slot ('{slotName}', 'test_decoding')"); - } + #region BugTests - async IAsyncEnumerable StartReplication(TConnection connection, string slotName, - NpgsqlLogSequenceNumber xLogPos, [EnumeratorCancellation] CancellationToken cancellationToken) - { - if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3534")] + [NonParallelizable] + public Task Bug3534() + => SafeReplicationTest( + async (slotName, _) => { - var slot = new PhysicalReplicationSlot(slotName); - var rc = (PhysicalReplicationConnection)(ReplicationConnection)connection; - await foreach (var msg in rc.StartReplication(slot, xLogPos, cancellationToken)) + await using var rc = await OpenReplicationConnectionAsync(); + var info = await rc.IdentifySystem(); + await CreateReplicationSlot(slotName); + using var streamingCts = new CancellationTokenSource(); + rc.WalReceiverStatusInterval = TimeSpan.FromSeconds(1D); + rc.WalReceiverTimeout = TimeSpan.FromSeconds(3D); + await using var replicationEnumerator = StartReplication(rc, slotName, info.XLogPos, streamingCts.Token).GetAsyncEnumerator(streamingCts.Token); + + var replicationMessageTask = replicationEnumerator.MoveNextAsync(); + streamingCts.CancelAfter(rc.WalReceiverTimeout * 2); + + Assert.Multiple(() => { - yield return msg; - } - } - else if (typeof(TConnection) == typeof(LogicalReplicationConnection)) + Assert.That(async () => + { + // We only expect one transaction here but we need to keep polling + // because in physical replication we can't prevent internal transactions + // from being sent to the replication connection + while (true) + { + await replicationMessageTask; + replicationMessageTask = replicationEnumerator.MoveNextAsync(); + } + }, Throws.Exception.AssignableTo()); + Assert.That(streamingCts.IsCancellationRequested); + }); + }); + + #endregion + + async Task CreateReplicationSlot(string slotName) + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(typeof(TConnection) == typeof(PhysicalReplicationConnection) + ? $"SELECT pg_create_physical_replication_slot('{slotName}')" + : $"SELECT pg_create_logical_replication_slot ('{slotName}', 'test_decoding')"); + } + + async IAsyncEnumerable StartReplication(TConnection connection, string slotName, + NpgsqlLogSequenceNumber xLogPos, [EnumeratorCancellation] CancellationToken cancellationToken) + { + if (typeof(TConnection) == typeof(PhysicalReplicationConnection)) + { + var slot = new PhysicalReplicationSlot(slotName); + var rc = (PhysicalReplicationConnection)(ReplicationConnection)connection; + await foreach (var msg in rc.StartReplication(slot, xLogPos, cancellationToken)) { - var slot = new TestDecodingReplicationSlot(slotName); - var rc = (LogicalReplicationConnection)(ReplicationConnection)connection; - await foreach (var msg in rc.StartReplication(slot, cancellationToken, walLocation: xLogPos)) - { - yield return msg; - } + yield return msg; } } - - static TimeSpan ParseTimespan(string str) + else if (typeof(TConnection) == typeof(LogicalReplicationConnection)) { - var span = str.AsSpan(); - var pos = 0; - var number = 0; - while (pos < span.Length) + var slot = new TestDecodingReplicationSlot(slotName); + var rc = (LogicalReplicationConnection)(ReplicationConnection)connection; + await foreach (var msg in rc.StartReplication(slot, cancellationToken, options: new TestDecodingOptions(skipEmptyXacts: true), walLocation: xLogPos)) { - var c = span[pos]; - if (!char.IsDigit(c)) - break; - number = number * 10 + (c - 0x30); - pos++; + yield return msg; } + } + } - if (number == 0) - return Timeout.InfiniteTimeSpan; - if ("ms".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) - return TimeSpan.FromMilliseconds(number); - if ("s".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) - return TimeSpan.FromSeconds(number); - if ("min".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) - return TimeSpan.FromMinutes(number); - if ("h".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) - return TimeSpan.FromHours(number); - if ("d".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) - return TimeSpan.FromDays(number); - - throw new ArgumentException($"Can not parse timestamp '{span.ToString()}'"); + static TimeSpan ParseTimespan(string str) + { + var span = str.AsSpan(); + var pos = 0; + var number = 0; + while (pos < span.Length) + { + var c = span[pos]; + if (!char.IsDigit(c)) + break; + number = number * 10 + (c - 0x30); + pos++; } - protected override string Postfix => - "common_" + - new TConnection() switch - { - LogicalReplicationConnection _ => "_l", - PhysicalReplicationConnection _ => "_p", - _ => throw new ArgumentOutOfRangeException($"{typeof(TConnection)} is not expected.") - }; + if (number == 0) + return Timeout.InfiniteTimeSpan; + if ("ms".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) + return TimeSpan.FromMilliseconds(number); + if ("s".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) + return TimeSpan.FromSeconds(number); + if ("min".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) + return TimeSpan.FromMinutes(number); + if ("h".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) + return TimeSpan.FromHours(number); + if ("d".AsSpan().Equals(span.Slice(pos), StringComparison.Ordinal)) + return TimeSpan.FromDays(number); + + throw new ArgumentException($"Can not parse timestamp '{span.ToString()}'"); } + + protected override string Postfix => + "common_" + + new TConnection() switch + { + LogicalReplicationConnection _ => "_l", + PhysicalReplicationConnection _ => "_p", + _ => throw new ArgumentOutOfRangeException($"{typeof(TConnection)} is not expected.") + }; } diff --git a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs index f04eeb8ac5..3eb3921b79 100644 --- a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs @@ -1,5 +1,8 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -7,468 +10,1438 @@ using Npgsql.Replication; using Npgsql.Replication.PgOutput; using Npgsql.Replication.PgOutput.Messages; - -namespace Npgsql.Tests.Replication +using TruncateOptions = Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions; +using ReplicaIdentitySetting = Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests.Replication; + +[TestFixture(ProtocolVersion.V1, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(ProtocolVersion.V1, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(ProtocolVersion.V2, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] +[TestFixture(ProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(ProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] +// We currently don't execute all possible combinations of settings for efficiency reasons because they don't +// interact in the current implementation. +// Feel free to uncomment some or all of the following lines if the implementation changed or you suspect a +// problem with some combination. +// [TestFixture(ProtocolVersion.V1, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] +// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] +// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] +// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.StreamingTransactionMode)] +// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] +// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] +// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.StreamingTransactionMode)] +[NonParallelizable] // These tests aren't designed to be parallelizable +public class PgOutputReplicationTests : SafeReplicationTestBase { - public class PgOutputReplicationTests : SafeReplicationTestBase + readonly ulong _protocolVersion; + readonly bool? _binary; + readonly bool? _streaming; + + bool IsBinary => _binary ?? false; + bool IsStreaming => _streaming ?? false; + ulong Version => _protocolVersion; + + public PgOutputReplicationTests(ProtocolVersion protocolVersion, ReplicationDataMode dataMode, TransactionMode transactionMode) { - [Test] - public Task CreateReplicationSlot() - => SafeReplicationTest( - async (slotName, _) => - { - await using var c = await OpenConnectionAsync(); - await using var rc = await OpenReplicationConnectionAsync(); - var options = await rc.CreatePgOutputReplicationSlot(slotName); - - using var cmd = - new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.Name}'", - c); - await using var reader = await cmd.ExecuteReaderAsync(); - - Assert.That(reader.Read, Is.True); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("plugin")), Is.EqualTo("pgoutput")); - Assert.That(reader.Read, Is.False); - }); - - [Test(Description = "Tests whether INSERT commands get replicated as Logical Replication Protocol Messages")] - public Task Insert() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('val1'), ('val2')"); + _protocolVersion = (ulong)protocolVersion; + _binary = dataMode == ReplicationDataMode.BinaryReplicationDataMode + ? true + : dataMode == ReplicationDataMode.TextReplicationDataMode + ? false + : null; + _streaming = transactionMode == TransactionMode.StreamingTransactionMode + ? true + : transactionMode == TransactionMode.NonStreamingTransactionMode + ? false + : null; + } - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + [Test] + public Task CreatePgOutputReplicationSlot() + { + // There's nothing special here for binary data or when streaming so only execute once + if (IsBinary || IsStreaming) + return Task.CompletedTask; - // Begin Transaction - _ = await NextMessage(messages); + return SafeReplicationTest( + async (slotName, _) => + { - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('d')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Insert first value - var insertMsg = await NextMessage(messages); - Assert.That(insertMsg.NewRow.Length, Is.EqualTo(2)); - Assert.That(insertMsg.NewRow.Span[0].Value, Is.EqualTo("1")); - Assert.That(insertMsg.NewRow.Span[1].Value, Is.EqualTo("val1")); - - // Insert second value - insertMsg = await NextMessage(messages); - Assert.That(insertMsg.NewRow.Length, Is.EqualTo(2)); - Assert.That(insertMsg.NewRow.Span[0].Value, Is.EqualTo("2")); - Assert.That(insertMsg.NewRow.Span[1].Value, Is.EqualTo("val2")); - - // Commit Transaction - _ = await NextMessage(messages); + await using var c = await OpenConnectionAsync(); + await using var rc = await OpenReplicationConnectionAsync(); + var options = await rc.CreatePgOutputReplicationSlot(slotName); - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + using var cmd = + new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.Name}'", + c); + await using var reader = await cmd.ExecuteReaderAsync(); + + Assert.That(reader.Read, Is.True); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("plugin")), Is.EqualTo("pgoutput")); + Assert.That(reader.Read, Is.False); + }); + } - [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using the default replica identity")] - public Task UpdateForDefaultReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + [Test(Description = "Tests whether INSERT commands get replicated as Logical Replication Protocol Messages")] + public Task Insert() + => SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NULL); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"INSERT INTO {tableName} VALUES (1, 'val1'), (2, NULL), (3, 'ignored'); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(4, 15000) s(i);"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.Default)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Insert first value + var insertMsg = await NextMessage(messages); + Assert.That(insertMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(insertMsg.Relation, Is.SameAs(relationMsg)); + var columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.False); + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("val1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Insert second value + insertMsg = await NextMessage(messages); + Assert.That(insertMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(insertMsg.Relation, Is.SameAs(relationMsg)); + columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(2)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("2")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.True); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Insert third value + insertMsg = await NextMessage(messages); + Assert.That(insertMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(insertMsg.Relation, Is.SameAs(relationMsg)); + await foreach (var tuple in insertMsg.NewRow) // Don't consume the value to trigger eventual bugs + Assert.That(tuple.Kind, IsBinary ? Is.EqualTo(TupleDataKind.BinaryValue) : Is.EqualTo(TupleDataKind.TextValue)); + + // Remaining inserts + for (var insertCount = 0; insertCount < 14997; insertCount++) { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" + await NextMessage(messages); + } + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using the default replica identity")] + public Task Update_for_default_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"UPDATE {tableName} SET name='val1_updated' WHERE id = 1; + UPDATE {tableName} SET name = md5(name) WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.Default)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Update + var updateMsg = await NextMessage(messages); + Assert.That(updateMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(updateMsg.Relation, Is.SameAs(relationMsg)); + var columnEnumerator = updateMsg.NewRow.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.False); + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("val1_updated")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Remaining updates + for (var updateCount = 0; updateCount < 14999; updateCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using an index as replica identity")] + public Task Update_for_index_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + var indexName = $"i_{tableName.Substring(2)}"; + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + CREATE UNIQUE INDEX {indexName} ON {tableName} (name); + ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"UPDATE {tableName} SET name='val1_updated' WHERE id = 1; + UPDATE {tableName} SET name = md5(name) WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.IndexWithIndIsReplIdent)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Update + var updateMsg = await NextMessage(messages); + Assert.That(updateMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(updateMsg.Relation, Is.SameAs(relationMsg)); + + var oldRowColumnEnumerator = updateMsg.Key.GetAsyncEnumerator(); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(oldRowColumnEnumerator.Current.IsDBNull, Is.True); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await oldRowColumnEnumerator.Current.Get(), Is.EqualTo("val1")); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.False); + + var newRowColumnEnumerator = updateMsg.NewRow.GetAsyncEnumerator(); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await newRowColumnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await newRowColumnEnumerator.Current.Get(), Is.EqualTo("1")); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await newRowColumnEnumerator.Current.Get(), Is.EqualTo("val1_updated")); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.False); + + // Remaining updates + for (var updateCount = 0; updateCount < 14999; updateCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using full replica identity")] + public Task Update_for_full_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + ALTER TABLE {tableName} REPLICA IDENTITY FULL; + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"UPDATE {tableName} SET name='val1_updated' WHERE id = 1; + UPDATE {tableName} SET name = md5(name) WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.AllColumns)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Update + var updateMsg = await NextMessage(messages); + Assert.That(updateMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(updateMsg.Relation, Is.SameAs(relationMsg)); + + var oldRowColumnEnumerator = updateMsg.OldRow.GetAsyncEnumerator(); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await oldRowColumnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await oldRowColumnEnumerator.Current.Get(), Is.EqualTo("1")); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await oldRowColumnEnumerator.Current.Get(), Is.EqualTo("val1")); + Assert.That(await oldRowColumnEnumerator.MoveNextAsync(), Is.False); + + var newRowColumnEnumerator = updateMsg.NewRow.GetAsyncEnumerator(); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await newRowColumnEnumerator.Current.Get(), Is.EqualTo("val1_updated")); + Assert.That(await newRowColumnEnumerator.MoveNextAsync(), Is.False); + + // Remaining updates + for (var updateCount = 0; updateCount < 14999; updateCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + Assert.That(async () => await messages.MoveNextAsync(), Throws.Exception.AssignableTo() + .With.InnerException.InstanceOf() + .And.InnerException.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled)); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using the default replica identity")] + public Task Delete_for_default_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"DELETE FROM {tableName} WHERE id = 1; + DELETE FROM {tableName} WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.Default)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Delete + var deleteMsg = await NextMessage(messages); + Assert.That(deleteMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(deleteMsg.Relation, Is.SameAs(relationMsg)); + var columnEnumerator = deleteMsg.Key.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.True); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Remaining deletes + for (var deleteCount = 0; deleteCount < 14999; deleteCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using an index as replica identity")] + public Task Delete_for_index_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + var indexName = $"i_{tableName.Substring(2)}"; + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + CREATE UNIQUE INDEX {indexName} ON {tableName} (name); + ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"DELETE FROM {tableName} WHERE id = 1; + DELETE FROM {tableName} WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.IndexWithIndIsReplIdent)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Delete + var deleteMsg = await NextMessage(messages); + Assert.That(deleteMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(deleteMsg.Relation, Is.SameAs(relationMsg)); + var columnEnumerator = deleteMsg.Key.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.True); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("val1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Remaining deletes + for (var deleteCount = 0; deleteCount < 14999; deleteCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using full replica identity")] + public Task Delete_for_full_replica_identity() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + ALTER TABLE {tableName} REPLICA IDENTITY FULL; + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"DELETE FROM {tableName} WHERE id = 1; + DELETE FROM {tableName} WHERE id > 1"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMsg = await NextMessage(messages); + Assert.That(relationMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMsg.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.AllColumns)); + Assert.That(relationMsg.Namespace, Is.EqualTo("public")); + Assert.That(relationMsg.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("name")); + + // Delete + var deleteMsg = await NextMessage(messages); + Assert.That(deleteMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(deleteMsg.Relation, Is.SameAs(relationMsg)); + var columnEnumerator = deleteMsg.OldRow.GetAsyncEnumerator(); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + Assert.That(columnEnumerator.Current.IsDBNull, Is.False); + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("val1")); + Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); + + // Remaining deletes + for (var deleteCount = 0; deleteCount < 14999; deleteCount++) + await NextMessage(messages); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + + [Test(Description = "Tests whether TRUNCATE commands get replicated as Logical Replication Protocol Messages on PostgreSQL 11 and above")] + [TestCase(TruncateOptions.None)] + [TestCase(TruncateOptions.Cascade)] + [TestCase(TruncateOptions.RestartIdentity)] + [TestCase(TruncateOptions.Cascade | TruncateOptions.RestartIdentity)] + public Task Truncate(TruncateOptions truncateOptionFlags) + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "11.0", "Replication of TRUNCATE commands was introduced in PostgreSQL 11"); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); + INSERT INTO {tableName} (name) VALUES ('val1'); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + var sb = new StringBuilder("TRUNCATE TABLE ").Append(tableName); + if (truncateOptionFlags.HasFlag(TruncateOptions.RestartIdentity)) + sb.Append(" RESTART IDENTITY"); + if (truncateOptionFlags.HasFlag(TruncateOptions.Cascade)) + sb.Append(" CASCADE"); + sb.Append($"; INSERT INTO {tableName} (name) SELECT 'val' || i::text FROM generate_series(1, 15000) s(i);"); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(sb.ToString()); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // Relation + var relationMessage = await NextMessage(messages); + Assert.That(relationMessage.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(relationMessage.ReplicaIdentity, Is.EqualTo(ReplicaIdentitySetting.Default)); + Assert.That(relationMessage.Namespace, Is.EqualTo("public")); + Assert.That(relationMessage.RelationName, Is.EqualTo(tableName)); + Assert.That(relationMessage.Columns.Count, Is.EqualTo(2)); + Assert.That(relationMessage.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMessage.Columns[1].ColumnName, Is.EqualTo("name")); + + // Truncate + var truncateMsg = await NextMessage(messages); + Assert.That(truncateMsg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(truncateMsg.Options, Is.EqualTo(truncateOptionFlags)); + Assert.That(truncateMsg.Relations.Single(), Is.SameAs(relationMessage)); + + // Remaining inserts + // Since the inserts run in the same transaction as the truncate, we'll + // get a RelationMessage after every StreamStartMessage + for (var insertCount = 0; insertCount < 15000; insertCount++) + await NextMessage(messages, expectRelationMessage: true); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }, nameof(Truncate) + truncateOptionFlags.ToString("D")); + + [Test(Description = "Tests whether disposing while replicating will get us stuck forever.")] + public Task Dispose_while_replicating() + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('value 1'), ('value 2');"); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + await NextMessage(messages); + }, nameof(Dispose_while_replicating)); + + [Platform(Exclude = "MacOsX", Reason = "Test is flaky in CI on Mac, see https://github.com/npgsql/npgsql/issues/5294")] + [TestCase(true, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test(Description = "Tests whether logical decoding messages get replicated as Logical Replication Protocol Messages on PostgreSQL 14 and above")] + public Task LogicalDecodingMessage(bool writeMessages, bool readMessages) + => SafeReplicationTest( + async (slotName, tableName, publicationName) => + { + const string prefix = "My test Prefix"; + const string transactionalMessage = "A transactional message"; + const string nonTransactionalMessage = "A non-transactional message"; + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "14.0", "Replication of logical decoding messages was introduced in PostgreSQL 14"); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NOT NULL); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"SELECT pg_logical_emit_message(true, '{prefix}', '{transactionalMessage}'); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i);", tran); + await tran.CommitAsync(); + + await using var tran2 = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"SELECT pg_logical_emit_message(false, '{prefix}', '{nonTransactionalMessage}'); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(15001, 15010) s(i); + SELECT pg_logical_emit_message(true, '{prefix}', '{transactionalMessage}'); + INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(15011, 30000) s(i); + SELECT pg_logical_emit_message(false, '{prefix}', '{nonTransactionalMessage}'); + ", tran2); + await tran2.RollbackAsync(); + await c.ExecuteNonQueryAsync(@$"SELECT pg_switch_wal();"); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, + GetOptions(publicationName, writeMessages), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction 1 + var transactionXid = await AssertTransactionStart(messages); + + // LogicalDecodingMessage + if (writeMessages) + { + var msg = await NextMessage(messages); + Assert.That(msg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(msg.Flags, Is.EqualTo(1)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(transactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(transactionalMessage)); + } + } - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + // Relation + await NextMessage(messages); - // Begin Transaction - _ = await NextMessage(messages); + // Inserts + for (var insertCount = 0; insertCount < 15000; insertCount++) + await NextMessage(messages); - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('d')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Update - var updateMsg = await NextMessage(messages); - Assert.That(updateMsg.NewRow.Length, Is.EqualTo(2)); - Assert.That(updateMsg.NewRow.Span[0].Value, Is.EqualTo("1")); - Assert.That(updateMsg.NewRow.Span[1].Value, Is.EqualTo("val1")); - - // Commit Transaction - _ = await NextMessage(messages); + // Commit Transaction 1 + await AssertTransactionCommit(messages); - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + // LogicalDecodingMessage 1 (non-transactional) + if (writeMessages) + { + var msg = await NextMessage(messages); + Assert.That(msg.TransactionXid, Is.Null); + Assert.That(msg.Flags, Is.EqualTo(0)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(nonTransactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(nonTransactionalMessage)); + } + } - [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using an index as replica identity")] - public Task UpdateForIndexReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + if (IsStreaming) { - await using var c = await OpenConnectionAsync(); - var indexName = $"i_{tableName.Substring(2)}"; - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -CREATE UNIQUE INDEX {indexName} ON {tableName} (name); -ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + // Begin Transaction 2 + transactionXid = await AssertTransactionStart(messages); - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + // Relation + await NextMessage(messages); - // Begin Transaction - _ = await NextMessage(messages); + // Inserts + for (var insertCount = 0; insertCount < 10; insertCount++) + await NextMessage(messages); - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('i')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Update - var updateMsg = await NextMessage(messages); - Assert.That(updateMsg.KeyRow!.Length, Is.EqualTo(2)); - Assert.That(updateMsg.KeyRow!.Span[0].Value, Is.Null); - Assert.That(updateMsg.KeyRow!.Span[1].Value, Is.EqualTo("val")); - Assert.That(updateMsg.NewRow.Length, Is.EqualTo(2)); - Assert.That(updateMsg.NewRow.Span[0].Value, Is.EqualTo("1")); - Assert.That(updateMsg.NewRow.Span[1].Value, Is.EqualTo("val1")); - - // Commit Transaction - _ = await NextMessage(messages); + // LogicalDecodingMessage 2 (transactional) + if (writeMessages) + { + var msg = await NextMessage(messages); + Assert.That(msg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(msg.Flags, Is.EqualTo(1)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(transactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(transactionalMessage)); + } + } - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + // Further inserts + // We don't try to predict how many insert messages we get here + // since the streaming transaction will most likely abort before + // we reach the expected number + while (await messages.MoveNextAsync() && messages.Current is InsertMessage + || messages.Current is StreamStopMessage + && await messages.MoveNextAsync() + && messages.Current is StreamStartMessage + && await messages.MoveNextAsync() + && messages.Current is InsertMessage) + { + // Ignore + } + } + else if (writeMessages) + await messages.MoveNextAsync(); - [Test(Description = "Tests whether UPDATE commands get replicated as Logical Replication Protocol Messages for tables using full replica identity")] - public Task UpdateForFullReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + // LogicalDecodingMessage 3 (non-transactional) + if (writeMessages) { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -ALTER TABLE {tableName} REPLICA IDENTITY FULL; -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + var msg = (LogicalDecodingMessage)messages.Current; + Assert.That(msg.TransactionXid, Is.Null); + Assert.That(msg.Flags, Is.EqualTo(0)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(nonTransactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(nonTransactionalMessage)); + } - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + if (IsStreaming) + await messages.MoveNextAsync(); + } - // Begin Transaction - _ = await NextMessage(messages); + // Rollback Transaction 2 + if (IsStreaming) + Assert.That(messages.Current, Is.TypeOf()); - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('f')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Update - var updateMsg = await NextMessage(messages); - Assert.That(updateMsg.OldRow!.Length, Is.EqualTo(2)); - Assert.That(updateMsg.OldRow!.Span[0].Value, Is.EqualTo("1")); - Assert.That(updateMsg.OldRow!.Span[1].Value, Is.EqualTo("val")); - Assert.That(updateMsg.NewRow.Length, Is.EqualTo(2)); - Assert.That(updateMsg.NewRow.Span[0].Value, Is.EqualTo("1")); - Assert.That(updateMsg.NewRow.Span[1].Value, Is.EqualTo("val1")); - - // Commit Transaction - _ = await NextMessage(messages); + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }, $"{GetObjectName(nameof(LogicalDecodingMessage))}_m_{BoolToChar(writeMessages)}"); - streamingCts.Cancel(); - Assert.That(async () => await messages.MoveNextAsync(), Throws.Exception.AssignableTo() - .With.InnerException.InstanceOf() - .And.InnerException.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled)); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + [Test] + public Task Stream() + { + // We don't test transaction streaming here because there's nothing special in that case + if (IsStreaming) + return Task.CompletedTask; - [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using the default replica identity")] - public Task DeleteForDefaultReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (bytes bytea); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + var bytes = new byte[16384]; + for (var i = 0; i < 10; i++) + bytes[i] = (byte)i; - // Begin Transaction - _ = await NextMessage(messages); + using (var command = new NpgsqlCommand($"INSERT INTO {tableName} VALUES ($1)", c)) + { + command.Parameters.Add(new() { Value = bytes }); + await command.ExecuteNonQueryAsync(); + } - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('d')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Delete - var deleteMsg = await NextMessage(messages); - Assert.That(deleteMsg.KeyRow!.Length, Is.EqualTo(2)); - Assert.That(deleteMsg.KeyRow.Span[0].Value, Is.EqualTo("2")); - Assert.That(deleteMsg.KeyRow.Span[1].Value, Is.Null); - - // Commit Transaction - _ = await NextMessage(messages); + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + await AssertTransactionStart(messages); + await NextMessage(messages); + var insertMsg = await NextMessage(messages); + var columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); - [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using an index as replica identity")] - public Task DeleteForIndexReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + var stream = columnEnumerator.Current.GetStream(); + Assert.That(() => columnEnumerator.Current.GetStream(), Throws.Exception.TypeOf()); + Assert.That(() => columnEnumerator.Current.Get(), Throws.Exception.TypeOf()); + Assert.That(() => columnEnumerator.Current.Get(), Throws.Exception.TypeOf()); + + if (IsBinary) { - await using var c = await OpenConnectionAsync(); - var indexName = $"i_{tableName.Substring(2)}"; - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -CREATE UNIQUE INDEX {indexName} ON {tableName} (name); -ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + var someBytes = new byte[10]; + Assert.That(await stream.ReadAsync(someBytes, 0, 10), Is.EqualTo(10)); + Assert.That(someBytes, Is.EquivalentTo(bytes[..10])); + } + else + { + // We assume bytea hex format here + var hexString = "\\x" + BitConverter.ToString(bytes[..10]).Replace("-", string.Empty); + var expected = Encoding.ASCII.GetBytes(hexString); + var someBytes = new byte[expected.Length]; + Assert.That(await stream.ReadAsync(someBytes, 0, someBytes.Length), Is.EqualTo(someBytes.Length)); + Assert.That(someBytes, Is.EquivalentTo(expected)); + } - using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) - .GetAsyncEnumerator(); + await AssertTransactionCommit(messages); - // Begin Transaction - _ = await NextMessage(messages); + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + } - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('i')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Delete - var deleteMsg = await NextMessage(messages); - Assert.That(deleteMsg.KeyRow!.Length, Is.EqualTo(2)); - Assert.That(deleteMsg.KeyRow.Span[0].Value, Is.Null); - Assert.That(deleteMsg.KeyRow.Span[1].Value, Is.EqualTo("val2")); - - // Commit Transaction - _ = await NextMessage(messages); + [Test] + public Task TextReader() + { + // We don't test transaction streaming here because there's nothing special in that case + if (IsStreaming) + return Task.CompletedTask; - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NULL); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + var expectedText = "val1"; + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES (1, '{expectedText}')"); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + await AssertTransactionStart(messages); + await NextMessage(messages); + var insertMsg = await NextMessage(messages); + var columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); // We are not interested in the id field + await columnEnumerator.MoveNextAsync(); + using var reader = columnEnumerator.Current.GetTextReader(); + Assert.That(await reader.ReadToEndAsync(), Is.EqualTo(expectedText)); + + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + } - [Test(Description = "Tests whether DELETE commands get replicated as Logical Replication Protocol Messages for tables using full replica identity")] - public Task DeleteForFullReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + [Test] + public Task ValueMetadata() + { + // We don't test transaction streaming here because there's nothing special in that case + if (IsStreaming) + return Task.CompletedTask; + + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (id INT PRIMARY KEY, name TEXT NULL); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES (1, 'val1')"); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + await AssertTransactionStart(messages); + await NextMessage(messages); + var insertMsg = await NextMessage(messages); + var columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); + + Assert.That(columnEnumerator.Current.GetFieldType(), Is.SameAs(IsBinary ? typeof(int) : typeof(string))); + Assert.That(columnEnumerator.Current.GetPostgresType().Name, Is.EqualTo("integer")); + Assert.That(columnEnumerator.Current.GetDataTypeName(), Is.EqualTo("integer")); + Assert.That(columnEnumerator.Current.IsUnchangedToastedValue, Is.False); + + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + } + + [Test] + public Task Null() + { + // We don't test transaction streaming here because there's nothing special in that case + if (IsStreaming) + return Task.CompletedTask; + + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (int1 INT, int2 INT); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES (1, 1), (NULL, NULL)"); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + await AssertTransactionStart(messages); + await NextMessage(messages); + + // non-null + var columnEnumerator = (await NextMessage(messages)).NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); + Assert.That(columnEnumerator.Current.IsDBNull, Is.False); + Assert.That(columnEnumerator.Current.IsUnchangedToastedValue, Is.False); + if (IsBinary) + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + await columnEnumerator.MoveNextAsync(); + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(IsBinary ? 1 : "1")); + + // null + columnEnumerator = (await NextMessage(messages)).NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); + Assert.That(columnEnumerator.Current.IsDBNull, Is.True); + Assert.That(columnEnumerator.Current.IsUnchangedToastedValue, Is.False); + if (IsBinary) + Assert.That(() => columnEnumerator.Current.Get(), Throws.Exception.TypeOf()); + else + Assert.That(() => columnEnumerator.Current.Get(), Throws.Exception.TypeOf()); + await columnEnumerator.MoveNextAsync(); + Assert.That(await columnEnumerator.Current.Get(), Is.SameAs(DBNull.Value)); + + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }); + } + + [NpgsqlTypes.PgName("descriptor")] + public class Descriptor + { + [NpgsqlTypes.PgName("id")] + public long Id { get; set; } + + [NpgsqlTypes.PgName("name")] + public string Name { get; set; } = string.Empty; + } + +#pragma warning disable CS0618 // GlobalTypeMapper is obsolete + [Test, NonParallelizable] + public Task CompositeType() + { + // We don't test transaction streaming here because there's nothing special in that case + if (IsStreaming) + return Task.CompletedTask; + + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + await using var adminConnection = await OpenConnectionAsync(); + await adminConnection.ExecuteNonQueryAsync(@$" +DROP TYPE IF EXISTS descriptor CASCADE; +CREATE TYPE descriptor AS (id bigint, name text); +CREATE TABLE {tableName} (descriptor_field descriptor); +CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + + NpgsqlConnection.GlobalTypeMapper.MapComposite("descriptor"); + + try { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -ALTER TABLE {tableName} REPLICA IDENTITY FULL; -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); + + // Use a one-time connection string to make sure we get a new data source without cached mappings. + // In regular tests we'd use a data source, but replication doesn't work with data sources (yet). + // In addition, clear the DatabaseInfo cache. + using var _ = CreateTempPool(ConnectionString, out var connString); + var rc = await OpenReplicationConnectionAsync(connString); var slot = await rc.CreatePgOutputReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + var expected = new Descriptor { Id = 1248, Name = "My Descriptor" }; + var stringValue = $"({expected.Id},\"{expected.Name}\")"; + + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES ('{stringValue}')"); using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) .GetAsyncEnumerator(); - // Begin Transaction - _ = await NextMessage(messages); + await AssertTransactionStart(messages); + await NextMessage(messages); + await NextMessage(messages); - // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('f')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Delete - var deleteMsg = await NextMessage(messages); - Assert.That(deleteMsg.OldRow!.Length, Is.EqualTo(2)); - Assert.That(deleteMsg.OldRow.Span[0].Value, Is.EqualTo("2")); - Assert.That(deleteMsg.OldRow.Span[1].Value, Is.EqualTo("val2")); - - // Commit Transaction - _ = await NextMessage(messages); + // non-null + var columnEnumerator = (await NextMessage(messages)).NewRow.GetAsyncEnumerator(); + await columnEnumerator.MoveNextAsync(); + Assert.That(columnEnumerator.Current.IsDBNull, Is.False); + Assert.That(columnEnumerator.Current.IsUnchangedToastedValue, Is.False); + if (IsBinary) + { + var result = await columnEnumerator.Current.Get(); + Assert.That(result.Id, Is.EqualTo(expected.Id)); + Assert.That(result.Name, Is.EqualTo(expected.Name)); + } + else + Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(stringValue)); + + await columnEnumerator.MoveNextAsync(); + + await AssertTransactionCommit(messages); streamingCts.Cancel(); await AssertReplicationCancellation(messages); await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }); - - [Test(Description = "Tests whether TRUNCATE commands get replicated as Logical Replication Protocol Messages on PostgreSQL 11 and above")] - [TestCase(TruncateOptions.None)] - [TestCase(TruncateOptions.Cascade)] - [TestCase(TruncateOptions.RestartIdentity)] - [TestCase(TruncateOptions.Cascade | TruncateOptions.RestartIdentity)] - public Task Truncate(TruncateOptions truncateOptionFlags) - => SafeReplicationTest( - async (slotName, tableName, publicationName) => + } + finally { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "11.0", "Replication of TRUNCATE commands was introduced in PostgreSQL 11"); - await c.ExecuteNonQueryAsync(@$" -CREATE TABLE {tableName} (id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL); -INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); -CREATE PUBLICATION {publicationName} FOR TABLE {tableName}; -"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreatePgOutputReplicationSlot(slotName); - StringBuilder sb = new StringBuilder("TRUNCATE TABLE ").Append(tableName); - if (truncateOptionFlags.HasFlag(TruncateOptions.RestartIdentity)) - sb.Append(" RESTART IDENTITY"); - if (truncateOptionFlags.HasFlag(TruncateOptions.Cascade)) - sb.Append(" CASCADE"); - await c.ExecuteNonQueryAsync(sb.ToString()); + await adminConnection.ExecuteNonQueryAsync("DROP TYPE IF EXISTS descriptor CASCADE;"); + + NpgsqlConnection.GlobalTypeMapper.Reset(); + } + }); + } +#pragma warning restore CS0618 // GlobalTypeMapper is obsolete + + [Test] + public Task TwoPhase([Values]bool commit) + { + // Streaming of prepared transaction is only supported for + // logical streaming replication protocol >= 3 + if (_protocolVersion < 3UL) + return Task.CompletedTask; + return SafePgOutputReplicationTest( + async (slotName, tableName, publicationName) => + { + var gid = Guid.NewGuid().ToString(); + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$"CREATE TABLE {tableName} (a int primary key, b varchar); + CREATE PUBLICATION {publicationName} FOR TABLE {tableName};"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName, twoPhase: true); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$"INSERT INTO {tableName} SELECT i, 'val' || i::text FROM generate_series(1, 15000) s(i); + PREPARE TRANSACTION '{gid}';"); + try + { using var streamingCts = new CancellationTokenSource(); - var messages = SkipEmptyTransactions(rc.StartReplication(slot, new PgOutputReplicationOptions(publicationName), streamingCts.Token)) + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) .GetAsyncEnumerator(); // Begin Transaction - _ = await NextMessage(messages); + var transactionXid = await AssertTransactionStart(messages); // Relation - var relMsg = await NextMessage(messages); - Assert.That(relMsg.RelationReplicaIdentitySetting, Is.EqualTo('d')); - Assert.That(relMsg.Namespace, Is.EqualTo("public")); - Assert.That(relMsg.RelationName, Is.EqualTo(tableName)); - Assert.That(relMsg.Columns.Length, Is.EqualTo(2)); - Assert.That(relMsg.Columns.Span[0].ColumnName, Is.EqualTo("id")); - Assert.That(relMsg.Columns.Span[1].ColumnName, Is.EqualTo("name")); - - // Truncate - var truncateMsg = await NextMessage(messages); - Assert.That(truncateMsg.Options, Is.EqualTo(truncateOptionFlags)); - Assert.That(truncateMsg.RelationIds.Length, Is.EqualTo(1)); - - // Commit Transaction - _ = await NextMessage(messages); + await NextMessage(messages); + + // Remaining inserts + for (var insertCount = 0; insertCount < 15000; insertCount++) + { + await NextMessage(messages); + } + + var prepareMessageBase = await AssertPrepare(messages); + Assert.That(prepareMessageBase.TransactionXid, Is.EqualTo(transactionXid)); + Assert.That(prepareMessageBase.TransactionGid, Is.EqualTo(gid)); + + if (commit) + { + await c.ExecuteNonQueryAsync(@$"COMMIT PREPARED '{gid}';"); + + var commitPreparedMessage = await NextMessage(messages); + Assert.That(commitPreparedMessage.TransactionXid, Is.EqualTo(transactionXid)); + Assert.That(commitPreparedMessage.TransactionGid, Is.EqualTo(gid)); + } + else + { + await c.ExecuteNonQueryAsync(@$"ROLLBACK PREPARED '{gid}';"); + + var rollbackPreparedMessage = await NextMessage(messages); + Assert.That(rollbackPreparedMessage.TransactionXid, Is.EqualTo(transactionXid)); + Assert.That(rollbackPreparedMessage.TransactionGid, Is.EqualTo(gid)); + } streamingCts.Cancel(); await AssertReplicationCancellation(messages); await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); - }, nameof(Truncate) + truncateOptionFlags.ToString("D")); + } + finally + { + try + { + await using var cx = await OpenConnectionAsync(); + await cx.ExecuteNonQueryAsync(@$"ROLLBACK PREPARED '{gid}';"); + } + catch + { + // Give up + } + } + }, $"{GetObjectName(nameof(TwoPhase))}_{(commit ? "commit" : "rollback")}"); + } + + + [Test(Description = "Tests whether columns of internally cached RelationMessage instances are accidentally overwritten.")] + [IssueLink("https://github.com/npgsql/npgsql/issues/4633")] + public Task Bug4633() + { + // We don't need all the various test cases here since the bug gets triggered in any case + if (IsStreaming || IsBinary || Version > 1) + return Task.CompletedTask; + + return SafePgOutputReplicationTest( + async (slotName, tableNames, publicationName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$" +CREATE TABLE {tableNames[0]} +( + id uuid NOT NULL, + text text NOT NULL, + created_at timestamp with time zone NOT NULL, + CONSTRAINT pk_{tableNames[0]} PRIMARY KEY (id) +); +CREATE TABLE {tableNames[1]} +( + id uuid NOT NULL, + message_id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + CONSTRAINT pk_{tableNames[1]} PRIMARY KEY (id), + CONSTRAINT fk_{tableNames[1]}_message_id FOREIGN KEY (message_id) REFERENCES {tableNames[0]} (id) +); +CREATE PUBLICATION {publicationName} FOR TABLE {tableNames[0]}, {tableNames[1]} WITH (PUBLISH = 'insert');"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreatePgOutputReplicationSlot(slotName); + + await using var tran = await c.BeginTransactionAsync(); + await c.ExecuteNonQueryAsync(@$" +INSERT INTO {tableNames[0]} VALUES ('B6CB5293-F65E-4F48-A74B-06D5355DAA74', 'random', now()); +INSERT INTO {tableNames[1]} VALUES ('55870BEC-C42E-4AB0-83BA-225BB7777B37', 'B6CB5293-F65E-4F48-A74B-06D5355DAA74', now()); +INSERT INTO {tableNames[0]} VALUES ('5F89F5FE-6F4F-465F-BB87-716B1413F88D', 'another random', now());"); + await tran.CommitAsync(); + + using var streamingCts = new CancellationTokenSource(); + var messages = SkipEmptyTransactions(rc.StartReplication(slot, GetOptions(publicationName), streamingCts.Token)) + .GetAsyncEnumerator(); + + // Begin Transaction + var transactionXid = await AssertTransactionStart(messages); + + // First Relation + var relationMsg = await NextMessage(messages); + var relation1Name = relationMsg.RelationName; + var relation1Id = relationMsg.RelationId; + Assert.That(relation1Name, Is.EqualTo(tableNames[0])); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(3)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("text")); + Assert.That(relationMsg.Columns[2].ColumnName, Is.EqualTo("created_at")); + + // Insert first value + var insertMsg = await NextMessage(messages); + Assert.That(insertMsg.Relation.RelationName, Is.EqualTo(relation1Name)); + Assert.That(insertMsg.Relation.RelationId, Is.EqualTo(relation1Id)); + Assert.That(insertMsg.Relation.Columns.Count, Is.EqualTo(3)); + Assert.That(insertMsg.Relation.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(insertMsg.Relation.Columns[1].ColumnName, Is.EqualTo("text")); + Assert.That(insertMsg.Relation.Columns[2].ColumnName, Is.EqualTo("created_at")); + + // Second Relation + relationMsg = await NextMessage(messages); + var relation2Name = relationMsg.RelationName; + var relation2Id = relationMsg.RelationId; + Assert.That(relation2Name, Is.EqualTo(tableNames[1])); + Assert.That(relationMsg.Columns.Count, Is.EqualTo(3)); + Assert.That(relationMsg.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(relationMsg.Columns[1].ColumnName, Is.EqualTo("message_id")); + Assert.That(relationMsg.Columns[2].ColumnName, Is.EqualTo("created_at")); + + // Insert second value + insertMsg = await NextMessage(messages); + Assert.That(insertMsg.Relation.RelationName, Is.EqualTo(relation2Name)); + Assert.That(insertMsg.Relation.RelationId, Is.EqualTo(relation2Id)); + Assert.That(insertMsg.Relation.Columns.Count, Is.EqualTo(3)); + Assert.That(insertMsg.Relation.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(insertMsg.Relation.Columns[1].ColumnName, Is.EqualTo("message_id")); + Assert.That(insertMsg.Relation.Columns[2].ColumnName, Is.EqualTo("created_at")); + + // Insert third value + insertMsg = await NextMessage(messages); + Assert.That(insertMsg.Relation.RelationName, Is.EqualTo(relation1Name)); + Assert.That(insertMsg.Relation.RelationId, Is.EqualTo(relation1Id)); + Assert.That(insertMsg.Relation.Columns.Count, Is.EqualTo(3)); + Assert.That(insertMsg.Relation.Columns[0].ColumnName, Is.EqualTo("id")); + Assert.That(insertMsg.Relation.Columns[1].ColumnName, Is.EqualTo("text")); + Assert.That(insertMsg.Relation.Columns[2].ColumnName, Is.EqualTo("created_at")); + + // Commit Transaction + await AssertTransactionCommit(messages); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); + }, 2); + } - async ValueTask NextMessage(IAsyncEnumerator enumerator) - where TExpected : PgOutputReplicationMessage + #region Non-Test stuff (helper methods, initialization, enums, ...) + + async Task AssertTransactionStart(IAsyncEnumerator messages) + { + Assert.True(await messages.MoveNextAsync()); + + switch (messages.Current) + { + case StreamStartMessage streamStartMessage: + Assert.That(IsStreaming); + return streamStartMessage.TransactionXid; + case BeginMessage beginMessage: + Assert.That(!IsStreaming); + return beginMessage.TransactionXid; + case BeginPrepareMessage beginPrepareMessage: + Assert.That(!IsStreaming); + return beginPrepareMessage.TransactionXid; + default: + Assert.Fail("Expected transaction start message but got: " + messages.Current); + throw new Exception(); + } + } + + async Task AssertTransactionCommit(IAsyncEnumerator messages) + { + Assert.True(await messages.MoveNextAsync()); + + switch (messages.Current) + { + case StreamStopMessage: + Assert.That(IsStreaming); + Assert.True(await messages.MoveNextAsync()); + Assert.That(messages.Current, Is.TypeOf()); + return; + case CommitMessage: + return; + default: + Assert.Fail("Expected transaction end message but got: " + messages.Current); + throw new Exception(); + } + } + + async Task AssertPrepare(IAsyncEnumerator enumerator) + { + Assert.True(await enumerator.MoveNextAsync()); + if (IsStreaming && enumerator.Current is StreamStopMessage) { Assert.True(await enumerator.MoveNextAsync()); - Assert.That(enumerator.Current, Is.TypeOf()); - return (TExpected)enumerator.Current!; + Assert.That(enumerator.Current, Is.TypeOf()); + return (PrepareMessageBase)enumerator.Current!; } - /// - /// Unfortunately, empty transactions may get randomly created by PG because of auto-vacuuming; these cause test failures as we - /// assert for specific expected message types. This filters them out. - /// - async IAsyncEnumerable SkipEmptyTransactions(IAsyncEnumerable messages) + Assert.That(enumerator.Current, Is.TypeOf()); + return (PrepareMessageBase)enumerator.Current!; + } + + async ValueTask NextMessage(IAsyncEnumerator enumerator, bool expectRelationMessage = false) + where TExpected : PgOutputReplicationMessage + { + Assert.True(await enumerator.MoveNextAsync()); + if (IsStreaming && enumerator.Current is StreamStopMessage) { - var enumerator = messages.GetAsyncEnumerator(); - while (await enumerator.MoveNextAsync()) + Assert.True(await enumerator.MoveNextAsync()); + Assert.That(enumerator.Current, Is.TypeOf()); + Assert.True(await enumerator.MoveNextAsync()); + if (expectRelationMessage) { - if (enumerator.Current is BeginMessage) - { - var current = enumerator.Current.Clone(); - if (!await enumerator.MoveNextAsync()) - { - yield return current; - yield break; - } + Assert.That(enumerator.Current, Is.TypeOf()); + Assert.True(await enumerator.MoveNextAsync()); + } + } - var next = enumerator.Current; - if (next is CommitMessage) - continue; + Assert.That(enumerator.Current, Is.TypeOf()); + return (TExpected)enumerator.Current!; + } + /// + /// Unfortunately, empty transactions may get randomly created by PG because of auto-vacuuming; these cause test failures as we + /// assert for specific expected message types. This filters them out. + /// + async IAsyncEnumerable SkipEmptyTransactions(IAsyncEnumerable messages) + { + var enumerator = messages.GetAsyncEnumerator(); + while (await enumerator.MoveNextAsync()) + { + if (enumerator.Current is BeginMessage) + { + var current = enumerator.Current; + if (!await enumerator.MoveNextAsync()) + { yield return current; - yield return next; - continue; + yield break; } - yield return enumerator.Current; + var next = enumerator.Current; + if (next is CommitMessage) + continue; + + yield return current; + yield return next; + continue; } + + yield return enumerator.Current; } + } + + PgOutputReplicationOptions GetOptions(string publicationName, bool? messages = null) + => new(publicationName, _protocolVersion, _binary, _streaming, messages); + + Task SafePgOutputReplicationTest(Func testAction, [CallerMemberName] string memberName = "") + => SafeReplicationTest(testAction, GetObjectName(memberName)); + + Task SafePgOutputReplicationTest(Func testAction, int tableCount, [CallerMemberName] string memberName = "") + => SafeReplicationTest(testAction, tableCount, GetObjectName(memberName)); + + string GetObjectName(string memberName) + { + var sb = new StringBuilder(memberName) + .Append("_v").Append(_protocolVersion); + if (_binary.HasValue) + sb.Append("_b_").Append(BoolToChar(_binary.Value)); + if (_streaming.HasValue) + sb.Append("_s_").Append(BoolToChar(_streaming.Value)); + return sb.ToString(); + } + + static char BoolToChar(bool value) + => value ? 't' : 'f'; + - protected override string Postfix => "pgoutput_l"; + protected override string Postfix => "pgoutput_l"; - [OneTimeSetUp] - public async Task SetUp() + [OneTimeSetUp] + public async Task SetUp() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "10.0", "The Logical Replication Protocol (via pgoutput plugin) was introduced in PostgreSQL 10"); + if (_protocolVersion > 2) + TestUtil.MinimumPgVersion(c, "15.0", "Logical Streaming Replication Protocol version 3 was introduced in PostgreSQL 15"); + if (_protocolVersion > 1) + TestUtil.MinimumPgVersion(c, "14.0", "Logical Streaming Replication Protocol version 2 was introduced in PostgreSQL 14"); + if (IsBinary) + TestUtil.MinimumPgVersion(c, "14.0", "Sending replication values in binary representation was introduced in PostgreSQL 14"); + if (IsStreaming) { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "10.0", "The Logical Replication Protocol (via pgoutput plugin) was introduced in PostgreSQL 10"); + TestUtil.MinimumPgVersion(c, "14.0", "Streaming of in-progress transactions was introduced in PostgreSQL 14"); + var logicalDecodingWorkMem = (string)(await c.ExecuteScalarAsync("SHOW logical_decoding_work_mem"))!; + if (logicalDecodingWorkMem != "64kB") + { + TestUtil.IgnoreExceptOnBuildServer( + $"logical_decoding_work_mem is set to '{logicalDecodingWorkMem}', but must be set to '64kB' in order for the " + + "streaming replication tests to work correctly. Skipping replication tests"); + } } } + + public enum ProtocolVersion : ulong + { + V1 = 1UL, + V2 = 2UL, + V3 = 3UL, + } + public enum ReplicationDataMode + { + DefaultReplicationDataMode, + TextReplicationDataMode, + BinaryReplicationDataMode, + } + public enum TransactionMode + { + DefaultTransactionMode, + NonStreamingTransactionMode, + StreamingTransactionMode, + } + + #endregion Non-Test stuff (helper methods, initialization, ennums, ...) } diff --git a/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs b/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs index e1b1694aec..59698b87ac 100644 --- a/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs @@ -3,86 +3,85 @@ using System.Threading.Tasks; using NUnit.Framework; using Npgsql.Replication; +using NpgsqlTypes; -namespace Npgsql.Tests.Replication +namespace Npgsql.Tests.Replication; + +[Explicit("Flakiness")] +public class PhysicalReplicationTests : SafeReplicationTestBase { - [Explicit("Flakiness")] - public class PhysicalReplicationTests : SafeReplicationTestBase - { - [Test] - public Task CreateReplicationSlot() - => SafeReplicationTest( - async (slotName, _) => - { - await using var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateReplicationSlot(slotName); - - await using var c = await OpenConnectionAsync(); - using var cmd = - new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot.Name}'", - c); - await using var reader = await cmd.ExecuteReaderAsync(); - - Assert.That(reader.Read, Is.True); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("physical")); - Assert.That(reader.Read, Is.False); - await rc.DropReplicationSlot(slotName); - }); - - [Test] - public Task WithSlot() - => SafeReplicationTest( - async (slotName, tableName) => - { - // var messages = new ConcurrentQueue<(NpgsqlLogSequenceNumber WalStart, NpgsqlLogSequenceNumber WalEnd, byte[] data)>(); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateReplicationSlot(slotName); - var info = await rc.IdentifySystem(); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, info.XLogPos, streamingCts.Token).GetAsyncEnumerator(); - - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (value text)"); - - for (var i = 1; i <= 10; i++) - await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES ('Value {i}')"); - - // We can't assert a lot in physical replication. - // Since we're replicating in the scope of the whole cluster, - // other transactions possibly from system processes can - // interfere here, inserting additional messages, but more - // likely we'll get everything in one big chunk. - Assert.True(await messages.MoveNextAsync()); - var message = messages.Current; - Assert.That(message.WalStart, Is.EqualTo(info.XLogPos)); - Assert.That(message.WalEnd, Is.GreaterThan(message.WalStart)); - Assert.That(message.Data.Length, Is.GreaterThan(0)); - - streamingCts.Cancel(); - var exception = Assert.ThrowsAsync(Is.AssignableTo(), async () => await messages.MoveNextAsync()); - if (c.PostgreSqlVersion < Version.Parse("9.4")) - { - Assert.That(exception, Has.InnerException.InstanceOf() - .And.InnerException.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled)); - } - }); - - [Test] - public async Task WithoutSlot() - { - var rc = await OpenReplicationConnectionAsync(); - var info = await rc.IdentifySystem(); + [Test] + public Task CreateReplicationSlot([Values]bool temporary, [Values]bool reserveWal) + => SafeReplicationTest( + async (slotName, _) => + { + await using var c = await OpenConnectionAsync(); + if (reserveWal) + TestUtil.MinimumPgVersion(c, "10.0", "The RESERVE_WAL syntax was introduced in PostgreSQL 10"); + if (temporary) + TestUtil.MinimumPgVersion(c, "10.0", "Temporary replication slots were introduced in PostgreSQL 10"); + + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateReplicationSlot(slotName, temporary, reserveWal); - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(info.XLogPos, streamingCts.Token).GetAsyncEnumerator(); + using var cmd = + new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot.Name}'", + c); + await using var reader = await cmd.ExecuteReaderAsync(); - var tableName = "t_physicalreplicationwithoutslot_p"; - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync($"CREATE TABLE IF NOT EXISTS {tableName} (value text)"); - try + Assert.That(reader.Read, Is.True); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("physical")); + Assert.That(reader.Read, Is.False); + await rc.DropReplicationSlot(slotName); + }, nameof(CreateReplicationSlot) + (temporary ? "_t" : "") + (reserveWal ? "_r" : "")); + + [TestCase(true, true)] + [TestCase(true, false)] + [TestCase(false, false)] + public Task ReadReplicationSlot(bool createSlot, bool reserveWal) + => SafeReplicationTest( + async (slotName, _) => { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "15.0", "The READ_REPLICATION_SLOT command was introduced in PostgreSQL 15"); + if (createSlot) + await c.ExecuteNonQueryAsync($"SELECT pg_create_physical_replication_slot('{slotName}', {reserveWal}, false)"); + using var cmd = + new NpgsqlCommand($@"SELECT slot_name, substring(pg_walfile_name(restart_lsn), 1, 8)::bigint AS timeline_id, restart_lsn + FROM pg_replication_slots + WHERE slot_name = '{slotName}'", c); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(reader.Read, Is.EqualTo(createSlot)); + var expectedSlotName = createSlot ? reader.GetFieldValue(reader.GetOrdinal("slot_name")) : null; + var expectedTli = createSlot ? (uint?)reader.GetFieldValue(reader.GetOrdinal("timeline_id")) : null; + var expectedRestartLsn = createSlot ? reader.GetFieldValue(reader.GetOrdinal("restart_lsn")) : null; + Assert.That(reader.Read, Is.False); + await using var rc = await OpenReplicationConnectionAsync(); + + var slot = await rc.ReadReplicationSlot(slotName); + + Assert.That(slot?.Name, Is.EqualTo(expectedSlotName)); + Assert.That(slot?.RestartTimeline, Is.EqualTo(expectedTli)); + Assert.That(slot?.RestartLsn, Is.EqualTo(expectedRestartLsn)); + + }, $"{nameof(ReadReplicationSlot)}_{reserveWal}"); + + [Test] + public Task Replication_with_slot() + => SafeReplicationTest( + async (slotName, tableName) => + { + // var messages = new ConcurrentQueue<(NpgsqlLogSequenceNumber WalStart, NpgsqlLogSequenceNumber WalEnd, byte[] data)>(); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateReplicationSlot(slotName); + var info = await rc.IdentifySystem(); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, info.XLogPos, streamingCts.Token).GetAsyncEnumerator(); + + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (value text)"); + for (var i = 1; i <= 10; i++) await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES ('Value {i}')"); @@ -105,13 +104,50 @@ public async Task WithoutSlot() .And.InnerException.Property(nameof(PostgresException.SqlState)) .EqualTo(PostgresErrorCodes.QueryCanceled)); } - } - finally + }); + + [Test] + public async Task Replication_without_slot() + { + await using var rc = await OpenReplicationConnectionAsync(); + var info = await rc.IdentifySystem(); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(info.XLogPos, streamingCts.Token).GetAsyncEnumerator(); + + var tableName = "t_physicalreplicationwithoutslot_p"; + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync($"CREATE TABLE IF NOT EXISTS {tableName} (value text)"); + try + { + for (var i = 1; i <= 10; i++) + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} VALUES ('Value {i}')"); + + // We can't assert a lot in physical replication. + // Since we're replicating in the scope of the whole cluster, + // other transactions possibly from system processes can + // interfere here, inserting additional messages, but more + // likely we'll get everything in one big chunk. + Assert.True(await messages.MoveNextAsync()); + var message = messages.Current; + Assert.That(message.WalStart, Is.EqualTo(info.XLogPos)); + Assert.That(message.WalEnd, Is.GreaterThan(message.WalStart)); + Assert.That(message.Data.Length, Is.GreaterThan(0)); + + streamingCts.Cancel(); + var exception = Assert.ThrowsAsync(Is.AssignableTo(), async () => await messages.MoveNextAsync()); + if (c.PostgreSqlVersion < Version.Parse("9.4")) { - await c.ExecuteNonQueryAsync($"DROP TABLE {tableName}"); + Assert.That(exception, Has.InnerException.InstanceOf() + .And.InnerException.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled)); } } - - protected override string Postfix => "physical_p"; + finally + { + await c.ExecuteNonQueryAsync($"DROP TABLE {tableName}"); + } } + + protected override string Postfix => "physical_p"; } diff --git a/test/Npgsql.Tests/Replication/SafeReplicationTestBase.cs b/test/Npgsql.Tests/Replication/SafeReplicationTestBase.cs index c5786f8000..77f67eaf4b 100644 --- a/test/Npgsql.Tests/Replication/SafeReplicationTestBase.cs +++ b/test/Npgsql.Tests/Replication/SafeReplicationTestBase.cs @@ -7,135 +7,169 @@ using NUnit.Framework; using Npgsql.Replication; -namespace Npgsql.Tests.Replication +namespace Npgsql.Tests.Replication; + +public abstract class SafeReplicationTestBase : TestBase + where TConnection : ReplicationConnection, new() { - public abstract class SafeReplicationTestBase : TestBase - where TConnection : ReplicationConnection, new() + protected abstract string Postfix { get; } + + int _maxIdentifierLength; + static Version CurrentServerVersion = null!; + + [OneTimeSetUp] + public async Task OneTimeSetUp() { - protected abstract string Postfix { get; } + await using var conn = await OpenConnectionAsync(); + CurrentServerVersion = conn.PostgreSqlVersion; + _maxIdentifierLength = int.Parse((string)(await conn.ExecuteScalarAsync("SHOW max_identifier_length"))!); + } - int _maxIdentifierLength; - static Version CurrentServerVersion = null!; + [SetUp] + public async Task Setup() + { + await using var conn = await OpenConnectionAsync(); + var walLevel = (string)(await conn.ExecuteScalarAsync("SHOW wal_level"))!; + if (walLevel != "logical") + TestUtil.IgnoreExceptOnBuildServer("wal_level needs to be set to 'logical' in the PostgreSQL conf"); - [OneTimeSetUp] - public async Task OneTimeSetUp() + var maxWalSenders = int.Parse((string)(await conn.ExecuteScalarAsync("SHOW max_wal_senders"))!); + if (maxWalSenders < 50) { - await using var conn = await OpenConnectionAsync(); - CurrentServerVersion = conn.PostgreSqlVersion; - _maxIdentifierLength = int.Parse((string)(await conn.ExecuteScalarAsync("SHOW max_identifier_length"))!); + TestUtil.IgnoreExceptOnBuildServer( + $"max_wal_senders is too low ({maxWalSenders}) and could lead to transient failures. Skipping replication tests"); } + } - [SetUp] - public async Task Setup() + private protected Task OpenReplicationConnectionAsync( + NpgsqlConnectionStringBuilder csb, + CancellationToken cancellationToken = default) + => OpenReplicationConnectionAsync(csb.ToString(), cancellationToken); + + private protected async Task OpenReplicationConnectionAsync( + string? connectionString = null, + CancellationToken cancellationToken = default) + { + var c = new TConnection { ConnectionString = connectionString ?? ConnectionString }; + await c.Open(cancellationToken); + return c; + } + + private protected static async Task AssertReplicationCancellation(IAsyncEnumerator enumerator, bool streamingStarted = true) + { + try { - await using var conn = await OpenConnectionAsync(); - var walLevel = (string)(await conn.ExecuteScalarAsync("SHOW wal_level"))!; - if (walLevel != "logical") - TestUtil.IgnoreExceptOnBuildServer("wal_level needs to be set to 'logical' in the PostgreSQL conf"); - - var maxWalSenders = int.Parse((string)(await conn.ExecuteScalarAsync("SHOW max_wal_senders"))!); - if (maxWalSenders < 50) - TestUtil.IgnoreExceptOnBuildServer( - $"max_wal_senders is too low ({maxWalSenders}) and could lead to transient failures. Skipping replication tests"); + var succeeded = await enumerator.MoveNextAsync(); + Assert.Fail(succeeded + ? $"Expected replication cancellation but got message: {enumerator.Current}" + : "Expected replication cancellation but reached enumeration end instead"); } - - private protected async Task OpenReplicationConnectionAsync(NpgsqlConnectionStringBuilder? csb = null, CancellationToken cancellationToken = default) + catch (Exception e) { - var c = new TConnection { ConnectionString = csb?.ToString() ?? ConnectionString }; - await c.Open(cancellationToken); - return c; + Assert.That(e, streamingStarted && CurrentServerVersion >= Pg10Version + ? Is.AssignableTo() + .With.InnerException.InstanceOf() + .And.InnerException.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled) + : Is.AssignableTo() + .With.InnerException.Null); } + } + + private protected Task SafeReplicationTest(Func testAction, [CallerMemberName] string memberName = "") + => SafeReplicationTestCore((slotName, tableNames, publicationName) => testAction(slotName, tableNames[0]), 1, memberName); + + private protected Task SafeReplicationTest(Func testAction, [CallerMemberName] string memberName = "") + => SafeReplicationTestCore((slotName, tableNames, publicationName) => testAction(slotName, tableNames[0], publicationName), 1, memberName); - private protected static async Task AssertReplicationCancellation(IAsyncEnumerator enumerator) + private protected Task SafeReplicationTest(Func testAction, int tableCount, [CallerMemberName] string memberName = "") + => SafeReplicationTestCore(testAction, tableCount, memberName); + + static readonly Version Pg10Version = new(10, 0); + + async Task SafeReplicationTestCore(Func testAction, int tableCount, string memberName) + { + // if the supplied name is too long we create on from a guid. + var baseName = $"{memberName}_{Postfix}"; + var name = (baseName.Length > _maxIdentifierLength - 4 ? Guid.NewGuid().ToString("N") : baseName).ToLowerInvariant(); + var slotName = $"s_{name}".ToLowerInvariant(); + var tableNames = new string[tableCount]; + for (var i = tableNames.Length - 1; i >= 0; i--) { - try - { - var succeeded = await enumerator.MoveNextAsync(); - Assert.Fail(succeeded - ? $"Expected replication cancellation but got message: {enumerator.Current}" - : "Expected replication cancellation but reached enumeration end instead"); - } - catch (Exception e) - { - Assert.That(e, CurrentServerVersion >= Pg10Version - ? Is.AssignableTo() - .With.InnerException.InstanceOf() - .And.InnerException.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled) - : Is.AssignableTo() - .With.InnerException.Null); - } + tableNames[i] = $"t{(tableCount == 1 ? "" : i.ToString())}_{name}".ToLowerInvariant(); } + var publicationName = $"p_{name}".ToLowerInvariant(); - private protected Task SafeReplicationTest(Func testAction, [CallerMemberName] string memberName = "") - => SafeReplicationTestCore((slotName, tableName, publicationName) => testAction(slotName, tableName), memberName); + await Cleanup(); - private protected Task SafeReplicationTest(Func testAction, [CallerMemberName] string memberName = "") - => SafeReplicationTestCore(testAction, memberName); - - static readonly Version Pg10Version = new Version(10, 0); + try + { + await testAction(slotName, tableNames, publicationName); + } + finally + { + await Cleanup(); + } - async Task SafeReplicationTestCore(Func testAction, string memberName) + async Task Cleanup() { - // if the supplied name is too long we create on from a guid. - var baseName = $"{memberName}_{Postfix}"; - var name = (baseName.Length > _maxIdentifierLength - 4 ? Guid.NewGuid().ToString("N") : baseName).ToLowerInvariant(); - var slotName = $"s_{name}".ToLowerInvariant(); - var tableName = $"t_{name}".ToLowerInvariant(); - var publicationName = $"p_{name}".ToLowerInvariant(); + await using var c = await OpenConnectionAsync(); try { - await testAction(slotName, tableName, publicationName); + await DropSlot(); } - finally + catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.ObjectInUse && e.Message.Contains(slotName)) { - await using var c = await OpenConnectionAsync(); - try + // The slot is still in use. Probably because we didn't terminate + // the streaming replication properly. + // The following is ugly, but let's try to clean up after us if we can. + var pid = Regex.Match(e.MessageText, "PID (?\\d+)", RegexOptions.IgnoreCase).Groups["pid"]; + if (pid.Success) { - await DropSlot(); + await c.ExecuteNonQueryAsync($"SELECT pg_terminate_backend ({pid.Value})"); + for (var i = 0; (bool)(await c.ExecuteScalarAsync($"SELECT EXISTS(SELECT * FROM pg_stat_replication where pid = {pid.Value})"))! && i < 20; i++) + await Task.Delay(TimeSpan.FromSeconds(1)); } - catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.ObjectInUse && e.Message.Contains(slotName)) + else { - // The slot is still in use. Probably because we didn't terminate - // the streaming replication properly. - // The following is ugly, but let's try to clean up after us if we can. - var pid = Regex.Match(e.MessageText, "PID (?\\d+)", RegexOptions.IgnoreCase).Groups["pid"]; - if (pid.Success) - { - await c.ExecuteNonQueryAsync($"SELECT pg_terminate_backend ({pid.Value})"); - } // Old backends don't report the PID - for (var i = 0; (bool)(await c.ExecuteScalarAsync("SELECT EXISTS(SELECT * FROM pg_stat_replication)"))! && i < 30; i++) + for (var i = 0; (bool)(await c.ExecuteScalarAsync("SELECT EXISTS(SELECT * FROM pg_stat_replication)"))! && i < 20; i++) await Task.Delay(TimeSpan.FromSeconds(1)); + } + try + { await DropSlot(); } + catch (PostgresException e2) when (e2.SqlState == PostgresErrorCodes.ObjectInUse && e2.Message.Contains(slotName)) + { + // We failed to drop the slot, even after 20 seconds. Swallow the exception to avoid failing the test, we'll + // likely drop it the next time the test is executed (Cleanup is executed before starting the test as well). - if (c.PostgreSqlVersion >= Pg10Version) - await c.ExecuteNonQueryAsync($"DROP PUBLICATION IF EXISTS {publicationName}"); + return; + } + } + + if (c.PostgreSqlVersion >= Pg10Version) + await c.ExecuteNonQueryAsync($"DROP PUBLICATION IF EXISTS {publicationName}"); - await c.ExecuteNonQueryAsync($"DROP TABLE IF EXISTS {tableName}"); + for (var i = tableNames.Length - 1; i >= 0; i--) + await c.ExecuteNonQueryAsync($"DROP TABLE IF EXISTS {tableNames[i]} CASCADE;"); - async Task DropSlot() + async Task DropSlot() + { + try { - try - { - await c.ExecuteNonQueryAsync($"SELECT pg_drop_replication_slot('{slotName}')"); - } - catch (PostgresException ex) when (ex.SqlState == PostgresErrorCodes.UndefinedObject && ex.Message.Contains(slotName)) - { - // Temporary slots might already have been deleted - // We don't care as long as it's gone and we don't have to clean it up - } + await c.ExecuteNonQueryAsync($"SELECT pg_drop_replication_slot('{slotName}')"); + } + catch (PostgresException ex) when (ex.SqlState == PostgresErrorCodes.UndefinedObject && ex.Message.Contains(slotName)) + { + // Temporary slots might already have been deleted + // We don't care as long as it's gone and we don't have to clean it up } } } - - private protected static CancellationTokenSource GetCancelledCancellationTokenSource() - { - var cts = new CancellationTokenSource(); - cts.Cancel(); - return cts; - } } + + private protected static CancellationToken GetCancelledCancellationToken() => new(canceled: true); } diff --git a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs index 5781936578..5d7c633f6c 100644 --- a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs @@ -1,342 +1,342 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using NUnit.Framework; using Npgsql.Replication; using Npgsql.Replication.TestDecoding; -namespace Npgsql.Tests.Replication +namespace Npgsql.Tests.Replication; + +/// +/// These tests are meant to run on PostgreSQL versions back to 9.4 where the +/// implementation of logical replication was still somewhat incomplete. +/// Please don't change them without confirming that they still work on those old versions. +/// +[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] +[NonParallelizable] // These tests aren't designed to be parallelizable +public class TestDecodingReplicationTests : SafeReplicationTestBase { - /// - /// These tests are meant to run on PostgreSQL versions back to 9.4 where the - /// implementation of logical replication was still somewhat incomplete. - /// Please don't change them without confirming that they still work on those old versions. - /// - public class TestDecodingReplicationTests : SafeReplicationTestBase - { - [Test] - public Task CreateReplicationSlot() - => SafeReplicationTest( - async (slotName, _) => - { - await using var rc = await OpenReplicationConnectionAsync(); - var options = await rc.CreateTestDecodingReplicationSlot(slotName); - - await using var c = await OpenConnectionAsync(); - using var cmd = - new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.Name}'", - c); - await using var reader = await cmd.ExecuteReaderAsync(); - - Assert.That(reader.Read, Is.True); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); - Assert.That(reader.GetFieldValue(reader.GetOrdinal("plugin")), Is.EqualTo("test_decoding")); - Assert.That(reader.Read, Is.False); - }); - - [Test(Description = "Tests whether INSERT commands get replicated via test_decoding plugin")] - public Task Insert() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL)"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('val1'), ('val2')"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Insert first value - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: INSERT: id[integer]:1 name[text]:'val1'")); - - // Insert second value - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: INSERT: id[integer]:2 name[text]:'val2'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using the default replica identity")] - public Task UpdateForDefaultReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync($@"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); + [Test] + public Task CreateTestDecodingReplicationSlot() + => SafeReplicationTest( + async (slotName, _) => + { + await using var rc = await OpenReplicationConnectionAsync(); + var options = await rc.CreateTestDecodingReplicationSlot(slotName); + + await using var c = await OpenConnectionAsync(); + using var cmd = + new NpgsqlCommand($"SELECT * FROM pg_replication_slots WHERE slot_name = '{options.Name}'", + c); + await using var reader = await cmd.ExecuteReaderAsync(); + + Assert.That(reader.Read, Is.True); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("slot_type")), Is.EqualTo("logical")); + Assert.That(reader.GetFieldValue(reader.GetOrdinal("plugin")), Is.EqualTo("test_decoding")); + Assert.That(reader.Read, Is.False); + }); + + [Test(Description = "Tests whether INSERT commands get replicated via test_decoding plugin")] + public Task Insert() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync($"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL)"); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"INSERT INTO {tableName} (name) VALUES ('val1'), ('val2')"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Insert first value + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: INSERT: id[integer]:1 name[text]:'val1'")); + + // Insert second value + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: INSERT: id[integer]:2 name[text]:'val2'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using the default replica identity")] + public Task Update_for_default_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync($@"CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); INSERT INTO {tableName} (name) VALUES ('val'), ('val2')"); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Update - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: UPDATE: id[integer]:1 name[text]:'val1'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using an index as replica identity")] - public Task UpdateForIndexReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - var indexName = $"i_{tableName.Substring(2)}"; - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Update + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: UPDATE: id[integer]:1 name[text]:'val1'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using an index as replica identity")] + public Task Update_for_index_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + var indexName = $"i_{tableName.Substring(2)}"; + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); CREATE UNIQUE INDEX {indexName} ON {tableName} (name); ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Update - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: UPDATE: old-key: name[text]:'val' new-tuple: id[integer]:1 name[text]:'val1'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using full replica identity")] - public Task UpdateForFullReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Update + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: UPDATE: old-key: name[text]:'val' new-tuple: id[integer]:1 name[text]:'val1'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether UPDATE commands get replicated via test_decoding plugin for tables using full replica identity")] + public Task Update_for_full_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); ALTER TABLE {tableName} REPLICA IDENTITY FULL; INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Update - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: UPDATE: old-key: id[integer]:1 name[text]:'val' new-tuple: id[integer]:1 name[text]:'val1'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using the default replica identity")] - public Task DeleteForDefaultReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"UPDATE {tableName} SET name='val1' WHERE name='val'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Update + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: UPDATE: old-key: id[integer]:1 name[text]:'val' new-tuple: id[integer]:1 name[text]:'val1'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using the default replica identity")] + public Task Delete_for_default_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Delete - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: DELETE: id[integer]:2")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using an index as replica identity")] - public Task DeleteForIndexReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - var indexName = $"i_{tableName.Substring(2)}"; - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Delete + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: DELETE: id[integer]:2")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using an index as replica identity")] + public Task Delete_for_index_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + var indexName = $"i_{tableName.Substring(2)}"; + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); CREATE UNIQUE INDEX {indexName} ON {tableName} (name); ALTER TABLE {tableName} REPLICA IDENTITY USING INDEX {indexName}; INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Delete - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: DELETE: name[text]:'val2'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using full replica identity")] - public Task DeleteForFullReplicaIdentity() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Delete + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: DELETE: name[text]:'val2'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether DELETE commands get replicated via test_decoding plugin for tables using full replica identity")] + public Task Delete_for_full_replica_identity() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); ALTER TABLE {tableName} REPLICA IDENTITY FULL; INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - - await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); - - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); - - // Delete - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: DELETE: id[integer]:2 name[text]:'val2'")); - - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); - - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); - - [Test(Description = "Tests whether TRUNCATE commands get replicated via test_decoding plugin")] - public Task Truncate() - => SafeReplicationTest( - async (slotName, tableName) => - { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "11.0", "Replication of TRUNCATE commands was introduced in PostgreSQL 11"); - await c.ExecuteNonQueryAsync(@$" + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + + await c.ExecuteNonQueryAsync($"DELETE FROM {tableName} WHERE name='val2'"); + + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); + + // Delete + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: DELETE: id[integer]:2 name[text]:'val2'")); + + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); + + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); + + [Test(Description = "Tests whether TRUNCATE commands get replicated via test_decoding plugin")] + public Task Truncate() + => SafeReplicationTest( + async (slotName, tableName) => + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "11.0", "Replication of TRUNCATE commands was introduced in PostgreSQL 11"); + await c.ExecuteNonQueryAsync(@$" CREATE TABLE {tableName} (id serial PRIMARY KEY, name TEXT NOT NULL); INSERT INTO {tableName} (name) VALUES ('val'), ('val2'); "); - var rc = await OpenReplicationConnectionAsync(); - var slot = await rc.CreateTestDecodingReplicationSlot(slotName); + await using var rc = await OpenReplicationConnectionAsync(); + var slot = await rc.CreateTestDecodingReplicationSlot(slotName); - await c.ExecuteNonQueryAsync($"TRUNCATE TABLE {tableName} RESTART IDENTITY CASCADE"); + await c.ExecuteNonQueryAsync($"TRUNCATE TABLE {tableName} RESTART IDENTITY CASCADE"); - using var streamingCts = new CancellationTokenSource(); - var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); + using var streamingCts = new CancellationTokenSource(); + var messages = rc.StartReplication(slot, streamingCts.Token, new TestDecodingOptions(skipEmptyXacts: true)).GetAsyncEnumerator(); - // Begin Transaction - var message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("BEGIN ")); + // Begin Transaction + var message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("BEGIN ")); - // Truncate - message = await NextMessage(messages); - Assert.That(message.Data, - Is.EqualTo($"table public.{tableName}: TRUNCATE: restart_seqs cascade")); + // Truncate + message = await NextMessage(messages); + Assert.That(message.Data, + Is.EqualTo($"table public.{tableName}: TRUNCATE: restart_seqs cascade")); - // Commit Transaction - message = await NextMessage(messages); - Assert.That(message.Data, Does.StartWith("COMMIT ")); + // Commit Transaction + message = await NextMessage(messages); + Assert.That(message.Data, Does.StartWith("COMMIT ")); - streamingCts.Cancel(); - await AssertReplicationCancellation(messages); - }); + streamingCts.Cancel(); + await AssertReplicationCancellation(messages); + }); - static async ValueTask NextMessage(IAsyncEnumerator enumerator) - { - Assert.True(await enumerator.MoveNextAsync()); - return enumerator.Current!; - } + static async ValueTask NextMessage(IAsyncEnumerator enumerator) + { + Assert.True(await enumerator.MoveNextAsync()); + return enumerator.Current!; + } - protected override string Postfix => "test_encoding_l"; + protected override string Postfix => "test_encoding_l"; - [OneTimeSetUp] - public async Task SetUp() - { - await using var c = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(c, "9.4", "Logical Replication was introduced in PostgreSQL 9.4"); - } + [OneTimeSetUp] + public async Task SetUp() + { + await using var c = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(c, "9.4", "Logical Replication was introduced in PostgreSQL 9.4"); } } diff --git a/test/Npgsql.Tests/SchemaTests.cs b/test/Npgsql.Tests/SchemaTests.cs index 5638bfe6e0..e65fc48cf2 100644 --- a/test/Npgsql.Tests/SchemaTests.cs +++ b/test/Npgsql.Tests/SchemaTests.cs @@ -1,547 +1,599 @@ -using System; +using NpgsqlTypes; +using NUnit.Framework; +using System; using System.Data; using System.Data.Common; using System.Linq; using System.Text.RegularExpressions; using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -#pragma warning disable 8602 // Warning should be removable after rc2 (https://github.com/dotnet/runtime/pull/42215) +namespace Npgsql.Tests; -namespace Npgsql.Tests +public class SchemaTests : SyncOrAsyncTestBase { - public class SchemaTests : SyncOrAsyncTestBase + [Test] + public async Task MetaDataCollections() { - [Test] - public async Task MetaDataCollectionNames() - { - using var conn = OpenConnection(); - var metaDataCollections = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); - Assert.That(metaDataCollections.Rows, Has.Count.GreaterThan(0)); - foreach (var row in metaDataCollections.Rows.OfType()) - { - var collectionName = (string)row!["CollectionName"]; - Assert.That(conn.GetSchema(collectionName), Is.Not.Null, $"Collection {collectionName} advertise in MetaDataCollections but is null"); - } - } + await using var conn = await OpenConnectionAsync(); - [Test, Description("Calling GetSchema() without a parameter should be the same as passing MetaDataCollections")] - public async Task NoParameter() - { - using var conn = OpenConnection(); - var dataTable1 = await GetSchema(conn); - var collections1 = dataTable1.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - var dataTable2 = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); - var collections2 = dataTable2.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - Assert.That(collections1, Is.EquivalentTo(collections2)); - } + var metaDataCollections = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); + Assert.That(metaDataCollections.Rows, Has.Count.GreaterThan(0)); - [Test, Description("Calling GetSchema(collectionName [, restrictions]) case insensive collectionName can be used")] - public async Task CaseInsensitiveCollectionName() + foreach (var row in metaDataCollections.Rows.OfType()) { - using var conn = OpenConnection(); - var dataTable1 = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); - var collections1 = dataTable1.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable2 = await GetSchema(conn, "METADATACOLLECTIONS"); - var collections2 = dataTable2.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable3 = await GetSchema(conn, "metadatacollections"); - var collections3 = dataTable3.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable4 = await GetSchema(conn, "MetaDataCollections"); - var collections4 = dataTable4.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable5 = await GetSchema(conn, "METADATACOLLECTIONS", null!); - var collections5 = dataTable5.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable6 = await GetSchema(conn, "metadatacollections", null!); - var collections6 = dataTable6.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - var dataTable7 = await GetSchema(conn, "MetaDataCollections", null!); - var collections7 = dataTable7.Rows - .Cast() - .Select(r => (string)r["CollectionName"]) - .ToList(); - - Assert.That(collections1, Is.EquivalentTo(collections2)); - Assert.That(collections1, Is.EquivalentTo(collections3)); - Assert.That(collections1, Is.EquivalentTo(collections4)); - Assert.That(collections1, Is.EquivalentTo(collections5)); - Assert.That(collections1, Is.EquivalentTo(collections6)); - Assert.That(collections1, Is.EquivalentTo(collections7)); + var collectionName = (string)row!["CollectionName"]; + Assert.That(await GetSchema(conn, collectionName), Is.Not.Null, $"Collection {collectionName} advertise in MetaDataCollections but is null"); } + } - [Test] - public async Task DataSourceInformation() - { - using var conn = OpenConnection(); - var dataTable = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); - var metadata = dataTable.Rows - .Cast() - .Single(r => r["CollectionName"].Equals("DataSourceInformation")); - Assert.That(metadata["NumberOfRestrictions"], Is.Zero); - Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); + [Test, Description("Calling GetSchema() without a parameter should be the same as passing MetaDataCollections")] + public async Task No_parameter() + { + await using var conn = await OpenConnectionAsync(); - var dataSourceInfo = await GetSchema(conn, DbMetaDataCollectionNames.DataSourceInformation); - var row = dataSourceInfo.Rows.Cast().Single(); + var dataTable1 = await GetSchema(conn); + var collections1 = dataTable1.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); - Assert.That(row["DataSourceProductName"], Is.EqualTo("Npgsql")); + var dataTable2 = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); + var collections2 = dataTable2.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); - var pgVersion = conn.PostgreSqlVersion; - Assert.That(row["DataSourceProductVersion"], Is.EqualTo(pgVersion.ToString())); + Assert.That(collections1, Is.EquivalentTo(collections2)); + } - var parsedNormalizedVersion = Version.Parse((string)row["DataSourceProductVersionNormalized"]); - Assert.That(parsedNormalizedVersion, Is.EqualTo(conn.PostgreSqlVersion)); + [Test, Description("Calling GetSchema(collectionName [, restrictions]) case insensive collectionName can be used")] + public async Task Case_insensitive_collection_name() + { + await using var conn = await OpenConnectionAsync(); + + var dataTable1 = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); + var collections1 = dataTable1.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable2 = await GetSchema(conn, "METADATACOLLECTIONS"); + var collections2 = dataTable2.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable3 = await GetSchema(conn, "metadatacollections"); + var collections3 = dataTable3.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable4 = await GetSchema(conn, "MetaDataCollections"); + var collections4 = dataTable4.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable5 = await GetSchema(conn, "METADATACOLLECTIONS", null!); + var collections5 = dataTable5.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable6 = await GetSchema(conn, "metadatacollections", null!); + var collections6 = dataTable6.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable7 = await GetSchema(conn, "MetaDataCollections", null!); + var collections7 = dataTable7.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + Assert.That(collections1, Is.EquivalentTo(collections2)); + Assert.That(collections1, Is.EquivalentTo(collections3)); + Assert.That(collections1, Is.EquivalentTo(collections4)); + Assert.That(collections1, Is.EquivalentTo(collections5)); + Assert.That(collections1, Is.EquivalentTo(collections6)); + Assert.That(collections1, Is.EquivalentTo(collections7)); + } - Assert.That(Regex.Match("\"some_identifier\"", (string)row["QuotedIdentifierPattern"]).Groups[1].Value, - Is.EqualTo("some_identifier")); - } + [Test] + public async Task DataSourceInformation() + { + await using var conn = await OpenConnectionAsync(); + var dataTable = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); + var metadata = dataTable.Rows + .Cast() + .Single(r => r["CollectionName"].Equals("DataSourceInformation")); + Assert.That(metadata["NumberOfRestrictions"], Is.Zero); + Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); - [Test] - public async Task DataTypes() - { - using var conn = OpenConnection(); - conn.ExecuteNonQuery("CREATE TYPE pg_temp.test_enum AS ENUM ('a', 'b')"); - conn.ExecuteNonQuery("CREATE TYPE pg_temp.test_composite AS (a INTEGER)"); - conn.ExecuteNonQuery("CREATE DOMAIN pg_temp.us_postal_code AS TEXT"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(); - conn.TypeMapper.MapComposite(); - - var dataTable = await GetSchema(conn, DbMetaDataCollectionNames.MetaDataCollections); - var metadata = dataTable.Rows - .Cast() - .Single(r => r["CollectionName"].Equals("DataTypes")); - Assert.That(metadata["NumberOfRestrictions"], Is.Zero); - Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); - - var dataTypes = await GetSchema(conn, DbMetaDataCollectionNames.DataTypes); - - var intRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("integer")); - Assert.That(intRow["DataType"], Is.EqualTo("System.Int32")); - Assert.That(intRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Integer)); - Assert.That(intRow["IsUnsigned"], Is.False); - Assert.That(intRow["OID"], Is.EqualTo(23)); - - var textRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("text")); - Assert.That(textRow["DataType"], Is.EqualTo("System.String")); - Assert.That(textRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Text)); - Assert.That(textRow["IsUnsigned"], Is.SameAs(DBNull.Value)); - Assert.That(textRow["OID"], Is.EqualTo(25)); - - var numericRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("numeric")); - Assert.That(numericRow["DataType"], Is.EqualTo("System.Decimal")); - Assert.That(numericRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Numeric)); - Assert.That(numericRow["IsUnsigned"], Is.False); - Assert.That(numericRow["OID"], Is.EqualTo(1700)); - Assert.That(numericRow["CreateFormat"], Is.EqualTo("NUMERIC({0},{1})")); - Assert.That(numericRow["CreateParameters"], Is.EqualTo("precision, scale")); - - var intArrayRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("integer[]")); - Assert.That(intArrayRow["DataType"], Is.EqualTo("System.Int32[]")); - Assert.That(intArrayRow["ProviderDbType"], Is.EqualTo((int)(NpgsqlDbType.Integer | NpgsqlDbType.Array))); - Assert.That(intArrayRow["OID"], Is.EqualTo(1007)); - Assert.That(intArrayRow["CreateFormat"], Is.EqualTo("INTEGER[]")); - - var numericArrayRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("numeric[]")); - Assert.That(numericArrayRow["CreateFormat"], Is.EqualTo("NUMERIC({0},{1})[]")); - Assert.That(numericArrayRow["CreateParameters"], Is.EqualTo("precision, scale")); - - var intRangeRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith("int4range")); - Assert.That(intRangeRow["DataType"], Does.StartWith("NpgsqlTypes.NpgsqlRange`1[[System.Int32")); - Assert.That(intRangeRow["ProviderDbType"], Is.EqualTo((int)(NpgsqlDbType.Integer | NpgsqlDbType.Range))); - Assert.That(intRangeRow["OID"], Is.EqualTo(3904)); - - var enumRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith(".test_enum")); - Assert.That(enumRow["DataType"], Is.EqualTo("Npgsql.Tests.SchemaTests+TestEnum")); - Assert.That(enumRow["ProviderDbType"], Is.SameAs(DBNull.Value)); - - var compositeRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith(".test_composite")); - Assert.That(compositeRow["DataType"], Is.EqualTo("Npgsql.Tests.SchemaTests+TestComposite")); - Assert.That(compositeRow["ProviderDbType"], Is.SameAs(DBNull.Value)); - - var domainRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith(".us_postal_code")); - Assert.That(domainRow["DataType"], Is.EqualTo("System.String")); - Assert.That(domainRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Text)); - Assert.That(domainRow["IsBestMatch"], Is.False); - } + var dataSourceInfo = await GetSchema(conn, DbMetaDataCollectionNames.DataSourceInformation); + var row = dataSourceInfo.Rows.Cast().Single(); - enum TestEnum { A, B }; + Assert.That(row["DataSourceProductName"], Is.EqualTo("Npgsql")); - class TestComposite { public int A { get; set; } } + var pgVersion = conn.PostgreSqlVersion; + Assert.That(row["DataSourceProductVersion"], Is.EqualTo(pgVersion.ToString())); - [Test] - public async Task Restrictions() - { - using var conn = OpenConnection(); - var restrictions = await GetSchema(conn, DbMetaDataCollectionNames.Restrictions); - Assert.That(restrictions.Rows, Has.Count.GreaterThan(0)); - } + var parsedNormalizedVersion = Version.Parse((string)row["DataSourceProductVersionNormalized"]); + Assert.That(parsedNormalizedVersion, Is.EqualTo(conn.PostgreSqlVersion)); - [Test] - public async Task ReservedWords() - { - using var conn = OpenConnection(); - var reservedWords = await GetSchema(conn, DbMetaDataCollectionNames.ReservedWords); - Assert.That(reservedWords.Rows, Has.Count.GreaterThan(0)); - } + Assert.That(Regex.Match("\"some_identifier\"", (string)row["QuotedIdentifierPattern"]).Groups[1].Value, + Is.EqualTo("some_identifier")); + } - [Test] - public async Task ForeignKeys() - { - using var conn = OpenConnection(); - var dt = await GetSchema(conn, "ForeignKeys"); - Assert.IsNotNull(dt); - } + [Test] + public async Task DataTypes() + { + await using var adminConnection = await OpenConnectionAsync(); + var enumType = await GetTempTypeName(adminConnection); + var compositeType = await GetTempTypeName(adminConnection); + var domainType = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {enumType} AS ENUM ('a', 'b'); +CREATE TYPE {compositeType} AS (a INTEGER); +CREATE DOMAIN {domainType} AS TEXT"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(enumType); + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var dataTable = await GetSchema(connection, DbMetaDataCollectionNames.MetaDataCollections); + var metadata = dataTable.Rows + .Cast() + .Single(r => r["CollectionName"].Equals("DataTypes")); + Assert.That(metadata["NumberOfRestrictions"], Is.Zero); + Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); + + var dataTypes = await GetSchema(connection, DbMetaDataCollectionNames.DataTypes); + + var intRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("integer")); + Assert.That(intRow["DataType"], Is.EqualTo("System.Int32")); + Assert.That(intRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Integer)); + Assert.That(intRow["IsUnsigned"], Is.False); + Assert.That(intRow["OID"], Is.EqualTo(23)); + + var textRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("text")); + Assert.That(textRow["DataType"], Is.EqualTo("System.String")); + Assert.That(textRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Text)); + Assert.That(textRow["IsUnsigned"], Is.SameAs(DBNull.Value)); + Assert.That(textRow["OID"], Is.EqualTo(25)); + + var numericRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("numeric")); + Assert.That(numericRow["DataType"], Is.EqualTo("System.Decimal")); + Assert.That(numericRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Numeric)); + Assert.That(numericRow["IsUnsigned"], Is.False); + Assert.That(numericRow["OID"], Is.EqualTo(1700)); + Assert.That(numericRow["CreateFormat"], Is.EqualTo("NUMERIC({0},{1})")); + Assert.That(numericRow["CreateParameters"], Is.EqualTo("precision, scale")); + + var intArrayRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("integer[]")); + Assert.That(intArrayRow["DataType"], Is.EqualTo("System.Int32[]")); + Assert.That(intArrayRow["ProviderDbType"], Is.EqualTo((int)(NpgsqlDbType.Integer | NpgsqlDbType.Array))); + Assert.That(intArrayRow["OID"], Is.EqualTo(1007)); + Assert.That(intArrayRow["CreateFormat"], Is.EqualTo("INTEGER[]")); + + var numericArrayRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("numeric[]")); + Assert.That(numericArrayRow["CreateFormat"], Is.EqualTo("NUMERIC({0},{1})[]")); + Assert.That(numericArrayRow["CreateParameters"], Is.EqualTo("precision, scale")); + + var intRangeRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith("int4range")); + Assert.That(intRangeRow["DataType"], Does.StartWith("NpgsqlTypes.NpgsqlRange`1[[System.Int32")); + Assert.That(intRangeRow["ProviderDbType"], Is.EqualTo((int)(NpgsqlDbType.Integer | NpgsqlDbType.Range))); + Assert.That(intRangeRow["OID"], Is.EqualTo(3904)); + + var enumRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith("." + enumType)); + Assert.That(enumRow["DataType"], Is.EqualTo("Npgsql.Tests.SchemaTests+TestEnum")); + Assert.That(enumRow["ProviderDbType"], Is.SameAs(DBNull.Value)); + + var compositeRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith("." + compositeType)); + Assert.That(compositeRow["DataType"], Is.EqualTo("Npgsql.Tests.SchemaTests+TestComposite")); + Assert.That(compositeRow["ProviderDbType"], Is.SameAs(DBNull.Value)); + + var domainRow = dataTypes.Rows.Cast().Single(r => ((string)r["TypeName"]).EndsWith("." + domainType)); + Assert.That(domainRow["DataType"], Is.EqualTo("System.String")); + Assert.That(domainRow["ProviderDbType"], Is.EqualTo((int)NpgsqlDbType.Text)); + Assert.That(domainRow["IsBestMatch"], Is.False); + } - [Test] - public async Task ParameterMarkerFormats() - { - using var conn = OpenConnection(); - var dt = await GetSchema(conn, "DataSourceInformation"); - var parameterMarkerFormat = (string)dt.Rows[0]["ParameterMarkerFormat"]; - - conn.ExecuteNonQuery("CREATE TEMP TABLE data (int INTEGER)"); - conn.ExecuteNonQuery("INSERT INTO data (int) VALUES (4)"); - using var command = conn.CreateCommand(); - const string parameterName = "@p_int"; - command.CommandText = "SELECT * FROM data WHERE int=" + - string.Format(parameterMarkerFormat, parameterName); - command.Parameters.Add(new NpgsqlParameter(parameterName, 4)); - using var reader = command.ExecuteReader(); - Assert.IsTrue(reader.Read()); - // This is OK, when no exceptions are occurred. - } + enum TestEnum { A, B }; - [Test] - public async Task PrecisionAndScale() - { - using var conn = OpenConnection(); - conn.ExecuteNonQuery(@"CREATE TEMP TABLE data (explicit_both NUMERIC(10,2), explicit_precision NUMERIC(10), implicit_both NUMERIC, integer INTEGER, text TEXT)"); - var dataTable = await GetSchema(conn, "Columns"); - var rows = dataTable.Rows.Cast().ToList(); - - var explicitBoth = rows.Single(r => (string)r["column_name"] == "explicit_both"); - Assert.That(explicitBoth["numeric_precision"], Is.EqualTo(10)); - Assert.That(explicitBoth["numeric_scale"], Is.EqualTo(2)); - - var explicitPrecision = rows.Single(r => (string)r["column_name"] == "explicit_precision"); - Assert.That(explicitPrecision["numeric_precision"], Is.EqualTo(10)); - Assert.That(explicitPrecision["numeric_scale"], Is.EqualTo(0)); // Not good - - // Consider exposing actual precision/scale even for implicit - var implicitBoth = rows.Single(r => (string)r["column_name"] == "implicit_both"); - Assert.That(implicitBoth["numeric_precision"], Is.EqualTo(DBNull.Value)); - Assert.That(implicitBoth["numeric_scale"], Is.EqualTo(DBNull.Value)); - - var integer = rows.Single(r => (string)r["column_name"] == "integer"); - Assert.That(integer["numeric_precision"], Is.EqualTo(32)); - Assert.That(integer["numeric_scale"], Is.EqualTo(0)); - - var text = rows.Single(r => (string)r["column_name"] == "text"); - Assert.That(text["numeric_precision"], Is.EqualTo(DBNull.Value)); - Assert.That(text["numeric_scale"], Is.EqualTo(DBNull.Value)); - } + class TestComposite { public int A { get; set; } } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1831")] - public async Task NoSystemTables() - { - using (var conn = OpenConnection()) - { - var dataTable = await GetSchema(conn, "Tables"); - var tables = dataTable.Rows - .Cast() - .Select(r => (string)r["TABLE_NAME"]) - .ToList(); - Assert.That(tables, Does.Not.Contain("pg_type")); // schema pg_catalog - Assert.That(tables, Does.Not.Contain("tables")); // schema information_schema - } - - using (var conn = OpenConnection()) - { - var dataTable = await GetSchema(conn, "Views"); - var views = dataTable.Rows - .Cast() - .Select(r => (string)r["TABLE_NAME"]) - .ToList(); - Assert.That(views, Does.Not.Contain("pg_user")); // schema pg_catalog - Assert.That(views, Does.Not.Contain("views")); // schema information_schema - } - } + [Test] + public async Task Restrictions() + { + await using var conn = await OpenConnectionAsync(); + var restrictions = await GetSchema(conn, DbMetaDataCollectionNames.Restrictions); + Assert.That(restrictions.Rows, Has.Count.GreaterThan(0)); + } - [Test] - public async Task GetSchemaWithRestrictions() - { - // We can't use temporary tables because GetSchema filters out that in WHERE clause. - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - conn.ExecuteNonQuery("CREATE TABLE data (bar INTEGER)"); - - try - { - string[] restrictions = { null!, null!, "data" }; - var dt = await GetSchema(conn, "Tables", restrictions); - foreach (var row in dt.Rows.OfType()) - { - Assert.That(row["table_name"], Is.EqualTo("data")); - } - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - } - } - - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("DROP VIEW IF EXISTS view"); - conn.ExecuteNonQuery("CREATE VIEW view AS SELECT 8 AS foo"); - - try - { - string[] restrictions = { null!, null!, "view" }; - var dt = await GetSchema(conn, "Views", restrictions); - foreach (var row in dt.Rows.OfType()) - { - Assert.That(row["table_name"], Is.EqualTo("view")); - } - } - finally - { - conn.ExecuteNonQuery("DROP VIEW IF EXISTS view"); - } - } - } + [Test] + public async Task ReservedWords() + { + await using var conn = await OpenConnectionAsync(); + var reservedWords = await GetSchema(conn, DbMetaDataCollectionNames.ReservedWords); + Assert.That(reservedWords.Rows, Has.Count.GreaterThan(0)); + } - [Test] - public async Task PrimaryKey() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS data; -CREATE TABLE data (id INT, f1 INT); -ALTER TABLE data ADD PRIMARY KEY (id);"); - string[] restrictions = { null!, null!, "data" }; - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", restrictions); - var column = dataTable.Rows.Cast().Single(); - - Assert.That(column["table_schema"], Is.EqualTo("public")); - Assert.That(column["table_name"], Is.EqualTo("data")); - Assert.That(column["column_name"], Is.EqualTo("id")); - Assert.That(column["constraint_type"], Is.EqualTo("PRIMARY KEY")); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - } - } + [Test] + public async Task Databases() + { + await using var conn = await OpenConnectionAsync(); + var database = await conn.ExecuteScalarAsync("SELECT current_database()"); - [Test] - public async Task PrimaryKeyComposite() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS data; -CREATE TABLE data (id1 INT, id2 INT, f1 INT); -ALTER TABLE data ADD PRIMARY KEY (id1, id2);"); - string[] restrictions = { null!, null!, "data" }; - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", restrictions); - var columns = dataTable.Rows.Cast() - .OrderBy(r => r["ordinal_number"]).ToList(); - - Assert.That(columns.All(r => r["table_schema"].Equals("public"))); - Assert.That(columns.All(r => r["table_name"].Equals("data"))); - Assert.That(columns.All(r => r["constraint_type"].Equals("PRIMARY KEY"))); - - Assert.That(columns[0]["column_name"], Is.EqualTo("id1")); - Assert.That(columns[1]["column_name"], Is.EqualTo("id2")); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - } - } + var dataTable = await GetSchema(conn, "Databases"); + var databases = dataTable.Rows + .Cast() + .Select(r => (string)r["database_name"]) + .ToList(); - [Test] - public async Task UniqueConstraint() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS data; -CREATE TABLE data (f1 INT, f2 INT); -ALTER TABLE data ADD UNIQUE (f1, f2);"); - string[] restrictions = { null!, null!, "data" }; - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", restrictions); - var rows = dataTable.Rows.Cast().ToList(); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - } - } + Assert.That(databases, Does.Contain(database)); + } - [Test] - public async Task UniqueIndexComposite() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS data; -CREATE TABLE data (f1 INT, f2 INT); -CREATE UNIQUE INDEX idx_unique ON data (f1, f2); -"); - var database = conn.ExecuteScalar("SELECT current_database()"); - - string[] restrictions = { null!, null!, "data" }; - var dataTable = await GetSchema(conn, "INDEXES", restrictions); - var index = dataTable.Rows.Cast().Single(); - - Assert.That(index["table_schema"], Is.EqualTo("public")); - Assert.That(index["table_name"], Is.EqualTo("data")); - Assert.That(index["index_name"], Is.EqualTo("idx_unique")); - Assert.That(index["type_desc"], Is.EqualTo("")); - - string[] indexColumnRestrictions = { null!, null!, "data" }; - var dataTable2 = await GetSchema(conn, "INDEXCOLUMNS", indexColumnRestrictions); - var columns = dataTable2.Rows.Cast().ToList(); - - Assert.That(columns.All(r => r["constraint_catalog"].Equals(database))); - Assert.That(columns.All(r => r["constraint_schema"].Equals("public"))); - Assert.That(columns.All(r => r["constraint_name"].Equals("idx_unique"))); - Assert.That(columns.All(r => r["table_catalog"].Equals(database))); - Assert.That(columns.All(r => r["table_schema"].Equals("public"))); - Assert.That(columns.All(r => r["table_name"].Equals("data"))); - Assert.That(columns.All(r => r["index_name"].Equals("idx_unique"))); - - Assert.That(columns[0]["column_name"], Is.EqualTo("f1")); - Assert.That(columns[1]["column_name"], Is.EqualTo("f2")); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - } - } + [Test] + public async Task Schemata() + { + await using var conn = await OpenConnectionAsync(); + var schema = await CreateTempSchema(conn); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1886")] - public async Task ColumnSchemaDataTypes() - { - using var conn = OpenConnection(); - try - { - conn.ExecuteNonQuery(@" -DROP TABLE IF EXISTS types_table; -CREATE TABLE types_table -( - p0 integer NOT NULL, - achar char, - char character(3), - vchar character varying(10), - text text, - bytea bytea, - abit bit(1), - bit bit(3), - vbit bit varying(5), - boolean boolean, - smallint smallint, - integer integer, - bigint bigint, - real real, - double double precision, - numeric numeric, - money money, - date date, - timetz time with time zone, - timestamptz timestamp with time zone, - time time without time zone, - timestamp timestamp without time zone, - point point, - box box, - lseg lseg, - path path, - polygon polygon, - circle circle, - line line, - inet inet, - macaddr macaddr, - uuid uuid, - interval interval, - name name, - refcursor refcursor, - numrange numrange, - oidvector oidvector, - ""bigint[]"" bigint[], - cidr cidr, - maccaddr8 macaddr8, - jsonb jsonb, - json json, - xml xml, - tsvector tsvector, - tsquery tsquery, - tid tid, - xid xid, - cid cid, - CONSTRAINT types_table_pkey PRIMARY KEY(p0) -) -"); - var database = conn.ExecuteScalar("SELECT current_database()"); - - string[] restrictions = { "npgsql_tests", "public", "types_table", null! }; - var columnsSchema = await GetSchema(conn, "Columns", restrictions); - var columns = columnsSchema.Rows.Cast().ToList(); - - var dataTypes = await GetSchema(conn, DbMetaDataCollectionNames.DataTypes); - - columns.ForEach(col => Assert.That(dataTypes.Rows.Cast().Any(row => row["TypeName"].Equals(col["data_type"])), Is.True)); - } - finally - { - conn.ExecuteNonQuery("DROP TABLE IF EXISTS types_table"); - } - } + var dataTable = await GetSchema(conn, "Schemata"); + var row = dataTable.Rows.Cast().Single(r => (string)r["schema_name"] == schema); + + Assert.That(row["catalog_name"], Is.EqualTo(await conn.ExecuteScalarAsync("SELECT current_database()"))); + Assert.That(row["schema_owner"], Is.EqualTo(await conn.ExecuteScalarAsync("SELECT current_user"))); + } - public SchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } + [Test] + public async Task ForeignKeys() + { + await using var conn = await OpenConnectionAsync(); + var dt = await GetSchema(conn, "ForeignKeys"); + Assert.IsNotNull(dt); + } - private async Task GetSchema(NpgsqlConnection conn) - => IsAsync ? await conn.GetSchemaAsync() : conn.GetSchema(); + [Test] + public async Task ParameterMarkerFormat() + { + await using var conn = await OpenConnectionAsync(); - private async Task GetSchema(NpgsqlConnection conn, string collectionName) - => IsAsync ? await conn.GetSchemaAsync(collectionName) : conn.GetSchema(collectionName); + var table = await CreateTempTable(conn, "int INTEGER"); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (int) VALUES (4)"); - private async Task GetSchema(NpgsqlConnection conn, string collectionName, string?[] restrictions) - => IsAsync ? await conn.GetSchemaAsync(collectionName, restrictions) : conn.GetSchema(collectionName, restrictions); + var dt = await GetSchema(conn, "DataSourceInformation"); + var parameterMarkerFormat = (string)dt.Rows[0]["ParameterMarkerFormat"]; + + await using var command = conn.CreateCommand(); + const string parameterName = "@p_int"; + command.CommandText = $"SELECT * FROM {table} WHERE int=" + string.Format(parameterMarkerFormat, parameterName); + command.Parameters.Add(new NpgsqlParameter(parameterName, 4)); + await using var reader = await command.ExecuteReaderAsync(); + Assert.IsTrue(reader.Read()); } + + [Test] + public async Task Precision_and_scale() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable( + conn, "explicit_both NUMERIC(10,2), explicit_precision NUMERIC(10), implicit_both NUMERIC, integer INTEGER, text TEXT"); + + var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var rows = dataTable.Rows.Cast().ToList(); + + var explicitBoth = rows.Single(r => (string)r["column_name"] == "explicit_both"); + Assert.That(explicitBoth["numeric_precision"], Is.EqualTo(10)); + Assert.That(explicitBoth["numeric_scale"], Is.EqualTo(2)); + + var explicitPrecision = rows.Single(r => (string)r["column_name"] == "explicit_precision"); + Assert.That(explicitPrecision["numeric_precision"], Is.EqualTo(10)); + Assert.That(explicitPrecision["numeric_scale"], Is.EqualTo(0)); // Not good + + // Consider exposing actual precision/scale even for implicit + var implicitBoth = rows.Single(r => (string)r["column_name"] == "implicit_both"); + Assert.That(implicitBoth["numeric_precision"], Is.EqualTo(DBNull.Value)); + Assert.That(implicitBoth["numeric_scale"], Is.EqualTo(DBNull.Value)); + + var integer = rows.Single(r => (string)r["column_name"] == "integer"); + Assert.That(integer["numeric_precision"], Is.EqualTo(32)); + Assert.That(integer["numeric_scale"], Is.EqualTo(0)); + + var text = rows.Single(r => (string)r["column_name"] == "text"); + Assert.That(text["numeric_precision"], Is.EqualTo(DBNull.Value)); + Assert.That(text["numeric_scale"], Is.EqualTo(DBNull.Value)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1831")] + public async Task No_system_tables() + { + await using var conn = await OpenConnectionAsync(); + + var dataTable = await GetSchema(conn, "Tables"); + var tables = dataTable.Rows + .Cast() + .Select(r => (string)r["TABLE_NAME"]) + .ToList(); + Assert.That(tables, Does.Not.Contain("pg_type")); // schema pg_catalog + Assert.That(tables, Does.Not.Contain("tables")); // schema information_schema + + dataTable = await GetSchema(conn, "Views"); + var views = dataTable.Rows + .Cast() + .Select(r => (string)r["TABLE_NAME"]) + .ToList(); + Assert.That(views, Does.Not.Contain("pg_user")); // schema pg_catalog + Assert.That(views, Does.Not.Contain("views")); // schema information_schema + } + + [Test] + public async Task GetSchema_tables_with_restrictions() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "bar INTEGER"); + + var dt = await GetSchema(conn, "Tables", new[] { null, null, table }); + foreach (var row in dt.Rows.OfType()) + Assert.That(row["table_name"], Is.EqualTo(table)); + } + + [Test] + public async Task GetSchema_views_with_restrictions() + { + await using var conn = await OpenConnectionAsync(); + var view = await GetTempViewName(conn); + + await conn.ExecuteNonQueryAsync($"CREATE VIEW {view} AS SELECT 8 AS foo"); + + var dt = await GetSchema(conn, "Views", new[] { null, null, view }); + foreach (var row in dt.Rows.OfType()) + Assert.That(row["table_name"], Is.EqualTo(view)); + } + + [Test] + public async Task GetSchema_materialized_views_with_restrictions() + { + await using var conn = await OpenConnectionAsync(); + var viewName = await GetTempMaterializedViewName(conn); + + await conn.ExecuteNonQueryAsync($"CREATE MATERIALIZED VIEW {viewName} AS SELECT 8 AS foo"); + + var dt = await GetSchema(conn, "MaterializedViews", new[] { null, viewName, null, null }); + foreach (var row in dt.Rows.OfType()) + Assert.That(row["table_name"], Is.EqualTo(viewName)); + } + + [Test] + public async Task Primary_key() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id INT PRIMARY KEY, f1 INT"); + + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var column = dataTable.Rows.Cast().Single(); + + Assert.That(column["table_schema"], Is.EqualTo("public")); + Assert.That(column["table_name"], Is.EqualTo(table)); + Assert.That(column["column_name"], Is.EqualTo("id")); + Assert.That(column["constraint_type"], Is.EqualTo("PRIMARY KEY")); + } + + [Test] + public async Task Primary_key_composite() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "id1 INT, id2 INT, f1 INT, PRIMARY KEY (id1, id2)"); + + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var columns = dataTable.Rows.Cast().OrderBy(r => r["ordinal_number"]).ToList(); + + Assert.That(columns.All(r => r["table_schema"].Equals("public"))); + Assert.That(columns.All(r => r["table_name"].Equals(table))); + Assert.That(columns.All(r => r["constraint_type"].Equals("PRIMARY KEY"))); + + Assert.That(columns[0]["column_name"], Is.EqualTo("id1")); + Assert.That(columns[1]["column_name"], Is.EqualTo("id2")); + } + + [Test] + public async Task Unique_constraint() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "f1 INT, f2 INT, UNIQUE (f1, f2)"); + + var database = await conn.ExecuteScalarAsync("SELECT current_database()"); + + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var columns = dataTable.Rows.Cast().ToList(); + + Assert.That(columns.All(r => r["constraint_catalog"].Equals(database))); + Assert.That(columns.All(r => r["constraint_schema"].Equals("public"))); + Assert.That(columns.All(r => r["constraint_name"] is not null)); + Assert.That(columns.All(r => r["table_catalog"].Equals(database))); + Assert.That(columns.All(r => r["table_schema"].Equals("public"))); + Assert.That(columns.All(r => r["table_name"].Equals(table))); + Assert.That(columns.All(r => r["constraint_type"].Equals("UNIQUE KEY"))); + + Assert.That(columns.Count, Is.EqualTo(2)); + + // Columns are not necessarily in the correct order + var firstColumn = columns.FirstOrDefault(x => (string)x["column_name"] == "f1")!; + Assert.NotNull(firstColumn); + Assert.That(firstColumn["ordinal_number"], Is.EqualTo(1)); + + var secondColumn = columns.FirstOrDefault(x => (string)x["column_name"] == "f2")!; + Assert.NotNull(secondColumn); + Assert.That(secondColumn["ordinal_number"], Is.EqualTo(2)); + } + + [Test] + public async Task Unique_index_composite() + { + await using var conn = await OpenConnectionAsync(); + var table = await GetTempTableName(conn); + var constraint = table + "_uq"; + await conn.ExecuteNonQueryAsync(@$" +CREATE TABLE {table} ( + f1 INT, + f2 INT, + CONSTRAINT {constraint} UNIQUE (f1, f2) +)"); + + var database = await conn.ExecuteScalarAsync("SELECT current_database()"); + + var dataTable = await GetSchema(conn, "INDEXES", new[] { null, null, table }); + var index = dataTable.Rows.Cast().Single(); + + Assert.That(index["table_schema"], Is.EqualTo("public")); + Assert.That(index["table_name"], Is.EqualTo(table)); + Assert.That(index["index_name"], Is.EqualTo(constraint)); + Assert.That(index["type_desc"], Is.EqualTo("")); + + string[] indexColumnRestrictions = { null!, null!, table }; + var dataTable2 = await GetSchema(conn, "INDEXCOLUMNS", indexColumnRestrictions); + var columns = dataTable2.Rows.Cast().ToList(); + + Assert.That(columns.All(r => r["constraint_catalog"].Equals(database))); + Assert.That(columns.All(r => r["constraint_schema"].Equals("public"))); + Assert.That(columns.All(r => r["constraint_name"].Equals(constraint))); + Assert.That(columns.All(r => r["table_catalog"].Equals(database))); + Assert.That(columns.All(r => r["table_schema"].Equals("public"))); + Assert.That(columns.All(r => r["table_name"].Equals(table))); + + Assert.That(columns[0]["column_name"], Is.EqualTo("f1")); + Assert.That(columns[1]["column_name"], Is.EqualTo("f2")); + + string[] indexColumnRestrictions3 = { (string) database! , "public", table, constraint, "f1" }; + var dataTable3 = await GetSchema(conn, "INDEXCOLUMNS", indexColumnRestrictions3); + var columns3 = dataTable3.Rows.Cast().ToList(); + Assert.That(columns3.Count, Is.EqualTo(1)); + Assert.That(columns3.Single()["column_name"], Is.EqualTo("f1")); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1886")] + public async Task Column_schema_data_types() + { + await using var conn = await OpenConnectionAsync(); + + var columnDefinition = @" +p0 integer PRIMARY KEY NOT NULL, +achar char, +char character(3), +vchar character varying(10), +text text, +bytea bytea, +abit bit(1), +bit bit(3), +vbit bit varying(5), +boolean boolean, +smallint smallint, +integer integer, +bigint bigint, +real real, +double double precision, +numeric numeric, +money money, +date date, +timetz time with time zone, +timestamptz timestamp with time zone, +time time without time zone, +timestamp timestamp without time zone, +point point, +box box, +lseg lseg, +path path, +polygon polygon, +circle circle, +line line, +inet inet, +macaddr macaddr, +uuid uuid, +interval interval, +name name, +refcursor refcursor, +numrange numrange, +oidvector oidvector, +""bigint[]"" bigint[], +cidr cidr, +maccaddr8 macaddr8, +jsonb jsonb, +json json, +xml xml, +tsvector tsvector, +tsquery tsquery, +tid tid, +xid xid, +cid cid"; + var table = await CreateTempTable(conn, columnDefinition); + + var columnsSchema = await GetSchema(conn, "Columns", new[] { null, null, table }); + var columns = columnsSchema.Rows.Cast().ToList(); + + var dataTypes = await GetSchema(conn, DbMetaDataCollectionNames.DataTypes); + + var nonMatching = columns.FirstOrDefault(col => !dataTypes.Rows.Cast().Any(row => row["TypeName"].Equals(col["data_type"]))); + if (nonMatching is not null) + Assert.Fail($"Could not find matching data type for column {nonMatching["column_name"]} with type {nonMatching["data_type"]}"); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4392")] + public async Task Enum_in_public_schema() + { + await using var conn = await OpenConnectionAsync(); + var enumName = await GetTempTypeName(conn); + var table = await GetTempTableName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE TYPE {enumName} AS ENUM ('red', 'yellow', 'blue'); +CREATE TABLE {table} (color {enumName});"); + + var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var row = dataTable.Rows.Cast().Single(); + Assert.That(row["data_type"], Is.EqualTo(enumName)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4392")] + public async Task Enum_in_non_public_schema() + { + await using var conn = await OpenConnectionAsync(); + const string enumName = "my_enum"; + var schema = await CreateTempSchema(conn); + var table = await GetTempTableName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE TYPE {schema}.{enumName} AS ENUM ('red', 'yellow', 'blue'); +CREATE TABLE {table} (color {schema}.{enumName});"); + + var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var row = dataTable.Rows.Cast().Single(); + Assert.That(row["data_type"], Is.EqualTo($"{schema}.{enumName}")); + } + + [Test] + public async Task SlimBuilder_introspection_without_unsupported_type_exceptions() + { + await using var dataSource = new NpgsqlSlimDataSourceBuilder(ConnectionString).Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(() => GetSchema(conn, DbMetaDataCollectionNames.DataTypes), Throws.Nothing); + } + + public SchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } + + // ReSharper disable MethodHasAsyncOverload + async Task GetSchema(NpgsqlConnection conn) + => IsAsync ? await conn.GetSchemaAsync() : conn.GetSchema(); + + async Task GetSchema(NpgsqlConnection conn, string collectionName) + => IsAsync ? await conn.GetSchemaAsync(collectionName) : conn.GetSchema(collectionName); + + async Task GetSchema(NpgsqlConnection conn, string collectionName, string?[] restrictions) + => IsAsync ? await conn.GetSchemaAsync(collectionName, restrictions) : conn.GetSchema(collectionName, restrictions); + // ReSharper restore MethodHasAsyncOverload } diff --git a/test/Npgsql.Tests/SecurityTests.cs b/test/Npgsql.Tests/SecurityTests.cs index c341dda6e5..8600942969 100644 --- a/test/Npgsql.Tests/SecurityTests.cs +++ b/test/Npgsql.Tests/SecurityTests.cs @@ -1,225 +1,485 @@ using System; +using System.IO; +using System.Runtime.InteropServices; +using System.Security.Authentication; using System.Threading; +using System.Threading.Tasks; +using Npgsql.Properties; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class SecurityTests : TestBase { - public class SecurityTests : TestBase + [Test, Description("Establishes an SSL connection, assuming a self-signed server certificate")] + public void Basic_ssl() { - [Test, Description("Establishes an SSL connection, assuming a self-signed server certificate")] - public void BasicSsl() + using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - SslMode = SslMode.Require, - TrustServerCertificate = true - }; + csb.SslMode = SslMode.Require; + }); + using var conn = dataSource.OpenConnection(); + Assert.That(conn.IsSecure, Is.True); + } + + [Test, Description("Default user must run with md5 password encryption")] + public void Default_user_uses_md5_password() + { + if (!IsOnBuildServer) + Assert.Ignore("Only executed in CI"); + + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + }); + using var conn = dataSource.OpenConnection(); + Assert.That(conn.IsScram, Is.False); + Assert.That(conn.IsScramPlus, Is.False); + } + + [Test, Description("Makes sure a certificate whose root CA isn't known isn't accepted")] + public void Reject_self_signed_certificate([Values(SslMode.VerifyCA, SslMode.VerifyFull)] SslMode sslMode) + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + SslMode = sslMode, + CheckCertificateRevocation = false, + }; + + using var _ = CreateTempPool(csb, out var connString); + using var conn = new NpgsqlConnection(connString); + + var ex = Assert.Throws(conn.Open)!; + Assert.That(ex.InnerException, Is.TypeOf()); + } + + [Test, Description("Makes sure that ssl_renegotiation_limit is always 0, renegotiation is buggy")] + public void No_ssl_renegotiation() + { + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + }); + using var conn = dataSource.OpenConnection(); + Assert.That(conn.ExecuteScalar("SHOW ssl_renegotiation_limit"), Is.EqualTo("0")); + conn.ExecuteNonQuery("DISCARD ALL"); + Assert.That(conn.ExecuteScalar("SHOW ssl_renegotiation_limit"), Is.EqualTo("0")); + } - using (var conn = OpenConnection(csb)) - Assert.That(conn.IsSecure, Is.True); + [Test, Description("Makes sure that when SSL is disabled IsSecure returns false")] + public void IsSecure_without_ssl() + { + using var dataSource = CreateDataSource(csb => csb.SslMode = SslMode.Disable); + using var conn = dataSource.OpenConnection(); + Assert.That(conn.IsSecure, Is.False); + } + + [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] + public void IntegratedSecurity_with_Username() + { + var username = Environment.UserName; + if (username == null) + throw new Exception("Could find username"); + + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Username = username, + Password = null + }.ToString(); + using var conn = new NpgsqlConnection(connString); + try + { + conn.Open(); + } + catch (Exception e) + { + if (IsOnBuildServer) + throw; + Console.WriteLine(e); + Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); } + } - [Test, Description("Default user must run with md5 password encryption")] - public void DefaultUserUsesMd5Password() + [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] + public void IntegratedSecurity_without_Username() + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - SslMode = SslMode.Require, - TrustServerCertificate = true - }; + Username = null, + Password = null + }.ToString(); + using var conn = new NpgsqlConnection(connString); + try + { + conn.Open(); + } + catch (Exception e) + { + if (IsOnBuildServer) + throw; + Console.WriteLine(e); + Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); + } + } - using (var conn = OpenConnection(csb)) + [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] + public void Connection_database_is_populated_on_Open() + { + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + { + Username = null, + Password = null, + Database = null + }.ToString(); + using var conn = new NpgsqlConnection(connString); + try + { + conn.Open(); + } + catch (Exception e) + { + if (IsOnBuildServer) + throw; + Console.WriteLine(e); + Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); + } + Assert.That(conn.Database, Is.Not.Null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1718")] + public void Bug1718() + { + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + }); + using var conn = dataSource.OpenConnection(); + using var tx = conn.BeginTransaction(); + using var cmd = CreateSleepCommand(conn, 10000); + var cts = new CancellationTokenSource(1000).Token; + Assert.That(async () => await cmd.ExecuteNonQueryAsync(cts), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf() + .With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); + } + + [Test] + public void ScramPlus() + { + try + { + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + csb.Username = "npgsql_tests_scram"; + csb.Password = "npgsql_tests_scram"; + }); + using var conn = dataSource.OpenConnection(); + // scram-sha-256-plus only works beginning from PostgreSQL 11 + if (conn.PostgreSqlVersion.Major >= 11) { Assert.That(conn.IsScram, Is.False); + Assert.That(conn.IsScramPlus, Is.True); + } + else + { + Assert.That(conn.IsScram, Is.True); Assert.That(conn.IsScramPlus, Is.False); } } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("scram-sha-256-plus doesn't seem to be set up"); + } + } - [Test, Description("Makes sure a certificate whose root CA isn't known isn't accepted")] - public void RejectSelfSignedCertificate() + [Test] + public void ScramPlus_channel_binding([Values] ChannelBinding channelBinding) + { + try { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) + using var dataSource = CreateDataSource(csb => { - SslMode = SslMode.Require - }.ToString(); + csb.SslMode = SslMode.Require; + csb.Username = "npgsql_tests_scram"; + csb.Password = "npgsql_tests_scram"; + csb.ChannelBinding = channelBinding; + }); + // scram-sha-256-plus only works beginning from PostgreSQL 11 + MinimumPgVersion(dataSource, "11.0"); + using var conn = dataSource.OpenConnection(); - using (var conn = new NpgsqlConnection(connString)) + if (channelBinding == ChannelBinding.Disable) { - // The following is necessary since a pooled connector may exist from a previous - // SSL test - NpgsqlConnection.ClearPool(conn); - - // TODO: Specific exception, align with SslStream - Assert.That(() => conn.Open(), Throws.Exception); + Assert.That(conn.IsScram, Is.True); + Assert.That(conn.IsScramPlus, Is.False); + } + else + { + Assert.That(conn.IsScram, Is.False); + Assert.That(conn.IsScramPlus, Is.True); } } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("scram-sha-256-plus doesn't seem to be set up"); + } + } - [Test, Description("Makes sure that ssl_renegotiation_limit is always 0, renegotiation is buggy")] - public void NoSslRenegotiation() + [Test] + public async Task Connect_with_only_ssl_allowed_user([Values] bool multiplexing, [Values] bool keepAlive) + { + if (multiplexing && keepAlive) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - SslMode = SslMode.Require, - TrustServerCertificate = true - }; + Assert.Ignore("Multiplexing doesn't support keepalive"); + } - using (var conn = OpenConnection(csb)) + try + { + await using var dataSource = CreateDataSource(csb => { - Assert.That(conn.ExecuteScalar("SHOW ssl_renegotiation_limit"), Is.EqualTo("0")); - conn.ExecuteNonQuery("DISCARD ALL"); - Assert.That(conn.ExecuteScalar("SHOW ssl_renegotiation_limit"), Is.EqualTo("0")); - } + csb.SslMode = SslMode.Allow; + csb.Username = "npgsql_tests_ssl"; + csb.Password = "npgsql_tests_ssl"; + csb.Multiplexing = multiplexing; + csb.KeepAlive = keepAlive ? 10 : 0; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.IsTrue(conn.IsSecure); } - - [Test, Description("Makes sure that when SSL is disabled IsSecure returns false")] - public void NonSecure() + catch (Exception e) when (!IsOnBuildServer) { - using (var conn = OpenConnection()) - Assert.That(conn.IsSecure, Is.False); + Console.WriteLine(e); + Assert.Ignore("Only ssl user doesn't seem to be set up"); } + } - [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] - public void IntegratedSecurityWithUsername() + [Test] + [Platform(Exclude = "Win", Reason = "Postgresql doesn't close connection correctly on windows which might result in missing error message")] + public async Task Connect_with_only_non_ssl_allowed_user([Values] bool multiplexing, [Values] bool keepAlive) + { + if (multiplexing && keepAlive) { - var username = Environment.UserName; - if (username == null) - throw new Exception("Could find username"); + Assert.Ignore("Multiplexing doesn't support keepalive"); + } - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - IntegratedSecurity = true, - Username = username, - Password = null - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) + try + { + await using var dataSource = CreateDataSource(csb => { - try - { - conn.Open(); - } - catch (Exception e) - { - if (TestUtil.IsOnBuildServer) - throw; - Console.WriteLine(e); - Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); - } - } + csb.SslMode = SslMode.Prefer; + csb.Username = "npgsql_tests_nossl"; + csb.Password = "npgsql_tests_nossl"; + csb.Multiplexing = multiplexing; + csb.KeepAlive = keepAlive ? 10 : 0; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.IsFalse(conn.IsSecure); + } + catch (NpgsqlException ex) when (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && ex.InnerException is IOException) + { + // Windows server to windows client invites races that can cause the socket to be reset before all data can be read. + // https://www.postgresql.org/message-id/flat/90b34057-4176-7bb0-0dbb-9822a5f6425b%40greiz-reinsdorf.de + // https://www.postgresql.org/message-id/flat/16678-253e48d34dc0c376@postgresql.org + Assert.Ignore(); + } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("Only nonssl user doesn't seem to be set up"); } + } + + [Test] + public async Task DataSource_UserCertificateValidationCallback_is_invoked([Values] bool acceptCertificate) + { + var callbackWasInvoked = false; - [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] - public void IntegratedSecurityWithoutUsername() + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require; + dataSourceBuilder.UseUserCertificateValidationCallback((_, _, _, _) => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - IntegratedSecurity = true, - Username = null, - Password = null - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - try - { - conn.Open(); - } - catch (Exception e) - { - if (TestUtil.IsOnBuildServer) - throw; - Console.WriteLine(e); - Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); - } - } + callbackWasInvoked = true; + return acceptCertificate; + }); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = dataSource.CreateConnection(); + + if (acceptCertificate) + Assert.DoesNotThrowAsync(async () => await connection.OpenAsync()); + else + { + var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; + Assert.That(ex.InnerException, Is.TypeOf()); } - [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] - public void ConnectionDatabasePopulatedOnConnect() + Assert.That(callbackWasInvoked); + } + + [Test] + public async Task Connection_UserCertificateValidationCallback_is_invoked([Values] bool acceptCertificate) + { + var callbackWasInvoked = false; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require; + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = dataSource.CreateConnection(); + connection.UserCertificateValidationCallback = (_, _, _, _) => { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - IntegratedSecurity = true, - Username = null, - Password = null, - Database = null - }.ToString(); - using (var conn = new NpgsqlConnection(connString)) - { - try - { - conn.Open(); - } - catch (Exception e) - { - if (TestUtil.IsOnBuildServer) - throw; - Console.WriteLine(e); - Assert.Ignore("Integrated security (GSS/SSPI) doesn't seem to be set up"); - } - Assert.That(conn.Database, Is.Not.Null); - } + callbackWasInvoked = true; + return acceptCertificate; + }; + + if (acceptCertificate) + Assert.DoesNotThrowAsync(async () => await connection.OpenAsync()); + else + { + var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; + Assert.That(ex.InnerException, Is.TypeOf()); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1718")] - [Timeout(12000)] - public void Bug1718() + Assert.That(callbackWasInvoked); + } + + [Test] + public void Connect_with_Verify_and_callback_throws([Values(SslMode.VerifyCA, SslMode.VerifyFull)] SslMode sslMode) + { + using var dataSource = CreateDataSource(csb => csb.SslMode = sslMode); + using var connection = dataSource.CreateConnection(); + connection.UserCertificateValidationCallback = (_, _, _, _) => true; + + var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; + Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode))); + } + + [Test] + public void Connect_with_RootCertificate_and_callback_throws() + { + using var dataSource = CreateDataSource(csb => { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - SslMode = SslMode.Require, - TrustServerCertificate = true - }; + csb.SslMode = SslMode.Require; + csb.RootCertificate = "foo"; + }); + using var connection = dataSource.CreateConnection(); + connection.UserCertificateValidationCallback = (_, _, _, _) => true; - using (var conn = OpenConnection(csb)) - using (var cmd = CreateSleepCommand(conn, 10000)) - { - var cts = new CancellationTokenSource(1000).Token; - Assert.That(async () => await cmd.ExecuteNonQueryAsync(cts), Throws.Exception - .TypeOf() - .With.InnerException.TypeOf() - .With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); - } + var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; + Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback))); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4305")] + public async Task Bug4305_Secure([Values] bool async) + { + await using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + csb.Username = "npgsql_tests_ssl"; + csb.Password = "npgsql_tests_ssl"; + csb.MaxPoolSize = 1; + }); + + NpgsqlConnection conn = default!; + + try + { + conn = await dataSource.OpenConnectionAsync(); + Assert.IsTrue(conn.IsSecure); + } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("Only ssl user doesn't seem to be set up"); } - [Test] - [Timeout(2000)] - public void ConnectToDatabaseUsingScramPlus() + await using var __ = conn; + await using var cmd = conn.CreateCommand(); + await using (var tx = await conn.BeginTransactionAsync()) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - SslMode = SslMode.Require, - TrustServerCertificate = true, - Username = "npgsql_tests_scram", - Password = "npgsql_tests_scram", - }; + var originalConnector = conn.Connector; - try - { - using var conn = OpenConnection(csb); - // scram-sha-256-plus only works begining from PostgreSQL 11 - if (conn.PostgreSqlVersion.Major >= 11) - { - Assert.That(conn.IsScramPlus, Is.True); - } - } - catch (Exception e) when (!TestUtil.IsOnBuildServer) - { - Console.WriteLine(e); - Assert.Ignore("scram-sha-256-plus doesn't seem to be set up"); - } + cmd.CommandText = "select pg_sleep(30)"; + cmd.CommandTimeout = 3; + var ex = async + ? Assert.ThrowsAsync(() => cmd.ExecuteNonQueryAsync())! + : Assert.Throws(() => cmd.ExecuteNonQuery())!; + Assert.That(ex.InnerException, Is.TypeOf()); + + await conn.CloseAsync(); + await conn.OpenAsync(); + + Assert.AreSame(originalConnector, conn.Connector); } - #region Setup / Teardown / Utils + cmd.CommandText = "SELECT 1"; + if (async) + Assert.DoesNotThrowAsync(async () => await cmd.ExecuteNonQueryAsync()); + else + Assert.DoesNotThrow(() => cmd.ExecuteNonQuery()); + } - [SetUp] - public void CheckSslSupport() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4305")] + public async Task Bug4305_not_Secure([Values] bool async) + { + await using var dataSource = CreateDataSource(csb => { - using (var conn = OpenConnection()) - { - var sslSupport = (string)conn.ExecuteScalar("SHOW ssl")!; - if (sslSupport == "off") - TestUtil.IgnoreExceptOnBuildServer("SSL support isn't enabled at the backend"); - } + csb.SslMode = SslMode.Disable; + csb.Username = "npgsql_tests_nossl"; + csb.Password = "npgsql_tests_nossl"; + csb.MaxPoolSize = 1; + }); + + NpgsqlConnection conn = default!; + + try + { + conn = await dataSource.OpenConnectionAsync(); + Assert.IsFalse(conn.IsSecure); + } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("Only nossl user doesn't seem to be set up"); } - #endregion + await using var __ = conn; + var originalConnector = conn.Connector; + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "select pg_sleep(30)"; + cmd.CommandTimeout = 3; + var ex = async + ? Assert.ThrowsAsync(() => cmd.ExecuteNonQueryAsync())! + : Assert.Throws(() => cmd.ExecuteNonQuery())!; + Assert.That(ex.InnerException, Is.TypeOf()); + + await conn.CloseAsync(); + await conn.OpenAsync(); + + Assert.AreSame(originalConnector, conn.Connector); + + cmd.CommandText = "SELECT 1"; + if (async) + Assert.DoesNotThrowAsync(async () => await cmd.ExecuteNonQueryAsync()); + else + Assert.DoesNotThrow(() => cmd.ExecuteNonQuery()); + } + + #region Setup / Teardown / Utils + + [OneTimeSetUp] + public void CheckSslSupport() + { + using var conn = OpenConnection(); + var sslSupport = (string)conn.ExecuteScalar("SHOW ssl")!; + if (sslSupport == "off") + IgnoreExceptOnBuildServer("SSL support isn't enabled at the backend"); } + + #endregion } diff --git a/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs b/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs index 4167040b45..52de32bccf 100644 --- a/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs +++ b/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs @@ -1,53 +1,74 @@ using System.Collections.Generic; +using System.Globalization; using System.Linq; using Npgsql.NameTranslation; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class SnakeCaseNameTranslatorTests { - [TestFixture] - public class SnakeCaseNameTranslatorTests + static readonly CultureInfo trTRCulture = new("tr-TR"); + static readonly CultureInfo enUSCulture = new("en-US"); + + [Test, TestCaseSource(typeof(SnakeCaseNameTranslatorTests), nameof(TestCases))] + public string TranslateTypeName(CultureInfo? culture, string value, bool legacyMode) + => new NpgsqlSnakeCaseNameTranslator(legacyMode, culture).TranslateTypeName(value); + + [Test, TestCaseSource(typeof(SnakeCaseNameTranslatorTests), nameof(TestCases))] + public string TranslateMemberName(CultureInfo? culture, string value, bool legacyMode) + => new NpgsqlSnakeCaseNameTranslator(legacyMode, culture).TranslateMemberName(value); + + static IEnumerable TestCases => new (CultureInfo? culture, string value, string legacyResult, string result)[] + { + (null, "Hi!! This is text. Time to test.", "hi!! _this is text. _time to test.", "hi_this_is_text_time_to_test"), + (null, "9999-12-31T23:59:59.9999999Z", "9999-12-31_t23:59:59.9999999_z", "9999_12_31t23_59_59_9999999z"), + (null, "FK_post_simple_blog_BlogId", "f_k_post_simple_blog__blog_id", "fk_post_simple_blog_blog_id"), + (null, "already_snake_case_ ", "already_snake_case_ ", "already_snake_case_"), + (null, "SHOUTING_CASE", "s_h_o_u_t_i_n_g__c_a_s_e", "shouting_case"), + (null, "IsJSONProperty", "is_j_s_o_n_property", "is_json_property"), + (null, "SnA__ kEcAsE", "sn_a__ k_ec_as_e", "sn_a__k_ec_as_e"), + (null, "SnA__kEcAsE", "sn_a__k_ec_as_e", "sn_a__k_ec_as_e"), + (null, "SnAkEcAsE", "sn_ak_ec_as_e", "sn_ak_ec_as_e"), + (null, "URLValue", "u_r_l_value", "url_value"), + (null, "Xml2Json", "xml2_json", "xml2json"), + (null, " IPhone ", " _i_phone ", "i_phone"), + (null, "I Phone", "i _phone", "i_phone"), + (null, " IPhone", " _i_phone", "i_phone"), + (null, "I Phone", "i _phone", "i_phone"), + (null, "IPhone", "i_phone", "i_phone"), + (null, "iPhone", "i_phone", "i_phone"), + (null, "IsCIA", "is_c_i_a", "is_cia"), + (null, "Person", "person", "person"), + (null, "ABC123", "a_b_c123", "abc123"), + (null, "VmQ", "vm_q", "vm_q"), + (null, "URL", "u_r_l", "url"), + (null, "AB1", "a_b1", "ab1"), + (null, "ID", "i_d", "id"), + (null, "I", "i", "i"), + (null, "", "", ""), + (trTRCulture, "IPhone", "ı_phone", "ı_phone"), // dotless I -> dotless ı + (enUSCulture, "IPhone", "i_phone", "i_phone"), + (CultureInfo.InvariantCulture, "IPhone", "i_phone", "i_phone"), + }.SelectMany(x => new[] { - [Test, TestCaseSource(typeof(SnakeCaseNameTranslatorTests), nameof(TestCases))] - public string TranslateTypeName(string value, bool legacyMode) - => new NpgsqlSnakeCaseNameTranslator(legacyMode).TranslateTypeName(value); - - [Test, TestCaseSource(typeof(SnakeCaseNameTranslatorTests), nameof(TestCases))] - public string TranslateMemberName(string value, bool legacyMode) - => new NpgsqlSnakeCaseNameTranslator(legacyMode).TranslateMemberName(value); - - static IEnumerable TestCases => new (string value, string legacyResult, string result)[] - { - ("Hi!! This is text. Time to test.", "hi!! _this is text. _time to test.", "hi_this_is_text_time_to_test"), - ("9999-12-31T23:59:59.9999999Z", "9999-12-31_t23:59:59.9999999_z", "9999_12_31t23_59_59_9999999z"), - ("FK_post_simple_blog_BlogId", "f_k_post_simple_blog__blog_id", "fk_post_simple_blog_blog_id"), - ("already_snake_case_ ", "already_snake_case_ ", "already_snake_case_"), - ("SHOUTING_CASE", "s_h_o_u_t_i_n_g__c_a_s_e", "shouting_case"), - ("IsJSONProperty", "is_j_s_o_n_property", "is_json_property"), - ("SnA__ kEcAsE", "sn_a__ k_ec_as_e", "sn_a__k_ec_as_e"), - ("SnA__kEcAsE", "sn_a__k_ec_as_e", "sn_a__k_ec_as_e"), - ("SnAkEcAsE", "sn_ak_ec_as_e", "sn_ak_ec_as_e"), - ("URLValue", "u_r_l_value", "url_value"), - ("Xml2Json", "xml2_json", "xml2json"), - (" IPhone ", " _i_phone ", "i_phone"), - ("I Phone", "i _phone", "i_phone"), - (" IPhone", " _i_phone", "i_phone"), - ("I Phone", "i _phone", "i_phone"), - ("IPhone", "i_phone", "i_phone"), - ("iPhone", "i_phone", "i_phone"), - ("IsCIA", "is_c_i_a", "is_cia"), - ("Person", "person", "person"), - ("ABC123", "a_b_c123", "abc123"), - ("VmQ", "vm_q", "vm_q"), - ("URL", "u_r_l", "url"), - ("AB1", "a_b1", "ab1"), - ("ID", "i_d", "id"), - ("I", "i", "i"), - ("", "", "") - }.SelectMany(x => new[] - { - new TestCaseData(x.value, true).Returns(x.legacyResult), - new TestCaseData(x.value, false).Returns(x.result), - }); + new TestCaseData(x.culture, x.value, true).Returns(x.legacyResult), + new TestCaseData(x.culture, x.value, false).Returns(x.result), + }); + + [Test, Description("Checks translating a name with letter 'I' in Turkish locale with default setting (Invariant Culture)")] + [SetCulture("tr-TR")] + public void TurkeyTest() + { + var translator = new NpgsqlSnakeCaseNameTranslator(); + var legacyTranslator = new NpgsqlSnakeCaseNameTranslator(true); + + const string clrName = "IPhone"; + const string expected = "i_phone"; + + Assert.AreEqual(expected, translator.TranslateMemberName(clrName)); + Assert.AreEqual(expected, translator.TranslateTypeName(clrName)); + Assert.AreEqual(expected, legacyTranslator.TranslateMemberName(clrName)); + Assert.AreEqual(expected, legacyTranslator.TranslateTypeName(clrName)); } } diff --git a/test/Npgsql.Tests/SqlQueryParserTests.cs b/test/Npgsql.Tests/SqlQueryParserTests.cs index 6e3e9b647c..1044b707fc 100644 --- a/test/Npgsql.Tests/SqlQueryParserTests.cs +++ b/test/Npgsql.Tests/SqlQueryParserTests.cs @@ -3,183 +3,203 @@ using System.Linq; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +class SqlQueryParserTests { - class SqlQueryParserTests + [Test] + public void Parameter_simple() { - [Test] - public void ParamSimple() - { - _params.AddWithValue(":p1", "foo"); - _params.AddWithValue(":p2", "bar"); - _parser.ParseRawQuery("SELECT :p1, :p2", _params, _queries); - Assert.That(_queries.Single().InputParameters, Is.EqualTo(_params)); - } - - [Test] - public void ParamNameWithDot() - { - _params.AddWithValue(":a.parameter", "foo"); - _parser.ParseRawQuery("INSERT INTO data (field_char5) VALUES (:a.parameter)", _params, _queries); - Assert.That(_queries.Single().InputParameters.Single(), Is.SameAs(_params.Single())); - } - - [Test, Description("Checks several scenarios in which the SQL is supposed to pass untouched")] - [TestCase(@"SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat')", TestName="AtAt")] - [TestCase(@"SELECT 'cat'::tsquery @> 'cat & rat'::tsquery", TestName = "AtGt")] - [TestCase(@"SELECT 'cat'::tsquery <@ 'cat & rat'::tsquery", TestName = "AtLt")] - [TestCase(@"SELECT 'b''la'", TestName = "DoubleTicks")] - [TestCase(@"SELECT 'type(''m.response'')#''O''%'", TestName = "DoubleTicks2")] - [TestCase(@"SELECT 'abc'':str''a:str'", TestName = "DoubleTicks3")] - [TestCase(@"SELECT 1 FROM "":str""", TestName = "DoubleQuoted")] - [TestCase(@"SELECT 1 FROM 'yo'::str", TestName = "DoubleColons")] - [TestCase("SELECT $\u00ffabc0$literal string :str :int$\u00ffabc0 $\u00ffabc0$", TestName = "DollarQuotes")] - [TestCase("SELECT $$:str$$", TestName = "DollarQuotesNoTag")] - public void Untouched(string sql) - { - _params.AddWithValue(":param", "foo"); - _parser.ParseRawQuery(sql, _params, _queries); - Assert.That(_queries.Single().SQL, Is.EqualTo(sql)); - Assert.That(_queries.Single().InputParameters, Is.Empty); - } - - [Test] - [TestCase(@"SELECT 1<:param", TestName = "LessThan")] - [TestCase(@"SELECT 1>:param", TestName = "GreaterThan")] - [TestCase(@"SELECT 1<>:param", TestName = "NotEqual")] - [TestCase("SELECT--comment\r:param", TestName="LineComment")] - public void ParamGetsBound(string sql) - { - _params.AddWithValue(":param", "foo"); - _parser.ParseRawQuery(sql, _params, _queries); - Assert.That(_queries.Single().InputParameters.Single(), Is.SameAs(_params.Single())); - } + var parameters = new NpgsqlParameter[] { new(":p1", "foo"), new(":p2", "foo") }; + var result = ParseCommand("SELECT :p1, :p2", parameters).Single(); + Assert.That(result.FinalCommandText, Is.EqualTo("SELECT $1, $2")); + Assert.That(result.PositionalParameters, Is.EquivalentTo(parameters)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1177")] - public void ParamGetsBoundNonAscii() - { - _params.AddWithValue("漢字", "foo"); - _parser.ParseRawQuery("SELECT @漢字", _params, _queries); - Assert.That(_queries.Single().InputParameters.Single(), Is.SameAs(_params.Single())); - } - - [Test] - [TestCase(@"SELECT e'ab\'c:param'", TestName = "Estring")] - [TestCase(@"SELECT/*/* -- nested comment :int /*/* *//*/ **/*/*/*/1")] - [TestCase(@"SELECT 1, + [Test] + public void Parameter_name_with_dot() + { + var p = new NpgsqlParameter(":a.parameter", "foo"); + var results = ParseCommand("INSERT INTO data (field_char5) VALUES (:a.parameter)", p); + Assert.That(results.Single().PositionalParameters.Single(), Is.SameAs(p)); + } + + [Test, Description("Checks several scenarios in which the SQL is supposed to pass untouched")] + [TestCase(@"SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat')", TestName="AtAt")] + [TestCase(@"SELECT 'cat'::tsquery @> 'cat & rat'::tsquery", TestName = "AtGt")] + [TestCase(@"SELECT 'cat'::tsquery <@ 'cat & rat'::tsquery", TestName = "AtLt")] + [TestCase(@"SELECT 'b''la'", TestName = "DoubleTicks")] + [TestCase(@"SELECT 'type(''m.response'')#''O''%'", TestName = "DoubleTicks2")] + [TestCase(@"SELECT 'abc'':str''a:str'", TestName = "DoubleTicks3")] + [TestCase(@"SELECT 1 FROM "":str""", TestName = "DoubleQuoted")] + [TestCase(@"SELECT 1 FROM 'yo'::str", TestName = "DoubleColons")] + [TestCase("SELECT $\u00ffabc0$literal string :str :int$\u00ffabc0 $\u00ffabc0$", TestName = "DollarQuotes")] + [TestCase("SELECT $$:str$$", TestName = "DollarQuotesNoTag")] + public void Untouched(string sql) + { + var results = ParseCommand(sql, new NpgsqlParameter(":param", "foo")); + Assert.That(results.Single().FinalCommandText, Is.EqualTo(sql)); + Assert.That(results.Single().PositionalParameters, Is.Empty); + } + + [Test] + [TestCase(@"SELECT 1<:param", TestName = "LessThan")] + [TestCase(@"SELECT 1>:param", TestName = "GreaterThan")] + [TestCase(@"SELECT 1<>:param", TestName = "NotEqual")] + [TestCase("SELECT--comment\r:param", TestName="LineComment")] + public void Parameter_gets_bound(string sql) + { + var p = new NpgsqlParameter(":param", "foo"); + var results = ParseCommand(sql, p); + Assert.That(results.Single().PositionalParameters.Single(), Is.SameAs(p)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1177")] + public void Parameter_gets_bound_non_ascii() + { + var p = new NpgsqlParameter("漢字", "foo"); + var results = ParseCommand("SELECT @漢字", p); + Assert.That(results.Single().PositionalParameters.Single(), Is.SameAs(p)); + } + + [Test] + [TestCase(@"SELECT e'ab\'c:param'", TestName = "Estring")] + [TestCase(@"SELECT/*/* -- nested comment :int /*/* *//*/ **/*/*/*/1")] + [TestCase(@"SELECT 1, -- Comment, @param and also :param 2", TestName = "LineComment")] - public void ParamDoesntGetBound(string sql) - { - _params.AddWithValue(":param", "foo"); - _parser.ParseRawQuery(sql, _params, _queries); - Assert.That(_queries.Single().InputParameters, Is.Empty); - } + public void Parameter_does_not_get_bound(string sql) + { + var p = new NpgsqlParameter(":param", "foo"); + var results = ParseCommand(sql, p); + Assert.That(results.Single().PositionalParameters, Is.Empty); + } - [Test] - public void MultiqueryWithParams() - { - var p1 = new NpgsqlParameter("p1", DbType.String); - _params.Add(p1); - var p2 = new NpgsqlParameter("p2", DbType.String); - _params.Add(p2); - var p3 = new NpgsqlParameter("p3", DbType.String); - _params.Add(p3); - - _parser.ParseRawQuery("SELECT @p3, @p1; SELECT @p2, @p3", _params, _queries); - - Assert.That(_queries, Has.Count.EqualTo(2)); - Assert.That(_queries[0].InputParameters[0], Is.SameAs(p3)); - Assert.That(_queries[0].InputParameters[1], Is.SameAs(p1)); - Assert.That(_queries[1].InputParameters[0], Is.SameAs(p2)); - Assert.That(_queries[1].InputParameters[1], Is.SameAs(p3)); - } - - [Test] - public void NoOutputParameters() - { - var p = new NpgsqlParameter("p", DbType.String) { Direction = ParameterDirection.Output }; - _params.Add(p); - Assert.That(() => _parser.ParseRawQuery("SELECT @p", _params, _queries), Throws.Exception); - } + [Test] + public void Non_conforming_string() + { + var result = ParseCommand(@"SELECT 'abc\':str''a:str'").Single(); + Assert.That(result.FinalCommandText, Is.EqualTo(@"SELECT 'abc\':str''a:str'")); + Assert.That(result.PositionalParameters, Is.Empty); + } - [Test] - public void MissingParamIsIgnored() - { - _parser.ParseRawQuery("SELECT @p; SELECT 1", _params, _queries); - Assert.That(_queries[0].SQL, Is.EqualTo("SELECT @p")); - Assert.That(_queries[1].SQL, Is.EqualTo("SELECT 1")); - Assert.That(_queries[0].InputParameters, Is.Empty); - Assert.That(_queries[1].InputParameters, Is.Empty); - } - - [Test] - public void ConsecutiveSemicolons() - { - _parser.ParseRawQuery(";;SELECT 1", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(1)); - } + [Test] + public void Multiquery_with_parameters() + { + var parameters = new NpgsqlParameter[] + { + new("p1", DbType.String), + new("p2", DbType.String), + new("p3", DbType.String), + }; + + var results = ParseCommand("SELECT @p3, @p1; SELECT @p2, @p3", parameters); + + Assert.That(results, Has.Count.EqualTo(2)); + Assert.That(results[0].FinalCommandText, Is.EqualTo("SELECT $1, $2")); + Assert.That(results[0].PositionalParameters[0], Is.SameAs(parameters[2])); + Assert.That(results[0].PositionalParameters[1], Is.SameAs(parameters[0])); + Assert.That(results[1].FinalCommandText, Is.EqualTo("SELECT $1, $2")); + Assert.That(results[1].PositionalParameters[0], Is.SameAs(parameters[1])); + Assert.That(results[1].PositionalParameters[1], Is.SameAs(parameters[2])); + } - [Test] - public void TrailingSemicolon() - { - _parser.ParseRawQuery("SELECT 1;", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(1)); - } + [Test] + public void No_output_parameters() + { + var p = new NpgsqlParameter("p", DbType.String) { Direction = ParameterDirection.Output }; + Assert.That(() => ParseCommand("SELECT @p", p), Throws.Exception); + } - [Test] - public void Empty() - { - _parser.ParseRawQuery("", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(1)); - } + [Test] + public void Missing_parameter_is_ignored() + { + var results = ParseCommand("SELECT @p; SELECT 1"); + Assert.That(results[0].FinalCommandText, Is.EqualTo("SELECT @p")); + Assert.That(results[1].FinalCommandText, Is.EqualTo("SELECT 1")); + Assert.That(results[0].PositionalParameters, Is.Empty); + Assert.That(results[1].PositionalParameters, Is.Empty); + } - [Test] - public void SemicolonInParentheses() - { - _parser.ParseRawQuery("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1; SELECT 1)", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(1)); - } + [Test] + public void Consecutive_semicolons() + { + var results = ParseCommand(";;SELECT 1"); + Assert.That(results, Has.Count.EqualTo(3)); + Assert.That(results[0].FinalCommandText, Is.Empty); + Assert.That(results[1].FinalCommandText, Is.Empty); + Assert.That(results[2].FinalCommandText, Is.EqualTo("SELECT 1")); + } - [Test] - public void SemicolonAfterParentheses() - { - _parser.ParseRawQuery("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1); SELECT 1", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(2)); - } + [Test] + public void Trailing_semicolon() + { + var results = ParseCommand("SELECT 1;"); + Assert.That(results, Has.Count.EqualTo(1)); + Assert.That(results[0].FinalCommandText, Is.EqualTo("SELECT 1")); + } - [Test] - public void ReduceNumberOfStatements() - { - _parser.ParseRawQuery("SELECT 1; SELECT 2", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(2)); - _parser.ParseRawQuery("SELECT 1", _params, _queries); - Assert.That(_queries, Has.Count.EqualTo(1)); - } + [Test] + public void Empty() + { + var results = ParseCommand(""); + Assert.That(results, Has.Count.EqualTo(1)); + Assert.That(results[0].FinalCommandText, Is.Empty); + } + + [Test] + public void Semicolon_in_parentheses() + { + var results = ParseCommand("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1; SELECT 1)"); + Assert.That(results, Has.Count.EqualTo(1)); + Assert.That(results[0].FinalCommandText, Is.EqualTo("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1; SELECT 1)")); + } + + [Test] + public void Semicolon_after_parentheses() + { + var results = ParseCommand("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1); SELECT 1"); + Assert.That(results, Has.Count.EqualTo(2)); + Assert.That(results[0].FinalCommandText, Is.EqualTo("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1)")); + Assert.That(results[1].FinalCommandText, Is.EqualTo("SELECT 1")); + } + + [Test] + public void Reduce_number_of_statements() + { + var parser = new SqlQueryParser(); + + var cmd = new NpgsqlCommand("SELECT 1; SELECT 2"); + parser.ParseRawQuery(cmd); + Assert.That(cmd.InternalBatchCommands, Has.Count.EqualTo(2)); + + cmd.CommandText = "SELECT 1"; + parser.ParseRawQuery(cmd); + Assert.That(cmd.InternalBatchCommands, Has.Count.EqualTo(1)); + } #if TODO - [Test] - public void TrimWhitespace() - { - _parser.ParseRawQuery(" SELECT 1\t", _params, _queries); - Assert.That(_queries.Single().Sql, Is.EqualTo("SELECT 1")); - } + [Test] + public void Trim_whitespace() + { + _parser.ParseRawQuery(" SELECT 1\t", _params, _queries, standardConformingStrings: true); + Assert.That(_queries.Single().Sql, Is.EqualTo("SELECT 1")); + } #endif - #region Setup / Teardown / Utils + #region Setup / Teardown / Utils - SqlQueryParser _parser = default!; - List _queries = default!; - NpgsqlParameterCollection _params = default!; + List ParseCommand(string sql, params NpgsqlParameter[] parameters) + => ParseCommand(sql, parameters, standardConformingStrings: true); - [SetUp] - public void SetUp() - { - _parser = new SqlQueryParser(); - _queries = new List(); - _params = new NpgsqlParameterCollection(); - } - - #endregion + List ParseCommand(string sql, NpgsqlParameter[] parameters, bool standardConformingStrings) + { + var cmd = new NpgsqlCommand(sql); + cmd.Parameters.AddRange(parameters); + var parser = new SqlQueryParser(); + parser.ParseRawQuery(cmd, standardConformingStrings); + return cmd.InternalBatchCommands; } + + #endregion } diff --git a/test/Npgsql.Tests/StoredProcedureTests.cs b/test/Npgsql.Tests/StoredProcedureTests.cs new file mode 100644 index 0000000000..84acb51b36 --- /dev/null +++ b/test/Npgsql.Tests/StoredProcedureTests.cs @@ -0,0 +1,431 @@ +using System.Data; +using System.Threading.Tasks; +using Npgsql.PostgresTypes; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +public class StoredProcedureTests : TestBase +{ + [Test] + [TestCase(true, false)] + [TestCase(false, true)] + [TestCase(true, true)] + public async Task With_input_parameters(bool withPositional, bool withNamed) + { + var table = await CreateTempTable(DataSource, "foo int, bar int"); + var sproc = await GetTempProcedureName(DataSource); + + await DataSource.ExecuteNonQueryAsync(@$" +CREATE PROCEDURE {sproc}(a int, b int) +LANGUAGE SQL +AS $$ + INSERT INTO {table} VALUES (a, b); +$$"); + + await using (var command = DataSource.CreateCommand(sproc)) + { + command.CommandType = CommandType.StoredProcedure; + + command.Parameters.Add(withPositional + ? new() { Value = 8 } + : new() { ParameterName = "a", Value = 8 }); + + command.Parameters.Add(withNamed + ? new() { ParameterName = "b", Value = 9 } + : new() { Value = 9 }); + + await command.ExecuteNonQueryAsync(); + } + + await using (var command = DataSource.CreateCommand($"SELECT * FROM {table}")) + await using (var reader = await command.ExecuteReaderAsync()) + { + await reader.ReadAsync(); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(reader[1], Is.EqualTo(9)); + } + } + + [Test] + [TestCase(true, false)] + [TestCase(false, true)] + [TestCase(true, true)] + public async Task With_output_parameters(bool withPositional, bool withNamed) + { + MinimumPgVersion(DataSource, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + + var sproc = await GetTempProcedureName(DataSource); + + await DataSource.ExecuteNonQueryAsync(@$" +CREATE PROCEDURE {sproc}(a int, OUT out1 int, OUT out2 int, b int) +LANGUAGE plpgsql +AS $$ +BEGIN + out1 = a; + out2 = b; +END$$"); + + await using var command = DataSource.CreateCommand(sproc); + command.CommandType = CommandType.StoredProcedure; + + command.Parameters.Add(new() { Value = 8 }); + + command.Parameters.Add(withPositional + ? new() { Direction = ParameterDirection.Output } + : new() { ParameterName = "out1", Direction = ParameterDirection.Output }); + + command.Parameters.Add(withNamed + ? new() { ParameterName = "out2", Direction = ParameterDirection.Output } + : new() { Direction = ParameterDirection.Output }); + + command.Parameters.Add(new() { ParameterName = "b", Value = 9 }); + + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(reader[1], Is.EqualTo(9)); + } + + [Test] + [TestCase(true, false)] + [TestCase(false, true)] + [TestCase(true, true)] + public async Task With_input_output_parameters(bool withPositional, bool withNamed) + { + var sproc = await GetTempProcedureName(DataSource); + + await DataSource.ExecuteNonQueryAsync(@$" +CREATE PROCEDURE {sproc}(a int, INOUT inout1 int, INOUT inout2 int, b int) +LANGUAGE plpgsql +AS $$ +BEGIN + inout1 = inout1 + a; + inout2 = inout2 + b; +END$$"); + + await using var command = DataSource.CreateCommand(sproc); + command.CommandType = CommandType.StoredProcedure; + + command.Parameters.Add(new() { Value = 8 }); + + command.Parameters.Add(withPositional + ? new() { Value = 1, Direction = ParameterDirection.InputOutput } + : new() { ParameterName = "inout1", Value = 1, Direction = ParameterDirection.InputOutput }); + + command.Parameters.Add(withNamed + ? new() { ParameterName = "inout2", Value = 2, Direction = ParameterDirection.InputOutput } + : new() { Value = 2, Direction = ParameterDirection.InputOutput }); + + command.Parameters.Add(new() { ParameterName = "b", Value = 9 }); + + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader[0], Is.EqualTo(9)); + Assert.That(reader[1], Is.EqualTo(11)); + } + + [Test] + public async Task Batch_positional_parameters_works() + { + var tempname = await GetTempProcedureName(DataSource); + await using var connection = await DataSource.OpenConnectionAsync(); + await using var transaction = await connection.BeginTransactionAsync(IsolationLevel.Serializable); + await using var batch = new NpgsqlBatch(connection, transaction) + { + BatchCommands = + { + new(tempname) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { DbType = DbType.Int64, Direction = ParameterDirection.Output } + } + }, + new ("COMMIT") + } + }; + + Assert.ThrowsAsync(() => batch.ExecuteNonQueryAsync()); + } + + [Test] + public async Task Batch_StoredProcedure_output_parameters_works() + { + // Proper OUT params were introduced in PostgreSQL 14 + MinimumPgVersion(DataSource, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + var sproc = await GetTempProcedureName(DataSource); + + await using var connection = await DataSource.OpenConnectionAsync(); + await using var transaction = await connection.BeginTransactionAsync(IsolationLevel.Serializable); + var c = connection.CreateCommand(); + c.CommandText = $""" + CREATE OR REPLACE PROCEDURE {sproc} + ( + p_username TEXT, + OUT p_user_id BIGINT + ) + LANGUAGE plpgsql + AS $$ + BEGIN + p_user_id = 1; + return; + END; + $$; + """; + await c.ExecuteNonQueryAsync(); + + await using var batch = new NpgsqlBatch(connection, transaction) + { + BatchCommands = + { + new(sproc) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { NpgsqlDbType = NpgsqlDbType.Bigint, Direction = ParameterDirection.Output } + } + }, + new(sproc) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { NpgsqlDbType = NpgsqlDbType.Bigint, Direction = ParameterDirection.Output } + } + } + } + }; + + await batch.ExecuteNonQueryAsync(); + Assert.AreEqual(1, batch.BatchCommands[0].Parameters[1].Value); + Assert.AreEqual(1, batch.BatchCommands[1].Parameters[1].Value); + } + + #region DeriveParameters + + [Test, Description("Tests function parameter derivation with IN, OUT and INOUT parameters")] + public async Task DeriveParameters_procedure_various() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + var sproc = await GetTempProcedureName(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE PROCEDURE {sproc}(IN param1 INT, OUT param2 text, INOUT param3 INT) AS $$ +BEGIN + param2 = 'sometext'; + param3 = param1 + param3; +END; +$$ LANGUAGE plpgsql"); + + await using var command = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.That(command.Parameters, Has.Count.EqualTo(3)); + Assert.That(command.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[0].PostgresType, Is.TypeOf()); + Assert.That(command.Parameters[0].DataTypeName, Is.EqualTo("integer")); + Assert.That(command.Parameters[0].ParameterName, Is.EqualTo("param1")); + Assert.That(command.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Output)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); + Assert.That(command.Parameters[1].PostgresType, Is.TypeOf()); + Assert.That(command.Parameters[1].DataTypeName, Is.EqualTo("text")); + Assert.That(command.Parameters[1].ParameterName, Is.EqualTo("param2")); + Assert.That(command.Parameters[2].Direction, Is.EqualTo(ParameterDirection.InputOutput)); + Assert.That(command.Parameters[2].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[2].PostgresType, Is.TypeOf()); + Assert.That(command.Parameters[2].DataTypeName, Is.EqualTo("integer")); + Assert.That(command.Parameters[2].ParameterName, Is.EqualTo("param3")); + command.Parameters[0].Value = 5; + command.Parameters[2].Value = 4; + await command.ExecuteNonQueryAsync(); + Assert.That(command.Parameters[0].Value, Is.EqualTo(5)); + Assert.That(command.Parameters[1].Value, Is.EqualTo("sometext")); + Assert.That(command.Parameters[2].Value, Is.EqualTo(9)); + } + + [Test, Description("Tests function parameter derivation with IN-only parameters")] + public async Task DeriveParameters_procedure_in_only() + { + await using var conn = await OpenConnectionAsync(); + var sproc = await GetTempProcedureName(conn); + + await conn.ExecuteNonQueryAsync($@"CREATE PROCEDURE {sproc}(IN param1 INT, IN param2 INT) AS '' LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Has.Count.EqualTo(2)); + Assert.That(cmd.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(cmd.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); + cmd.Parameters[0].Value = 5; + cmd.Parameters[1].Value = 4; + Assert.DoesNotThrowAsync(() => cmd.ExecuteNonQueryAsync()); + } + + [Test, Description("Tests function parameter derivation with no parameters")] + public async Task DeriveParameters_procedure_no_params() + { + await using var conn = await OpenConnectionAsync(); + var sproc = await GetTempProcedureName(conn); + + await conn.ExecuteNonQueryAsync($@"CREATE PROCEDURE {sproc}() AS '' LANGUAGE sql"); + + await using var cmd = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(cmd); + Assert.That(cmd.Parameters, Is.Empty); + } + + [Test] + public async Task DeriveParameters_procedure_with_case_sensitive_name() + { + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync(@"CREATE OR REPLACE PROCEDURE ""ProcedureCaseSensitive""(int4, text) AS '' LANGUAGE sql"); + + try + { + await using var command = new NpgsqlCommand(@"""ProcedureCaseSensitive""", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + } + finally + { + await conn.ExecuteNonQueryAsync(@"DROP PROCEDURE ""ProcedureCaseSensitive"""); + } + } + + [Test, Description("Tests function parameter derivation for quoted functions with double quotes in the name works")] + public async Task DeriveParameters_quote_characters_in_function_name() + { + await using var conn = await OpenConnectionAsync(); + var sproc = @"""""""ProcedureQuote""""CharactersInName"""""""; + await conn.ExecuteNonQueryAsync($"CREATE OR REPLACE PROCEDURE {sproc}(int4, text) AS 'SELECT 0' LANGUAGE sql"); + + try + { + await using var command = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + } + finally + { + await conn.ExecuteNonQueryAsync("DROP PROCEDURE " + sproc); + } + } + + [Test, Description("Tests function parameter derivation for quoted functions with dots in the name works")] + public async Task DeriveParameters_dot_character_in_function_name() + { + await using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync( + @"CREATE OR REPLACE PROCEDURE ""My.Dotted.Procedure""(int4, text) AS 'SELECT 0' LANGUAGE sql"); + + try + { + await using var command = new NpgsqlCommand(@"""My.Dotted.Procedure""", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + } + finally + { + await conn.ExecuteNonQueryAsync(@"DROP PROCEDURE ""My.Dotted.Procedure"""); + } + } + + [Test] + public async Task DeriveParameters_parameter_name_from_function() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + var sproc = await GetTempProcedureName(conn); + + await conn.ExecuteNonQueryAsync( + $"CREATE PROCEDURE {sproc}(x int, y int, out sum int, out product int) AS 'SELECT $1 + $2, $1 * $2' LANGUAGE sql"); + await using var command = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.AreEqual("x", command.Parameters[0].ParameterName); + Assert.AreEqual("y", command.Parameters[1].ParameterName); + } + + [Test] + public async Task DeriveParameters_non_existing_procedure() + { + await using var conn = await OpenConnectionAsync(); + var invalidCommandName = new NpgsqlCommand("invalidprocedurename", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(invalidCommandName), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFunction)); + } + + [Test, Description("Tests if the right function according to search_path is used in function parameter derivation")] + public async Task DeriveParameters_procedure_correct_schema_resolution() + { + await using var conn = await OpenConnectionAsync(); + var schema1 = await CreateTempSchema(conn); + var schema2 = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE PROCEDURE {schema1}.redundantsproc() AS 'SELECT 1' LANGUAGE sql; +CREATE PROCEDURE {schema2}.redundantsproc(IN param1 INT, IN param2 INT) AS 'SELECT param1 + param2' LANGUAGE sql; +SET search_path TO {schema2};"); + await using var command = new NpgsqlCommand("redundantsproc", conn) { CommandType = CommandType.StoredProcedure }; + NpgsqlCommandBuilder.DeriveParameters(command); + Assert.That(command.Parameters, Has.Count.EqualTo(2)); + Assert.That(command.Parameters[0].Direction, Is.EqualTo(ParameterDirection.Input)); + Assert.That(command.Parameters[1].Direction, Is.EqualTo(ParameterDirection.Input)); + } + + [Test, Description("Tests if function parameter derivation throws an exception if the specified function is not in the search_path")] + public async Task DeriveParameters_throws_for_existing_procedure_that_is_not_in_search_path() + { + await using var conn = await OpenConnectionAsync(); + var schema = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync($@" +CREATE PROCEDURE {schema}.schema1sproc() AS 'SELECT 1' LANGUAGE sql; +RESET search_path;"); + await using var command = new NpgsqlCommand("schema1sproc", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedFunction)); + } + + [Test, Description("Tests if an exception is thrown if multiple functions with the specified name are in the search_path")] + public async Task DeriveParameters_throws_for_multiple_procedures_name_hits_in_search_path() + { + await using var conn = await OpenConnectionAsync(); + var schema1 = await CreateTempSchema(conn); + var schema2 = await CreateTempSchema(conn); + + await conn.ExecuteNonQueryAsync( + $@" +CREATE PROCEDURE {schema1}.redundantsproc() AS 'SELECT 1' LANGUAGE sql; +CREATE PROCEDURE {schema1}.redundantsproc(IN param1 INT, IN param2 INT) AS 'SELECT param1 + param2' LANGUAGE sql; +SET search_path TO {schema1}, {schema2};"); + var command = new NpgsqlCommand("redundantsproc", conn) { CommandType = CommandType.StoredProcedure }; + Assert.That(() => NpgsqlCommandBuilder.DeriveParameters(command), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.AmbiguousFunction)); + } + + #endregion DeriveParameters + + [OneTimeSetUp] + public async Task OneTimeSetup() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Stored procedures were introduced in PostgreSQL 11"); + } +} diff --git a/test/Npgsql.Tests/Support/AssemblySetUp.cs b/test/Npgsql.Tests/Support/AssemblySetUp.cs new file mode 100644 index 0000000000..f1619ecec4 --- /dev/null +++ b/test/Npgsql.Tests/Support/AssemblySetUp.cs @@ -0,0 +1,46 @@ +using Npgsql; +using Npgsql.Tests; +using NUnit.Framework; +using System; +using System.Threading; + +[SetUpFixture] +public class AssemblySetUp +{ + [OneTimeSetUp] + public void Setup() + { + var connString = TestUtil.ConnectionString; + using var conn = new NpgsqlConnection(connString); + try + { + conn.Open(); + } + catch (PostgresException e) + { + if (e.SqlState == PostgresErrorCodes.InvalidPassword && connString == TestUtil.DefaultConnectionString) + throw new Exception("Please create a user npgsql_tests as follows: CREATE USER npgsql_tests PASSWORD 'npgsql_tests' SUPERUSER"); + + if (e.SqlState == PostgresErrorCodes.InvalidCatalogName) + { + var builder = new NpgsqlConnectionStringBuilder(connString) + { + Pooling = false, + Multiplexing = false, + Database = "postgres" + }; + + using var adminConn = new NpgsqlConnection(builder.ConnectionString); + adminConn.Open(); + adminConn.ExecuteNonQuery("CREATE DATABASE " + conn.Database); + adminConn.Close(); + Thread.Sleep(1000); + + conn.Open(); + return; + } + + throw; + } + } +} diff --git a/test/Npgsql.Tests/Support/ListLoggerFactory.cs b/test/Npgsql.Tests/Support/ListLoggerFactory.cs new file mode 100644 index 0000000000..2852335df8 --- /dev/null +++ b/test/Npgsql.Tests/Support/ListLoggerFactory.cs @@ -0,0 +1,92 @@ +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; + +namespace Npgsql.Tests.Support; + +public class ListLoggerProvider : ILoggerProvider +{ + readonly ListLogger _logger; + bool _recording; + + public ListLoggerProvider() + => _logger = new ListLogger(this); + + public List<(LogLevel Level, EventId Id, string Message, object? State, Exception? Exception)> Log + => _logger.LoggedEvents; + + public IDisposable Record() + { + _logger.Clear(); + _recording = true; + + return new RecordingDisposable(this); + } + + public void StopRecording() + => _recording = false; + + public ILogger CreateLogger(string categoryName) => _logger; + + public void AddProvider(ILoggerProvider provider) + { + } + + public void Dispose() + => StopRecording(); + + class ListLogger : ILogger + { + readonly ListLoggerProvider _provider; + + public ListLogger(ListLoggerProvider provider) + => _provider = provider; + + public List<(LogLevel, EventId, string, object?, Exception?)> LoggedEvents { get; } + = new(); + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, + Func formatter) + { + if (_provider._recording) + { + lock (this) + { + var message = formatter(state, exception).Trim(); + LoggedEvents.Add((logLevel, eventId, message, state, exception)); + } + } + } + + public void Clear() + { + lock (this) + { + LoggedEvents.Clear(); + } + } + + public bool IsEnabled(LogLevel logLevel) => _provider._recording; + + public IDisposable BeginScope(TState state) where TState : notnull + => new Scope(); + + class Scope : IDisposable + { + public void Dispose() + { + } + } + } + + class RecordingDisposable : IDisposable + { + readonly ListLoggerProvider _provider; + + public RecordingDisposable(ListLoggerProvider provider) + => _provider = provider; + + public void Dispose() + => _provider.StopRecording(); + } +} diff --git a/test/Npgsql.Tests/Support/LoggingSetupFixture.cs b/test/Npgsql.Tests/Support/LoggingSetupFixture.cs deleted file mode 100644 index 6ea75d1b6b..0000000000 --- a/test/Npgsql.Tests/Support/LoggingSetupFixture.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; -using NUnit.Framework; -using NLog.Config; -using NLog.Targets; -using NLog; -using Npgsql.Logging; -using Npgsql.Tests; -using Npgsql.Tests.Support; - -// ReSharper disable once CheckNamespace - -[SetUpFixture] -public class LoggingSetupFixture -{ - [OneTimeSetUp] - public void Setup() - { - var logLevelText = Environment.GetEnvironmentVariable("NPGSQL_TEST_LOGGING"); - if (logLevelText == null) - return; - if (!Enum.TryParse(logLevelText, true, out NpgsqlLogLevel logLevel)) - throw new ArgumentOutOfRangeException($"Invalid loglevel in NPGSQL_TEST_LOGGING: {logLevelText}"); - - var config = new LoggingConfiguration(); - var consoleTarget = new ColoredConsoleTarget - { - Layout = @"${message} ${exception:format=tostring}" - }; - config.AddTarget("console", consoleTarget); - var rule = new LoggingRule("*", LogLevel.Debug, consoleTarget); - config.LoggingRules.Add(rule); - LogManager.Configuration = config; - - NpgsqlLogManager.Provider = new NLogLoggingProvider(); - NpgsqlLogManager.IsParameterLoggingEnabled = true; - } -} diff --git a/test/Npgsql.Tests/Support/MultiplexingTestBase.cs b/test/Npgsql.Tests/Support/MultiplexingTestBase.cs new file mode 100644 index 0000000000..892dd79f5e --- /dev/null +++ b/test/Npgsql.Tests/Support/MultiplexingTestBase.cs @@ -0,0 +1,37 @@ +using System.Collections.Concurrent; +using NUnit.Framework; + +namespace Npgsql.Tests; + +[TestFixture(MultiplexingMode.NonMultiplexing)] +[TestFixture(MultiplexingMode.Multiplexing)] +public abstract class MultiplexingTestBase : TestBase +{ + protected bool IsMultiplexing => MultiplexingMode == MultiplexingMode.Multiplexing; + + protected MultiplexingMode MultiplexingMode { get; } + + readonly ConcurrentDictionary<(string ConnString, bool IsMultiplexing), string> _connStringCache + = new(); + + public override string ConnectionString { get; } + + protected MultiplexingTestBase(MultiplexingMode multiplexingMode) + { + MultiplexingMode = multiplexingMode; + + // If the test requires multiplexing to be on or off, use a small cache to avoid reparsing and + // regenerating the connection string every time + ConnectionString = _connStringCache.GetOrAdd((base.ConnectionString, IsMultiplexing), + tup => new NpgsqlConnectionStringBuilder(tup.ConnString) + { + Multiplexing = tup.IsMultiplexing + }.ToString()); + } +} + +public enum MultiplexingMode +{ + NonMultiplexing, + Multiplexing +} diff --git a/test/Npgsql.Tests/Support/NLogLoggingProvider.cs b/test/Npgsql.Tests/Support/NLogLoggingProvider.cs deleted file mode 100644 index b8bf8364b4..0000000000 --- a/test/Npgsql.Tests/Support/NLogLoggingProvider.cs +++ /dev/null @@ -1,51 +0,0 @@ -using System; -using NLog; -using Npgsql.Logging; - -namespace Npgsql.Tests.Support -{ - class NLogLoggingProvider : INpgsqlLoggingProvider - { - public NpgsqlLogger CreateLogger(string name) - { - return new NLogLogger(name); - } - } - - class NLogLogger : NpgsqlLogger - { - readonly Logger _log; - - internal NLogLogger(string name) - { - _log = LogManager.GetLogger(name); - } - - public override bool IsEnabled(NpgsqlLogLevel level) - { - return _log.IsEnabled(ToNLogLogLevel(level)); - } - - public override void Log(NpgsqlLogLevel level, int connectorId, string msg, Exception? exception = null) - { - var ev = new LogEventInfo(ToNLogLogLevel(level), "", msg); - if (exception != null) - ev.Exception = exception; - if (connectorId != 0) - ev.Properties["ConnectorId"] = connectorId; - _log.Log(ev); - } - - static LogLevel ToNLogLogLevel(NpgsqlLogLevel level) - => level switch - { - NpgsqlLogLevel.Trace => LogLevel.Trace, - NpgsqlLogLevel.Debug => LogLevel.Debug, - NpgsqlLogLevel.Info => LogLevel.Info, - NpgsqlLogLevel.Warn => LogLevel.Warn, - NpgsqlLogLevel.Error => LogLevel.Error, - NpgsqlLogLevel.Fatal => LogLevel.Fatal, - _ => throw new ArgumentOutOfRangeException(nameof(level)) - }; - } -} diff --git a/test/Npgsql.Tests/Support/PgCancellationRequest.cs b/test/Npgsql.Tests/Support/PgCancellationRequest.cs new file mode 100644 index 0000000000..c07f606bb8 --- /dev/null +++ b/test/Npgsql.Tests/Support/PgCancellationRequest.cs @@ -0,0 +1,38 @@ +using System.IO; +using Npgsql.Internal; + +namespace Npgsql.Tests.Support; + +class PgCancellationRequest +{ + readonly NpgsqlReadBuffer _readBuffer; + readonly NpgsqlWriteBuffer _writeBuffer; + readonly Stream _stream; + + public int ProcessId { get; } + public int Secret { get; } + + bool completed; + + public PgCancellationRequest(NpgsqlReadBuffer readBuffer, NpgsqlWriteBuffer writeBuffer, Stream stream, int processId, int secret) + { + _readBuffer = readBuffer; + _writeBuffer = writeBuffer; + _stream = stream; + + ProcessId = processId; + Secret = secret; + } + + public void Complete() + { + if (completed) + return; + + _readBuffer.Dispose(); + _writeBuffer.Dispose(); + _stream.Dispose(); + + completed = true; + } +} \ No newline at end of file diff --git a/test/Npgsql.Tests/Support/PgPostmasterMock.cs b/test/Npgsql.Tests/Support/PgPostmasterMock.cs index 1c93716338..ab3eeab521 100644 --- a/test/Npgsql.Tests/Support/PgPostmasterMock.cs +++ b/test/Npgsql.Tests/Support/PgPostmasterMock.cs @@ -6,185 +6,256 @@ using System.Text; using System.Threading.Channels; using System.Threading.Tasks; -using Npgsql.Util; -using NUnit.Framework.Constraints; +using Npgsql.Internal; -namespace Npgsql.Tests.Support +namespace Npgsql.Tests.Support; + +class PgPostmasterMock : IAsyncDisposable { - class PgPostmasterMock : IAsyncDisposable + const int ReadBufferSize = 8192; + const int WriteBufferSize = 8192; + const int CancelRequestCode = 1234 << 16 | 5678; + const int SslRequest = 80877103; + + static readonly Encoding Encoding = NpgsqlWriteBuffer.UTF8Encoding; + static readonly Encoding RelaxedEncoding = NpgsqlWriteBuffer.RelaxedUTF8Encoding; + + readonly Socket _socket; + readonly List _allServers = new(); + bool _acceptingClients; + Task? _acceptClientsTask; + int _processIdCounter; + + readonly bool _completeCancellationImmediately; + readonly string? _startupErrorCode; + + ChannelWriter> _pendingRequestsWriter { get; } + ChannelReader> _pendingRequestsReader { get; } + + internal string ConnectionString { get; } + internal string Host { get; } + internal int Port { get; } + + volatile MockState _state; + + internal MockState State { - const int ReadBufferSize = 8192; - const int WriteBufferSize = 8192; - const int CancelRequestCode = 1234 << 16 | 5678; + get => _state; + set => _state = value; + } - static readonly Encoding Encoding = PGUtil.UTF8Encoding; - static readonly Encoding RelaxedEncoding = PGUtil.RelaxedUTF8Encoding; + internal static PgPostmasterMock Start( + string? connectionString = null, + bool completeCancellationImmediately = true, + MockState state = MockState.MultipleHostsDisabled, + string? startupErrorCode = null) + { + var mock = new PgPostmasterMock(connectionString, completeCancellationImmediately, state, startupErrorCode); + mock.AcceptClients(); + return mock; + } - readonly Socket _socket; - readonly List _allServers = new List(); - bool _acceptingClients; - Task? _acceptClientsTask; - int _processIdCounter; + internal PgPostmasterMock( + string? connectionString = null, + bool completeCancellationImmediately = true, + MockState state = MockState.MultipleHostsDisabled, + string? startupErrorCode = null) + { + var pendingRequestsChannel = Channel.CreateUnbounded>(); + _pendingRequestsReader = pendingRequestsChannel.Reader; + _pendingRequestsWriter = pendingRequestsChannel.Writer; - ChannelWriter _pendingRequestsWriter { get; } - internal ChannelReader PendingRequestsReader { get; } + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(connectionString); - internal string ConnectionString { get; } + _completeCancellationImmediately = completeCancellationImmediately; + State = state; + _startupErrorCode = startupErrorCode; - internal static PgPostmasterMock Start(string? connectionString = null) - { - var mock = new PgPostmasterMock(connectionString); - mock.AcceptClients(); - return mock; - } + _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var endpoint = new IPEndPoint(IPAddress.Loopback, 0); + _socket.Bind(endpoint); - internal PgPostmasterMock(string? connectionString = null) - { - var pendingRequestsChannel = Channel.CreateUnbounded(); - PendingRequestsReader = pendingRequestsChannel.Reader; - _pendingRequestsWriter = pendingRequestsChannel.Writer; - - var connectionStringBuilder = - new NpgsqlConnectionStringBuilder(connectionString ?? TestUtil.ConnectionString); - - _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - var endpoint = new IPEndPoint(IPAddress.Loopback, 0); - _socket.Bind(endpoint); - var localEndPoint = (IPEndPoint)_socket.LocalEndPoint!; - connectionStringBuilder.Host = localEndPoint.Address.ToString(); - connectionStringBuilder.Port = localEndPoint.Port; - connectionStringBuilder.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; - ConnectionString = connectionStringBuilder.ConnectionString; - - _socket.Listen(5); - } + var localEndPoint = (IPEndPoint)_socket.LocalEndPoint!; + Host = localEndPoint.Address.ToString(); + Port = localEndPoint.Port; + connectionStringBuilder.Host = Host; + connectionStringBuilder.Port = Port; + connectionStringBuilder.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; + ConnectionString = connectionStringBuilder.ConnectionString; - void AcceptClients() - { - _acceptingClients = true; - _acceptClientsTask = DoAcceptClients(); + _socket.Listen(5); + } + + public NpgsqlDataSourceBuilder GetDataSourceBuilder() + => new(ConnectionString); - async Task DoAcceptClients() + public NpgsqlDataSource CreateDataSource() + => NpgsqlDataSource.Create(ConnectionString); + + void AcceptClients() + { + _acceptingClients = true; + _acceptClientsTask = DoAcceptClients(); + + async Task DoAcceptClients() + { + while (true) { - while (true) + var serverOrCancellationRequest = await Accept(_completeCancellationImmediately); + if (serverOrCancellationRequest.Server is { } server) { - var serverOrCancellationRequest = await Accept(); - if (serverOrCancellationRequest.Server is { } server) + // Hand off the new server to the client test only once startup is complete, to avoid reading/writing in parallel + // during startup. Don't wait for all this to complete - continue to accept other connections in case that's needed. + if (string.IsNullOrEmpty(_startupErrorCode)) { - // Hand off the new server to the client test only once startup is complete, to avoid reading/writing in parallel - // during startup. Don't wait for all this to complete - continue to accept other connections in case that's needed. - _ = server.Startup().ContinueWith(t => _pendingRequestsWriter.WriteAsync(serverOrCancellationRequest)); + // We may be accepting (and starting up) multiple connections in parallel, but some tests assume we return + // server connections in FIFO. As a result, we enqueue immediately into the _pendingRequestsWriter channel, + // but we enqueue a Task which represents the Startup completing. + await _pendingRequestsWriter.WriteAsync(Task.Run(async () => + { + await server.Startup(State); + return serverOrCancellationRequest; + })); } else - { - await _pendingRequestsWriter.WriteAsync(serverOrCancellationRequest); - } + _ = server.FailedStartup(_startupErrorCode); + } + else + { + await _pendingRequestsWriter.WriteAsync(Task.FromResult(serverOrCancellationRequest)); } - - // ReSharper disable once FunctionNeverReturns } + + // ReSharper disable once FunctionNeverReturns } + } - internal async Task Accept() - { - var clientSocket = await _socket.AcceptAsync(); + async Task Accept(bool completeCancellationImmediately) + { + var clientSocket = await _socket.AcceptAsync(); + + var stream = new NetworkStream(clientSocket, true); + var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding, + RelaxedEncoding); + var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding); + writeBuffer.MessageLengthValidation = false; - var stream = new NetworkStream(clientSocket, true); - var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding, - RelaxedEncoding); - var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding); + await readBuffer.EnsureAsync(4); + var len = readBuffer.ReadInt32(); + await readBuffer.EnsureAsync(len - 4); + + var request = readBuffer.ReadInt32(); + if (request == SslRequest) + { + writeBuffer.WriteByte((byte)'N'); + await writeBuffer.Flush(async: true); await readBuffer.EnsureAsync(4); - var len = readBuffer.ReadInt32(); + len = readBuffer.ReadInt32(); await readBuffer.EnsureAsync(len - 4); + request = readBuffer.ReadInt32(); + } - if (readBuffer.ReadInt32() == CancelRequestCode) + if (request == CancelRequestCode) + { + var cancellationRequest = new PgCancellationRequest(readBuffer, writeBuffer, stream, readBuffer.ReadInt32(), readBuffer.ReadInt32()); + if (completeCancellationImmediately) { - readBuffer.Dispose(); - writeBuffer.Dispose(); - stream.Dispose(); - return new ServerOrCancellationRequest((readBuffer.ReadInt32(), readBuffer.ReadInt32())); + cancellationRequest.Complete(); } - // This is not a cancellation, "spawn" a new server - readBuffer.ReadPosition -= 8; - var server = new PgServerMock(stream, readBuffer, writeBuffer, ++_processIdCounter); - _allServers.Add(server); - return new ServerOrCancellationRequest(server); + return new ServerOrCancellationRequest(cancellationRequest); } - internal async Task AcceptServer() - { - if (_acceptingClients) - throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); - var serverOrCancellationRequest = await Accept(); - if (serverOrCancellationRequest.Server is null) - throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); - return serverOrCancellationRequest.Server; - } + // This is not a cancellation, "spawn" a new server + readBuffer.ReadPosition -= 8; + var server = new PgServerMock(stream, readBuffer, writeBuffer, ++_processIdCounter); + _allServers.Add(server); + return new ServerOrCancellationRequest(server); + } - internal async Task<(int ProcessId, int Secret)> AcceptCancellationRequest() - { - if (_acceptingClients) - throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); - var serverOrCancellationRequest = await Accept(); - if (serverOrCancellationRequest.CancellationRequest is null) - throw new InvalidOperationException("Expected a cancellation request but got a server connection instead"); - return serverOrCancellationRequest.CancellationRequest.Value; - } + internal async Task AcceptServer(bool completeCancellationImmediately = true) + { + if (_acceptingClients) + throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); + var serverOrCancellationRequest = await Accept(completeCancellationImmediately); + if (serverOrCancellationRequest.Server is null) + throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); + return serverOrCancellationRequest.Server; + } + + internal async Task AcceptCancellationRequest() + { + if (_acceptingClients) + throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); + var serverOrCancellationRequest = await Accept(completeCancellationImmediately: true); + if (serverOrCancellationRequest.CancellationRequest is null) + throw new InvalidOperationException("Expected a cancellation request but got a server connection instead"); + return serverOrCancellationRequest.CancellationRequest; + } - internal async ValueTask WaitForServerConnection() + internal async ValueTask WaitForServerConnection() + { + var serverOrCancellationRequest = await await _pendingRequestsReader.ReadAsync(); + if (serverOrCancellationRequest.Server is null) + throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); + return serverOrCancellationRequest.Server; + } + + internal async ValueTask WaitForCancellationRequest() + { + var serverOrCancellationRequest = await await _pendingRequestsReader.ReadAsync(); + if (serverOrCancellationRequest.CancellationRequest is null) + throw new InvalidOperationException("Expected cancellation request but got a server connection instead"); + return serverOrCancellationRequest.CancellationRequest; + } + + public async ValueTask DisposeAsync() + { + var endpoint = _socket.LocalEndPoint as IPEndPoint; + Debug.Assert(endpoint is not null); + + // Stop accepting new connections + _socket.Dispose(); + try { - var serverOrCancellationRequest = await PendingRequestsReader.ReadAsync(); - if (serverOrCancellationRequest.Server is null) - throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); - return serverOrCancellationRequest.Server; + var acceptTask = _acceptClientsTask; + if (acceptTask != null) + await acceptTask; } - - internal async ValueTask<(int ProcessId, int Secret)> WaitForCancellationRequest() + catch { - var serverOrCancellationRequest = await PendingRequestsReader.ReadAsync(); - if (serverOrCancellationRequest.CancellationRequest is null) - throw new InvalidOperationException("Expected cancellation request but got a server connection instead"); - return serverOrCancellationRequest.CancellationRequest.Value; + // Swallow all exceptions } - public async ValueTask DisposeAsync() - { - // Stop accepting new connections - _socket.Dispose(); - try - { - var acceptTask = _acceptClientsTask; - if (acceptTask != null) - await acceptTask; - } - catch - { - // Swallow all exceptions - } + // Destroy all servers created by this postmaster + foreach (var server in _allServers) + server.Dispose(); + } - // Destroy all servers created by this postmaster - foreach (var server in _allServers) - server.Dispose(); + internal readonly struct ServerOrCancellationRequest + { + public ServerOrCancellationRequest(PgServerMock server) + { + Server = server; + CancellationRequest = null; } - internal readonly struct ServerOrCancellationRequest + public ServerOrCancellationRequest(PgCancellationRequest cancellationRequest) { - public ServerOrCancellationRequest(PgServerMock server) - { - Server = server; - CancellationRequest = null; - } - - public ServerOrCancellationRequest((int ProcessId, int Secret) cancellationRequest) - { - Server = null; - CancellationRequest = cancellationRequest; - } - - internal PgServerMock? Server { get; } - internal (int ProcessId, int Secret)? CancellationRequest { get; } + Server = null; + CancellationRequest = cancellationRequest; } + + internal PgServerMock? Server { get; } + internal PgCancellationRequest? CancellationRequest { get; } } } + +public enum MockState +{ + MultipleHostsDisabled = 0, + Primary = 1, + PrimaryReadOnly = 2, + Standby = 3 +} diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index df5cf58a42..9f7a799649 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -6,300 +6,405 @@ using System.Text; using System.Threading.Tasks; using Npgsql.BackendMessages; -using Npgsql.Util; +using Npgsql.Internal; +using Npgsql.Internal.Postgres; using NUnit.Framework; -namespace Npgsql.Tests.Support +namespace Npgsql.Tests.Support; + +class PgServerMock : IDisposable { - class PgServerMock : IDisposable - { - static readonly Encoding Encoding = PGUtil.UTF8Encoding; + static uint BoolOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Bool).Value; + static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; + static uint TextOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Text).Value; - readonly NetworkStream _stream; - readonly NpgsqlReadBuffer _readBuffer; - readonly NpgsqlWriteBuffer _writeBuffer; - bool _disposed; + static readonly Encoding Encoding = NpgsqlWriteBuffer.UTF8Encoding; - const int BackendSecret = 12345; - internal int ProcessId { get; } + readonly NetworkStream _stream; + readonly NpgsqlReadBuffer _readBuffer; + readonly NpgsqlWriteBuffer _writeBuffer; + bool _disposed; - internal NpgsqlReadBuffer ReadBuffer => _readBuffer; + const int BackendSecret = 12345; + internal int ProcessId { get; } - internal PgServerMock( - NetworkStream stream, - NpgsqlReadBuffer readBuffer, - NpgsqlWriteBuffer writeBuffer, - int processId) - { - ProcessId = processId; - _stream = stream; - _readBuffer = readBuffer; - _writeBuffer = writeBuffer; - } + internal NpgsqlReadBuffer ReadBuffer => _readBuffer; + internal NpgsqlWriteBuffer WriteBuffer => _writeBuffer; - internal async Task Startup() - { - // Read and skip the startup message - await SkipMessage(); - - WriteAuthenticateOk(); - WriteParameterStatuses(new Dictionary - { - { "server_version", "13" }, - { "server_encoding", "UTF8" }, - { "client_encoding", "UTF8" }, - { "application_name", "Mock" }, - { "is_superuser", "on" }, - { "session_authorization", "foo" }, - { "DateStyle", "ISO, MDY" }, - { "IntervalStyle", "postgres" }, - { "TimeZone", "UTC" }, - { "integer_datetimes", "on" }, - { "standard_conforming_strings", "on" } - - }); - WriteBackendKeyData(ProcessId, BackendSecret); - WriteReadyForQuery(); - - await FlushAsync(); - } + internal PgServerMock( + NetworkStream stream, + NpgsqlReadBuffer readBuffer, + NpgsqlWriteBuffer writeBuffer, + int processId) + { + ProcessId = processId; + _stream = stream; + _readBuffer = readBuffer; + _writeBuffer = writeBuffer; + writeBuffer.MessageLengthValidation = false; + } - internal async Task SkipMessage() - { - await _readBuffer.EnsureAsync(4); - var len = _readBuffer.ReadInt32(); - await _readBuffer.EnsureAsync(len - 4); - _readBuffer.Skip(len - 4); - } + internal async Task Startup(MockState state) + { + // Read and skip the startup message + await SkipMessage(); - internal async Task ExpectMessage(byte expectedCode) + WriteAuthenticateOk(); + var parameters = new Dictionary + { + { "server_version", "14" }, + { "server_encoding", "UTF8" }, + { "client_encoding", "UTF8" }, + { "application_name", "Mock" }, + { "is_superuser", "on" }, + { "session_authorization", "foo" }, + { "DateStyle", "ISO, MDY" }, + { "IntervalStyle", "postgres" }, + { "TimeZone", "UTC" }, + { "integer_datetimes", "on" }, + { "standard_conforming_strings", "on" } + }; + // While PostgreSQL 14 always sends default_transaction_read_only and in_hot_standby, we only send them if requested + // To minimize potential issues for tests not requiring multiple hosts + if (state != MockState.MultipleHostsDisabled) { - CheckDisposed(); - - await _readBuffer.EnsureAsync(5); - var actualCode = _readBuffer.ReadByte(); - Assert.That(actualCode, Is.EqualTo(expectedCode), - $"Expected message of type '{(char)expectedCode}' but got '{(char)actualCode}'"); - var len = _readBuffer.ReadInt32(); - _readBuffer.Skip(len - 4); + parameters["default_transaction_read_only"] = state == MockState.Primary ? "off" : "on"; + parameters["in_hot_standby"] = state == MockState.Standby ? "on" : "off"; } + WriteParameterStatuses(parameters); + WriteBackendKeyData(ProcessId, BackendSecret); + WriteReadyForQuery(); + await FlushAsync(); + } - internal Task ExpectExtendedQuery() - => ExpectMessages( - FrontendMessageCode.Parse, - FrontendMessageCode.Bind, - FrontendMessageCode.Describe, - FrontendMessageCode.Execute, - FrontendMessageCode.Sync); + internal async Task FailedStartup(string errorCode) + { + // Read and skip the startup message + await SkipMessage(); + WriteErrorResponse(errorCode); + await FlushAsync(); + } - internal async Task ExpectMessages(params byte[] expectedCodes) - { - foreach (var expectedCode in expectedCodes) - await ExpectMessage(expectedCode); - } + internal Task SendMockState(MockState state) + { + var isStandby = state == MockState.Standby; + var transactionReadOnly = state == MockState.Standby || state == MockState.PrimaryReadOnly + ? "on" + : "off"; + + return WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(BoolOid)) + .WriteDataRow(BitConverter.GetBytes(isStandby)) + .WriteCommandComplete() + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(TextOid)) + .WriteDataRow(Encoding.ASCII.GetBytes(transactionReadOnly)) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + } - internal async Task ExpectSimpleQuery(string expectedSql) - { - CheckDisposed(); - - await _readBuffer.EnsureAsync(5); - var actualCode = _readBuffer.ReadByte(); - Assert.That(actualCode, Is.EqualTo(FrontendMessageCode.Query), $"Expected message of type Query but got '{(char)actualCode}'"); - _ = _readBuffer.ReadInt32(); - var actualSql = _readBuffer.ReadNullTerminatedString(); - Assert.That(actualSql, Is.EqualTo(expectedSql)); - } + internal async Task SkipMessage() + { + await _readBuffer.EnsureAsync(4); + var len = _readBuffer.ReadInt32(); + await _readBuffer.EnsureAsync(len - 4); + _readBuffer.Skip(len - 4); + } - internal Task FlushAsync() - { - CheckDisposed(); - return _writeBuffer.Flush(async: true); - } + internal async Task ExpectMessage(byte expectedCode) + { + CheckDisposed(); + + await _readBuffer.EnsureAsync(5); + var actualCode = _readBuffer.ReadByte(); + Assert.That(actualCode, Is.EqualTo(expectedCode), + $"Expected message of type '{(char)expectedCode}' but got '{(char)actualCode}'"); + var len = _readBuffer.ReadInt32(); + _readBuffer.Skip(len - 4); + } - internal Task WriteScalarResponseAndFlush(int value) - => WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) - .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(value))) - .WriteCommandComplete() - .WriteReadyForQuery() - .FlushAsync(); + internal Task ExpectExtendedQuery() + => ExpectMessages( + FrontendMessageCode.Parse, + FrontendMessageCode.Bind, + FrontendMessageCode.Describe, + FrontendMessageCode.Execute, + FrontendMessageCode.Sync); - internal void Close() => _stream.Close(); + internal async Task ExpectMessages(params byte[] expectedCodes) + { + foreach (var expectedCode in expectedCodes) + await ExpectMessage(expectedCode); + } - #region Low-level message writing + internal async Task ExpectSimpleQuery(string expectedSql) + { + CheckDisposed(); + + await _readBuffer.EnsureAsync(5); + var actualCode = _readBuffer.ReadByte(); + Assert.That(actualCode, Is.EqualTo(FrontendMessageCode.Query), $"Expected message of type Query but got '{(char)actualCode}'"); + _ = _readBuffer.ReadInt32(); + var actualSql = _readBuffer.ReadNullTerminatedString(); + Assert.That(actualSql, Is.EqualTo(expectedSql)); + } - internal PgServerMock WriteParseComplete() - { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.ParseComplete); - _writeBuffer.WriteInt32(4); - return this; - } + internal Task WaitForData() => _readBuffer.EnsureAsync(1).AsTask(); - internal PgServerMock WriteBindComplete() - { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.BindComplete); - _writeBuffer.WriteInt32(4); - return this; - } + internal Task FlushAsync() + { + CheckDisposed(); + return _writeBuffer.Flush(async: true); + } - internal PgServerMock WriteRowDescription(params FieldDescription[] fields) + internal Task WriteScalarResponseAndFlush(int value) + => WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(value))) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + internal Task WriteScalarResponseAndFlush(bool value) + => WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(BoolOid)) + .WriteDataRow(BitConverter.GetBytes(value)) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + internal Task WriteScalarResponseAndFlush(string value) + => WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(TextOid)) + .WriteDataRow(Encoding.ASCII.GetBytes(value)) + .WriteCommandComplete() + .WriteReadyForQuery() + .FlushAsync(); + + internal void Close() => _stream.Close(); + + #region Low-level message writing + + internal PgServerMock WriteParseComplete() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.ParseComplete); + _writeBuffer.WriteInt32(4); + return this; + } + + internal PgServerMock WriteBindComplete() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.BindComplete); + _writeBuffer.WriteInt32(4); + return this; + } + + internal PgServerMock WriteRowDescription(params FieldDescription[] fields) + { + CheckDisposed(); + + _writeBuffer.WriteByte((byte)BackendMessageCode.RowDescription); + _writeBuffer.WriteInt32(4 + 2 + fields.Sum(f => Encoding.GetByteCount(f.Name) + 1 + 18)); + _writeBuffer.WriteInt16((short)fields.Length); + + foreach (var field in fields) { - CheckDisposed(); - - _writeBuffer.WriteByte((byte)BackendMessageCode.RowDescription); - _writeBuffer.WriteInt32(4 + 2 + fields.Sum(f => Encoding.GetByteCount(f.Name) + 1 + 18)); - _writeBuffer.WriteInt16(fields.Length); - - foreach (var field in fields) - { - _writeBuffer.WriteNullTerminatedString(field.Name); - _writeBuffer.WriteUInt32(field.TableOID); - _writeBuffer.WriteInt16(field.ColumnAttributeNumber); - _writeBuffer.WriteUInt32(field.TypeOID); - _writeBuffer.WriteInt16(field.TypeSize); - _writeBuffer.WriteInt32(field.TypeModifier); - _writeBuffer.WriteInt16((short)field.FormatCode); - } - - return this; + _writeBuffer.WriteNullTerminatedString(field.Name); + _writeBuffer.WriteUInt32(field.TableOID); + _writeBuffer.WriteInt16(field.ColumnAttributeNumber); + _writeBuffer.WriteUInt32(field.TypeOID); + _writeBuffer.WriteInt16(field.TypeSize); + _writeBuffer.WriteInt32(field.TypeModifier); + _writeBuffer.WriteInt16(field.DataFormat.ToFormatCode()); } - internal PgServerMock WriteDataRow(params byte[][] columnValues) - { - CheckDisposed(); + return this; + } - _writeBuffer.WriteByte((byte)BackendMessageCode.DataRow); - _writeBuffer.WriteInt32(4 + 2 + columnValues.Sum(v => 4 + v.Length)); - _writeBuffer.WriteInt16(columnValues.Length); + internal PgServerMock WriteParameterDescription(params FieldDescription[] fields) + { + CheckDisposed(); - foreach (var field in columnValues) - { - _writeBuffer.WriteInt32(field.Length); - _writeBuffer.WriteBytes(field); - } + _writeBuffer.WriteByte((byte)BackendMessageCode.ParameterDescription); + _writeBuffer.WriteInt32(1 + 4 + 2 + fields.Length * 4); + _writeBuffer.WriteUInt16((ushort)fields.Length); - return this; - } + foreach (var field in fields) + _writeBuffer.WriteUInt32(field.TypeOID); - internal async Task WriteDataRowWithFlush(params byte[][] columnValues) - { - CheckDisposed(); + return this; + } - _writeBuffer.WriteByte((byte) BackendMessageCode.DataRow); - _writeBuffer.WriteInt32(4 + 2 + columnValues.Sum(v => 4 + v.Length)); - _writeBuffer.WriteInt16(columnValues.Length); + internal PgServerMock WriteNoData() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.NoData); + _writeBuffer.WriteInt32(4); + return this; + } - foreach (var field in columnValues) - { - _writeBuffer.WriteInt32(field.Length); - await _writeBuffer.WriteBytesRaw(field, true); - } - } + internal PgServerMock WriteEmptyQueryResponse() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.EmptyQueryResponse); + _writeBuffer.WriteInt32(4); + return this; + } - internal PgServerMock WriteCommandComplete(string tag = "") - { - CheckDisposed(); + internal PgServerMock WriteDataRow(params byte[][] columnValues) + { + CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.CommandComplete); - _writeBuffer.WriteInt32(4 + Encoding.GetByteCount(tag) + 1); - _writeBuffer.WriteNullTerminatedString(tag); - return this; - } + _writeBuffer.WriteByte((byte)BackendMessageCode.DataRow); + _writeBuffer.WriteInt32(4 + 2 + columnValues.Sum(v => 4 + v.Length)); + _writeBuffer.WriteInt16((short)columnValues.Length); - internal PgServerMock WriteReadyForQuery(TransactionStatus transactionStatus = TransactionStatus.Idle) + foreach (var field in columnValues) { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.ReadyForQuery); - _writeBuffer.WriteInt32(4 + 1); - _writeBuffer.WriteByte((byte)transactionStatus); - return this; + _writeBuffer.WriteInt32(field.Length); + _writeBuffer.WriteBytes(field); } - internal PgServerMock WriteAuthenticateOk() - { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.AuthenticationRequest); - _writeBuffer.WriteInt32(4 + 4); - _writeBuffer.WriteInt32(0); - return this; - } + return this; + } + + /// + /// Writes the bytes to the buffer and flushes only when the buffer is full + /// + internal async Task WriteDataRowWithFlush(params byte[][] columnValues) + { + CheckDisposed(); + + _writeBuffer.WriteByte((byte)BackendMessageCode.DataRow); + _writeBuffer.WriteInt32(4 + 2 + columnValues.Sum(v => 4 + v.Length)); + _writeBuffer.WriteInt16((short)columnValues.Length); - internal PgServerMock WriteParameterStatuses(Dictionary parameters) + foreach (var field in columnValues) { - foreach (var kv in parameters) - WriteParameterStatus(kv.Key, kv.Value); - return this; + _writeBuffer.WriteInt32(field.Length); + await _writeBuffer.WriteBytesRaw(field, true); } + } - internal PgServerMock WriteParameterStatus(string name, string value) - { - CheckDisposed(); + internal PgServerMock WriteCommandComplete(string tag = "") + { + CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.ParameterStatus); - _writeBuffer.WriteInt32(4 + Encoding.GetByteCount(name) + 1 + Encoding.GetByteCount(value) + 1); - _writeBuffer.WriteNullTerminatedString(name); - _writeBuffer.WriteNullTerminatedString(value); + _writeBuffer.WriteByte((byte)BackendMessageCode.CommandComplete); + _writeBuffer.WriteInt32(4 + Encoding.GetByteCount(tag) + 1); + _writeBuffer.WriteNullTerminatedString(tag); + return this; + } - return this; - } + internal PgServerMock WriteReadyForQuery(TransactionStatus transactionStatus = TransactionStatus.Idle) + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.ReadyForQuery); + _writeBuffer.WriteInt32(4 + 1); + _writeBuffer.WriteByte((byte)transactionStatus); + return this; + } - internal PgServerMock WriteBackendKeyData(int processId, int secret) - { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.BackendKeyData); - _writeBuffer.WriteInt32(4 + 4 + 4); - _writeBuffer.WriteInt32(processId); - _writeBuffer.WriteInt32(secret); - return this; - } + internal PgServerMock WriteAuthenticateOk() + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.AuthenticationRequest); + _writeBuffer.WriteInt32(4 + 4); + _writeBuffer.WriteInt32(0); + return this; + } - internal PgServerMock WriteCancellationResponse() - => WriteErrorResponse(PostgresErrorCodes.QueryCanceled, "Cancellation", "Query cancelled"); + internal PgServerMock WriteParameterStatuses(Dictionary parameters) + { + foreach (var kv in parameters) + WriteParameterStatus(kv.Key, kv.Value); + return this; + } - internal PgServerMock WriteErrorResponse(string code) - => WriteErrorResponse(code, "ERROR", "MOCK ERROR MESSAGE"); + internal PgServerMock WriteParameterStatus(string name, string value) + { + CheckDisposed(); - internal PgServerMock WriteErrorResponse(string code, string severity, string message) - { - CheckDisposed(); - _writeBuffer.WriteByte((byte)BackendMessageCode.ErrorResponse); - _writeBuffer.WriteInt32( - 4 + - 1 + Encoding.GetByteCount(code) + - 1 + Encoding.GetByteCount(severity) + - 1 + Encoding.GetByteCount(message) + - 1); - _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Code); - _writeBuffer.WriteNullTerminatedString(code); - _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Severity); - _writeBuffer.WriteNullTerminatedString(severity); - _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Message); - _writeBuffer.WriteNullTerminatedString(message); - _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Done); - return this; - } + _writeBuffer.WriteByte((byte)BackendMessageCode.ParameterStatus); + _writeBuffer.WriteInt32(4 + Encoding.GetByteCount(name) + 1 + Encoding.GetByteCount(value) + 1); + _writeBuffer.WriteNullTerminatedString(name); + _writeBuffer.WriteNullTerminatedString(value); - #endregion Low-level message writing + return this; + } - void CheckDisposed() - { - if (_stream is null) - throw new ObjectDisposedException(nameof(PgServerMock)); - } + internal PgServerMock WriteBackendKeyData(int processId, int secret) + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.BackendKeyData); + _writeBuffer.WriteInt32(4 + 4 + 4); + _writeBuffer.WriteInt32(processId); + _writeBuffer.WriteInt32(secret); + return this; + } - public void Dispose() - { - if (_disposed) - return; + internal PgServerMock WriteCancellationResponse() + => WriteErrorResponse(PostgresErrorCodes.QueryCanceled, "Cancellation", "Query cancelled"); - _readBuffer.Dispose(); - _writeBuffer.Dispose(); - _stream.Dispose(); + internal PgServerMock WriteCopyInResponse(bool isBinary = false) + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.CopyInResponse); + _writeBuffer.WriteInt32(5); + _writeBuffer.WriteByte(isBinary ? (byte)1 : (byte)0); + _writeBuffer.WriteInt16(1); + _writeBuffer.WriteInt16(0); + return this; + } - _disposed = true; - } + internal PgServerMock WriteErrorResponse(string code) + => WriteErrorResponse(code, "ERROR", "MOCK ERROR MESSAGE"); + + internal PgServerMock WriteErrorResponse(string code, string severity, string message) + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.ErrorResponse); + _writeBuffer.WriteInt32( + 4 + + 1 + Encoding.GetByteCount(code) + + 1 + Encoding.GetByteCount(severity) + + 1 + Encoding.GetByteCount(message) + + 1); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Code); + _writeBuffer.WriteNullTerminatedString(code); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Severity); + _writeBuffer.WriteNullTerminatedString(severity); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Message); + _writeBuffer.WriteNullTerminatedString(message); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Done); + return this; + } + + #endregion Low-level message writing + + void CheckDisposed() + { + if (_stream is null) + throw new ObjectDisposedException(nameof(PgServerMock)); + } + + public void Dispose() + { + if (_disposed) + return; + + _readBuffer.Dispose(); + _writeBuffer.Dispose(); + _stream.Dispose(); + + _disposed = true; } } diff --git a/test/Npgsql.Tests/Support/SingleThreadSynchronizationContext.cs b/test/Npgsql.Tests/Support/SingleThreadSynchronizationContext.cs new file mode 100644 index 0000000000..a7fedad3d6 --- /dev/null +++ b/test/Npgsql.Tests/Support/SingleThreadSynchronizationContext.cs @@ -0,0 +1,120 @@ +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Threading; + +namespace Npgsql.Tests.Support; + +sealed class SingleThreadSynchronizationContext : SynchronizationContext, IDisposable +{ + readonly BlockingCollection _tasks = new(); + readonly object _lockObject = new(); + volatile Thread? _thread; + bool _doingWork; + + const int ThreadStayAliveMs = 10000; + readonly string _threadName; + + internal SingleThreadSynchronizationContext(string threadName) + => _threadName = threadName; + + internal Disposable Enter() => new(this); + + public override void Post(SendOrPostCallback callback, object? state) + { + _tasks.Add(new CallbackAndState { Callback = callback, State = state }); + + lock (_lockObject) + { + if (!_doingWork) + { + // Either there is no thread, or the current thread is exiting + // In which case, wait for it to complete + var currentThread = _thread; + currentThread?.Join(); + Debug.Assert(_thread is null); + _doingWork = true; + _thread = new Thread(WorkLoop) { Name = _threadName, IsBackground = true }; + _thread.Start(); + } + } + } + + public void Dispose() + { + _tasks.CompleteAdding(); + + var thread = _thread; + thread?.Join(); + + _tasks.Dispose(); + } + + void WorkLoop() + { + SetSynchronizationContext(this); + + try + { + while (true) + { + var taken = _tasks.TryTake(out var callbackAndState, ThreadStayAliveMs); + if (!taken) + { + lock (_lockObject) + { + if (_tasks.Count == 0) + { + _doingWork = false; + return; + } + } + + continue; + } + + try + { + Debug.Assert(_doingWork); + callbackAndState.Callback(callbackAndState.State); + } + catch (Exception e) + { + Trace.Write($"Exception caught in {nameof(SingleThreadSynchronizationContext)}:" + Environment.NewLine + e); + } + } + } + catch (Exception e) + { + // Here we attempt to catch any exception coming from BlockingCollection _tasks + Trace.Write($"Exception caught in {nameof(SingleThreadSynchronizationContext)}:" + Environment.NewLine + e); + lock (_lockObject) + _doingWork = false; + } + finally + { + Debug.Assert(!_doingWork); + _thread = null; + } + } + + struct CallbackAndState + { + internal SendOrPostCallback Callback; + internal object? State; + } + + internal readonly struct Disposable : IDisposable + { + readonly SynchronizationContext? _synchronizationContext; + + internal Disposable(SynchronizationContext synchronizationContext) + { + _synchronizationContext = Current; + SetSynchronizationContext(synchronizationContext); + } + + public void Dispose() + => SetSynchronizationContext(_synchronizationContext); + } +} diff --git a/test/Npgsql.Tests/Support/TestBase.cs b/test/Npgsql.Tests/Support/TestBase.cs new file mode 100644 index 0000000000..463b132d56 --- /dev/null +++ b/test/Npgsql.Tests/Support/TestBase.cs @@ -0,0 +1,665 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Npgsql.Tests.Support; +using NpgsqlTypes; +using NUnit.Framework; + +namespace Npgsql.Tests; + +public abstract class TestBase +{ + /// + /// The connection string that will be used when opening the connection to the tests database. + /// May be overridden in fixtures, e.g. to set special connection parameters + /// + public virtual string ConnectionString => TestUtil.ConnectionString; + + static readonly SemaphoreSlim DatabaseCreationLock = new(1); + + static readonly object dataSourceLockObject = new(); + + static ConcurrentDictionary DataSources = new(StringComparer.Ordinal); + + #region Type testing + + public async Task AssertType( + T value, + string sqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefaultForReading = true, + bool isDefaultForWriting = true, + bool? isDefault = null, + bool isNpgsqlDbTypeInferredFromClrType = true, + Func? comparer = null, + bool skipArrayCheck = false) + { + await using var connection = await OpenConnectionAsync(); + return await AssertType( + connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, isDefaultForWriting, + isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); + } + + public async Task AssertType( + NpgsqlDataSource dataSource, + T value, + string sqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefaultForReading = true, + bool isDefaultForWriting = true, + bool? isDefault = null, + bool isNpgsqlDbTypeInferredFromClrType = true, + Func? comparer = null, + bool skipArrayCheck = false) + { + await using var connection = await dataSource.OpenConnectionAsync(); + + return await AssertType(connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, + isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); + } + + public async Task AssertType( + NpgsqlConnection connection, + T value, + string sqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefaultForReading = true, + bool isDefaultForWriting = true, + bool? isDefault = null, + bool isNpgsqlDbTypeInferredFromClrType = true, + Func? comparer = null, + bool skipArrayCheck = false) + { + if (isDefault is not null) + isDefaultForReading = isDefaultForWriting = isDefault.Value; + + await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer, fieldType: null, skipArrayCheck); + } + + public async Task AssertTypeRead(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true, bool skipArrayCheck = false) + { + await using var connection = await OpenConnectionAsync(); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer: null, fieldType: null, skipArrayCheck); + } + + public async Task AssertTypeRead(NpgsqlDataSource dataSource, string sqlLiteral, string pgTypeName, T expected, + bool isDefault = true, Func? comparer = null, Type? fieldType = null, bool skipArrayCheck = false) + { + await using var connection = await dataSource.OpenConnectionAsync(); + return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer, fieldType, skipArrayCheck); + } + + public async Task AssertTypeWrite( + NpgsqlDataSource dataSource, + T value, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) + { + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertTypeWrite(connection, () => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, + isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + } + + public Task AssertTypeWrite( + T value, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) + => AssertTypeWrite(() => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, + isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + + public async Task AssertTypeWrite( + Func valueFactory, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) + { + await using var connection = await OpenConnectionAsync(); + await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + } + + internal static async Task AssertTypeRead( + NpgsqlConnection connection, + string sqlLiteral, + string pgTypeName, + T expected, + bool isDefault = true, + Func? comparer = null, + Type? fieldType = null, + bool skipArrayCheck = false) + { + var result = await AssertTypeReadCore(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeReadCore( + connection, + ArrayLiteral(sqlLiteral), + pgTypeName + "[]", + new[] { expected, expected }, + isDefault, + comparer is null ? null : (array1, array2) => comparer(array1[0], array2[0]) && comparer(array1[1], array2[1])); + } + + return result; + } + + internal static async Task AssertTypeReadCore( + NpgsqlConnection connection, + string sqlLiteral, + string pgTypeName, + T expected, + bool isDefault = true, + Func? comparer = null, + Type? fieldType = null) + { + if (sqlLiteral.Contains('\'')) + sqlLiteral = sqlLiteral.Replace("'", "''"); + + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", connection); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + var truncatedSqlLiteral = sqlLiteral.Length > 40 ? sqlLiteral[..40] + "..." : sqlLiteral; + + var dataTypeName = reader.GetDataTypeName(0); + var dotIndex = dataTypeName.IndexOf('.'); + if (dotIndex > -1 && dataTypeName.Substring(0, dotIndex) is "pg_catalog" or "public") + dataTypeName = dataTypeName.Substring(dotIndex + 1); + + Assert.That(dataTypeName, Is.EqualTo(pgTypeName), + $"Got wrong result from GetDataTypeName when reading '{truncatedSqlLiteral}'"); + + if (isDefault) + { + // For arrays, GetFieldType always returns typeof(Array), since PG arrays can have arbitrary dimensionality + Assert.That(reader.GetFieldType(0), Is.EqualTo(dataTypeName.EndsWith("[]") ? typeof(Array) : fieldType ?? typeof(T)), + $"Got wrong result from GetFieldType when reading '{truncatedSqlLiteral}'"); + } + + var actual = isDefault ? (T)reader.GetValue(0) : reader.GetFieldValue(0); + + Assert.That(actual, comparer is null ? Is.EqualTo(expected) : Is.EqualTo(expected).Using(new SimpleComparer(comparer)), + $"Got wrong result from GetFieldValue value when reading '{truncatedSqlLiteral}'"); + + return actual; + } + + internal static async Task AssertTypeWrite( + NpgsqlConnection connection, + Func valueFactory, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true, + bool skipArrayCheck = false) + { + await AssertTypeWriteCore( + connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, + isNpgsqlDbTypeInferredFromClrType); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeWriteCore( + connection, + () => new[] { valueFactory(), valueFactory() }, + ArrayLiteral(expectedSqlLiteral), + pgTypeName + "[]", + npgsqlDbType | NpgsqlDbType.Array, + dbType: null, + inferredDbType: null, + isDefault, + isNpgsqlDbTypeInferredFromClrType); + } + } + + internal static async Task AssertTypeWriteCore( + NpgsqlConnection connection, + Func valueFactory, + string expectedSqlLiteral, + string pgTypeName, + NpgsqlDbType? npgsqlDbType, + DbType? dbType = null, + DbType? inferredDbType = null, + bool isDefault = true, + bool isNpgsqlDbTypeInferredFromClrType = true) + { + if (npgsqlDbType is null) + isNpgsqlDbTypeInferredFromClrType = false; + + inferredDbType ??= isNpgsqlDbTypeInferredFromClrType ? dbType ?? DbType.Object : DbType.Object; + + // TODO: Interferes with both multiplexing and connection-specific mapping (used e.g. in NodaTime) + // Reset the type mapper to make sure we're resolving this type with a clean slate (for isolation, just in case) + // connection.TypeMapper.Reset(); + + // Strip any facet information (length/precision/scale) + var parenIndex = pgTypeName.IndexOf('('); + // var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName; + var pgTypeNameWithoutFacets = parenIndex > -1 + ? pgTypeName[..parenIndex] + pgTypeName[(pgTypeName.IndexOf(')') + 1)..] + : pgTypeName; + + // We test the following scenarios (between 2 and 5 in total): + // 1. With NpgsqlDbType explicitly set + // 2. With DataTypeName explicitly set + // 3. With DbType explicitly set (if one was provided) + // 4. With only the value set (if it's the default) + // 5. With only the value set, using generic NpgsqlParameter (if it's the default) + + var errorIdentifierIndex = -1; + var errorIdentifier = new Dictionary(); + + await using var cmd = new NpgsqlCommand { Connection = connection }; + NpgsqlParameter p; + // With NpgsqlDbType + if (npgsqlDbType is not null) + { + p = new NpgsqlParameter { Value = valueFactory(), NpgsqlDbType = npgsqlDbType.Value }; + cmd.Parameters.Add(p); + errorIdentifier[++errorIdentifierIndex] = $"NpgsqlDbType={npgsqlDbType}"; + CheckInference(); + } + + // With data type name + p = new NpgsqlParameter { Value = valueFactory(), DataTypeName = pgTypeNameWithoutFacets }; + cmd.Parameters.Add(p); + errorIdentifier[++errorIdentifierIndex] = $"DataTypeName={pgTypeNameWithoutFacets}"; + CheckInference(); + + // With DbType + if (dbType is not null) + { + p = new NpgsqlParameter { Value = valueFactory(), DbType = dbType.Value }; + cmd.Parameters.Add(p); + errorIdentifier[++errorIdentifierIndex] = $"DbType={dbType}"; + CheckInference(); + } + + if (isDefault) + { + // With (non-generic) value only + p = new NpgsqlParameter { Value = valueFactory() }; + cmd.Parameters.Add(p); + errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, non-generic)"; + CheckInference(valueOnlyInference: true); + + // With (generic) value only + p = new NpgsqlParameter { TypedValue = valueFactory() }; + cmd.Parameters.Add(p); + errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, generic)"; + CheckInference(valueOnlyInference: true); + } + + Debug.Assert(cmd.Parameters.Count == errorIdentifierIndex + 1); + + cmd.CommandText = "SELECT " + string.Join(", ", Enumerable.Range(1, cmd.Parameters.Count).Select(i => + "pg_typeof($1)::text, $1::text".Replace("$1", $"${i}"))); + + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + for (var i = 0; i < cmd.Parameters.Count * 2; i += 2) + { + Assert.That(reader[i], Is.EqualTo(pgTypeNameWithoutFacets), $"Got wrong PG type name when writing with {errorIdentifier[i / 2]}"); + Assert.That(reader[i+1], Is.EqualTo(expectedSqlLiteral), $"Got wrong SQL literal when writing with {errorIdentifier[i / 2]}"); + } + + void CheckInference(bool valueOnlyInference = false) + { + if (isNpgsqlDbTypeInferredFromClrType && npgsqlDbType is not null) + { + Assert.That(p.NpgsqlDbType, Is.EqualTo(npgsqlDbType), + () => $"Got wrong inferred NpgsqlDbType when inferring with {errorIdentifier[errorIdentifierIndex]}"); + } + + Assert.That(p.DbType, Is.EqualTo(valueOnlyInference ? inferredDbType : isNpgsqlDbTypeInferredFromClrType ? inferredDbType : dbType ?? DbType.Object), + () => $"Got wrong inferred DbType when inferring with {errorIdentifier[errorIdentifierIndex]}"); + + if (isNpgsqlDbTypeInferredFromClrType) + Assert.That(p.DataTypeName, Is.EqualTo(pgTypeNameWithoutFacets), + () => $"Got wrong inferred DataTypeName when inferring with {errorIdentifier[errorIdentifierIndex]}"); + } + } + + public async Task AssertTypeUnsupported(T value, string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + { + await AssertTypeUnsupportedRead(sqlLiteral, pgTypeName, dataSource); + await AssertTypeUnsupportedWrite(value, pgTypeName, dataSource); + } + + public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + { + dataSource ??= DataSource; + + await using var conn = await dataSource.OpenConnectionAsync(); + // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. + await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + return Assert.Throws(() => reader.GetValue(0))!; + } + + public Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, + NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) + => AssertTypeUnsupportedRead(sqlLiteral, pgTypeName, dataSource); + + public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, + NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) + where TException : Exception + { + var result = await AssertTypeUnsupportedReadCore(sqlLiteral, pgTypeName, dataSource); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeUnsupportedReadCore(ArrayLiteral(sqlLiteral), pgTypeName + "[]", dataSource); + } + + return result; + } + + async Task AssertTypeUnsupportedReadCore(string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + where TException : Exception + { + dataSource ??= DataSource; + + await using var conn = await dataSource.OpenConnectionAsync(); + // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. + await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + return Assert.Throws(() => reader.GetFieldValue(0))!; + } + + public Task AssertTypeUnsupportedWrite(T value, string? pgTypeName = null, NpgsqlDataSource? dataSource = null, + bool skipArrayCheck = false) + => AssertTypeUnsupportedWrite(value, pgTypeName, dataSource, skipArrayCheck: false); + + public async Task AssertTypeUnsupportedWrite(T value, string? pgTypeName = null, + NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) + where TException : Exception + { + var result = await AssertTypeUnsupportedWriteCore(value, pgTypeName, dataSource); + + // Check the corresponding array type as well + if (!skipArrayCheck && !pgTypeName?.EndsWith("[]", StringComparison.Ordinal) == true) + { + await AssertTypeUnsupportedWriteCore(new[] { value, value }, pgTypeName + "[]", dataSource); + } + + return result; + } + + async Task AssertTypeUnsupportedWriteCore(T value, string? pgTypeName = null, NpgsqlDataSource? dataSource = null) + where TException : Exception + { + dataSource ??= DataSource; + + await using var conn = await dataSource.OpenConnectionAsync(); + // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. + await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; + await using var cmd = new NpgsqlCommand("SELECT $1", conn) + { + Parameters = { new() { Value = value } } + }; + + if (pgTypeName is not null) + cmd.Parameters[0].DataTypeName = pgTypeName; + + return Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync())!; + } + + class SimpleComparer : IEqualityComparer + { + readonly Func _comparerDelegate; + + public SimpleComparer(Func comparerDelegate) + => _comparerDelegate = comparerDelegate; + + public bool Equals(T? x, T? y) + => x is null + ? y is null + : y is not null && _comparerDelegate(x, y); + + public int GetHashCode(T obj) => throw new NotSupportedException(); + } + + // For array quoting rules, see array_out in https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + static string ArrayLiteral(string elementLiteral) + { + switch (elementLiteral) + { + case "": + elementLiteral = "\"\""; + break; + case "NULL": + elementLiteral = "\"NULL\""; + break; + default: + // Escape quotes and backslashes, quote for special chars + elementLiteral = elementLiteral.Replace("\\", "\\\\").Replace("\"", "\\\""); + if (elementLiteral.Any(c => c is '{' or '}' or ',' or '"' or '\\' || char.IsWhiteSpace(c))) + { + elementLiteral = '"' + elementLiteral + '"'; + } + + break; + } + + return $"{{{elementLiteral},{elementLiteral}}}"; + } + + #endregion Type testing + + #region Utilities for use by tests + + protected virtual NpgsqlDataSourceBuilder CreateDataSourceBuilder() + => new(ConnectionString); + + protected virtual NpgsqlDataSource CreateDataSource() + => CreateDataSource(ConnectionString); + + protected NpgsqlDataSource CreateDataSource(string connectionString) + => NpgsqlDataSource.Create(connectionString); + + protected NpgsqlDataSource CreateDataSource(Action connectionStringBuilderAction) + { + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(ConnectionString); + connectionStringBuilderAction(connectionStringBuilder); + return NpgsqlDataSource.Create(connectionStringBuilder); + } + + protected NpgsqlDataSource CreateDataSource(Action configure) + { + var builder = new NpgsqlDataSourceBuilder(ConnectionString); + configure(builder); + return builder.Build(); + } + + protected static NpgsqlDataSource GetDataSource(string connectionString) + { + if (!DataSources.TryGetValue(connectionString, out var dataSource)) + { + lock (dataSourceLockObject) + { + if (!DataSources.TryGetValue(connectionString, out dataSource)) + { + var canonicalConnectionString = new NpgsqlConnectionStringBuilder(connectionString).ToString(); + if (!DataSources.TryGetValue(canonicalConnectionString, out dataSource)) + { + DataSources[canonicalConnectionString] = dataSource = NpgsqlDataSource.Create(connectionString); + } + DataSources[connectionString] = dataSource; + } + } + } + + return dataSource; + } + + protected virtual NpgsqlDataSource CreateLoggingDataSource( + out ListLoggerProvider listLoggerProvider, + string? connectionString = null, + bool sensitiveDataLoggingEnabled = true) + { + var builder = new NpgsqlDataSourceBuilder(connectionString ?? ConnectionString); + var provider = listLoggerProvider = new ListLoggerProvider(); + + builder.UseLoggerFactory(LoggerFactory.Create(loggerFactoryBuilder => + { + loggerFactoryBuilder.SetMinimumLevel(LogLevel.Trace); + loggerFactoryBuilder.AddProvider(provider); + })); + + builder.EnableParameterLogging(sensitiveDataLoggingEnabled); + + return builder.Build(); + } + + protected NpgsqlDataSource DefaultDataSource + => GetDataSource(ConnectionString); + + protected virtual NpgsqlDataSource DataSource => DefaultDataSource; + + protected virtual NpgsqlConnection CreateConnection() + => DataSource.CreateConnection(); + + protected virtual NpgsqlConnection OpenConnection() + { + var connection = CreateConnection(); + try + { + OpenConnection(connection, async: false).GetAwaiter().GetResult(); + return connection; + } + catch + { + connection.Dispose(); + throw; + } + } + + protected virtual async ValueTask OpenConnectionAsync() + { + var connection = CreateConnection(); + try + { + await OpenConnection(connection, async: true); + return connection; + } + catch + { + await connection.DisposeAsync(); + throw; + } + } + + static Task OpenConnection(NpgsqlConnection conn, bool async) + { + return OpenConnectionInternal(hasLock: false); + + async Task OpenConnectionInternal(bool hasLock) + { + try + { + if (async) + await conn.OpenAsync(); + else + conn.Open(); + } + catch (PostgresException e) + { + if (e.SqlState == PostgresErrorCodes.InvalidPassword) + throw new Exception("Please create a user npgsql_tests as follows: CREATE USER npgsql_tests PASSWORD 'npgsql_tests' SUPERUSER"); + + if (e.SqlState == PostgresErrorCodes.InvalidCatalogName) + { + if (!hasLock) + { + DatabaseCreationLock.Wait(); + try + { + await OpenConnectionInternal(hasLock: true); + } + finally + { + DatabaseCreationLock.Release(); + } + } + + // Database does not exist and we have the lock, proceed to creation + var builder = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString) + { + Pooling = false, + Multiplexing = false, + Database = "postgres" + }; + + using var adminConn = new NpgsqlConnection(builder.ConnectionString); + adminConn.Open(); + adminConn.ExecuteNonQuery("CREATE DATABASE " + conn.Database); + adminConn.Close(); + Thread.Sleep(1000); + + if (async) + await conn.OpenAsync(); + else + conn.Open(); + return; + } + + throw; + } + } + } + + // In PG under 9.1 you can't do SELECT pg_sleep(2) in binary because that function returns void and PG doesn't know + // how to transfer that. So cast to text server-side. + protected static NpgsqlCommand CreateSleepCommand(NpgsqlConnection conn, int seconds = 1000) + => new($"SELECT pg_sleep({seconds}){(conn.PostgreSqlVersion < new Version(9, 1, 0) ? "::TEXT" : "")}", conn); + + #endregion +} diff --git a/test/Npgsql.Tests/SyncOrAsyncTestBase.cs b/test/Npgsql.Tests/SyncOrAsyncTestBase.cs index e152426d9c..0eff0c7488 100644 --- a/test/Npgsql.Tests/SyncOrAsyncTestBase.cs +++ b/test/Npgsql.Tests/SyncOrAsyncTestBase.cs @@ -1,21 +1,20 @@ using NUnit.Framework; -namespace Npgsql.Tests -{ - [TestFixture(SyncOrAsync.Sync)] - [TestFixture(SyncOrAsync.Async)] - public abstract class SyncOrAsyncTestBase : TestBase - { - protected bool IsAsync => SyncOrAsync == SyncOrAsync.Async; +namespace Npgsql.Tests; - protected SyncOrAsync SyncOrAsync { get; } +[TestFixture(SyncOrAsync.Sync)] +[TestFixture(SyncOrAsync.Async)] +public abstract class SyncOrAsyncTestBase : TestBase +{ + protected bool IsAsync => SyncOrAsync == SyncOrAsync.Async; - protected SyncOrAsyncTestBase(SyncOrAsync syncOrAsync) => SyncOrAsync = syncOrAsync; - } + protected SyncOrAsync SyncOrAsync { get; } - public enum SyncOrAsync - { - Sync, - Async - } + protected SyncOrAsyncTestBase(SyncOrAsync syncOrAsync) => SyncOrAsync = syncOrAsync; } + +public enum SyncOrAsync +{ + Sync, + Async +} \ No newline at end of file diff --git a/test/Npgsql.Tests/SystemTransactionTests.cs b/test/Npgsql.Tests/SystemTransactionTests.cs index dc44c7fc77..b71c949259 100644 --- a/test/Npgsql.Tests/SystemTransactionTests.cs +++ b/test/Npgsql.Tests/SystemTransactionTests.cs @@ -1,390 +1,470 @@ using System; using System.Data; +using System.Threading; using System.Transactions; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +// This test suite contains ambient transaction tests, except those involving distributed transactions which are only +// supported on .NET Framework / Windows. Distributed transaction tests are in DistributedTransactionTests. +public class SystemTransactionTests : TestBase { - // This test suite contains ambient transaction tests, except those involving distributed transactions which are only - // supported on .NET Framework / Windows. Distributed transaction tests are in DistributedTransactionTests. - [NonParallelizable] - public class SystemTransactionTests : TestBase + [Test, Description("Single connection enlisting explicitly, committing")] + public void Explicit_enlist() { - [Test, Description("Single connection enlisting explicitly, committing")] - public void ExplicitEnlist() + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.OpenConnection(); + using (var scope = new TransactionScope()) { - using (var conn = new NpgsqlConnection(ConnectionStringEnlistOff)) - { - conn.Open(); - using (var scope = new TransactionScope()) - { - conn.EnlistTransaction(Transaction.Current); - Assert.That(conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - scope.Complete(); - } - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(1), "Unexpected data count"); - tx.Rollback(); - } - } + conn.EnlistTransaction(Transaction.Current); + Assert.That(conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + scope.Complete(); } - - [Test, Description("Single connection enlisting implicitly, committing")] - public void ImplicitEnlist() + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + using (var tx = conn.BeginTransaction()) { - var conn = new NpgsqlConnection(ConnectionStringEnlistOn); - using (var scope = new TransactionScope()) - { - conn.Open(); - Assert.That(conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - scope.Complete(); - } - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(1), "Unexpected data count"); - tx.Rollback(); - } + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(1), "Unexpected data count"); + tx.Rollback(); } + } - [Test] - public void EnlistOff() + [Test, Description("Single connection enlisting implicitly, committing")] + public void Implicit_enlist() + { + var dataSource = EnlistOnDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.CreateConnection(); + using (var scope = new TransactionScope()) { - using (new TransactionScope()) - using (var conn1 = OpenConnection(ConnectionStringEnlistOff)) - using (var conn2 = OpenConnection(ConnectionStringEnlistOff)) - { - Assert.That(conn1.EnlistedTransaction, Is.Null); - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - Assert.That(conn2.ExecuteScalar("SELECT COUNT(*) FROM data"), Is.EqualTo(1), "Unexpected data count"); - } - - // Scope disposed and not completed => rollback, but no enlistment, so changes should still be there. - using (var conn3 = OpenConnection(ConnectionStringEnlistOff)) - { - Assert.That(conn3.ExecuteScalar("SELECT COUNT(*) FROM data"), Is.EqualTo(1), "Insert unexpectedly rollback-ed"); - } + conn.Open(); + Assert.That(conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + scope.Complete(); } - - [Test, Description("Single connection enlisting explicitly, rollback")] - public void RollbackExplicitEnlist() + using (var tx = conn.BeginTransaction()) { - using (var conn = OpenConnection()) - { - using (new TransactionScope()) - { - conn.EnlistTransaction(Transaction.Current); - Assert.That(conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - // No commit - } - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(0), "Unexpected data count"); - tx.Rollback(); - } - } + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(1), "Unexpected data count"); + tx.Rollback(); } + } - [Test, Description("Single connection enlisting implicitly, rollback")] - public void RollbackImplicitEnlist() + [Test] + public void Enlist_Off() + { + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (new TransactionScope()) + using (var conn1 = dataSource.OpenConnection()) + using (var conn2 = dataSource.OpenConnection()) { - using (new TransactionScope()) - using (var conn = OpenConnection(ConnectionStringEnlistOn)) - { - Assert.That(conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - // No commit - } - - AssertNumberOfRows(0); + Assert.That(conn1.EnlistedTransaction, Is.Null); + Assert.That(conn1.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + Assert.That(conn2.ExecuteScalar($"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(1), "Unexpected data count"); } - [Test] - public void TwoConsecutiveConnections() + // Scope disposed and not completed => rollback, but no enlistment, so changes should still be there. + using (var conn3 = dataSource.OpenConnection()) { - using (var scope = new TransactionScope()) - { - using (var conn1 = OpenConnection(ConnectionStringEnlistOn)) - { - Assert.That(conn1.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); - } - - using (var conn2 = OpenConnection(ConnectionStringEnlistOn)) - { - Assert.That(conn2.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); - } - - // Consecutive connections used in same scope should not promote the - // transaction to distributed. - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - scope.Complete(); - } - AssertNumberOfRows(2); + Assert.That(conn3.ExecuteScalar($"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(1), "Insert unexpectedly rollback-ed"); } + } - [Test] - public void CloseConnection() + [Test, Description("Single connection enlisting explicitly, rollback")] + public void Rollback_explicit_enlist() + { + using var dataSource = CreateDataSource(); + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.OpenConnection(); + using (new TransactionScope()) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionStringEnlistOn) - { - ApplicationName = nameof(CloseConnection), - }.ToString(); - using (var scope = new TransactionScope()) - using (var conn = OpenConnection(connString)) - { - Assert.That(conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); - conn.Close(); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - scope.Complete(); - } - AssertNumberOfRows(1); - Assert.True(PoolManager.TryGetValue(connString, out var pool)); - Assert.That(pool!.Statistics.Idle, Is.EqualTo(1)); - - using (var conn = new NpgsqlConnection(connString)) - NpgsqlConnection.ClearPool(conn); + conn.EnlistTransaction(Transaction.Current); + Assert.That(conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + // No commit } + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + using (var tx = conn.BeginTransaction()) + { + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(0), "Unexpected data count"); + tx.Rollback(); + } + } + + [Test, Description("Single connection enlisting implicitly, rollback")] + [IssueLink("https://github.com/npgsql/npgsql/issues/2408")] + public void Rollback_implicit_enlist([Values(true, false)] bool pooling) + { + using var dataSource = CreateDataSource(csb => csb.Pooling = pooling); + var tableName = CreateTempTable(dataSource, "name TEXT"); - [Test] - public void EnlistToTwoTransactions() + using (new TransactionScope()) + using (var conn = dataSource.OpenConnection()) { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) - { - var ctx = new CommittableTransaction(); - conn.EnlistTransaction(ctx); - Assert.That(() => conn.EnlistTransaction(new CommittableTransaction()), Throws.Exception.TypeOf()); - ctx.Rollback(); - - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(0)); - tx.Rollback(); - } - } + Assert.That(conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + // No commit } - [Test] - public void EnlistTwiceToSameTransaction() + AssertNumberOfRows(0, tableName); + } + + [Test] + public void Two_consecutive_connections() + { + var dataSource = EnlistOnDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (var scope = new TransactionScope()) { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) + using (var conn1 = dataSource.OpenConnection()) { - var ctx = new CommittableTransaction(); - conn.EnlistTransaction(ctx); - conn.EnlistTransaction(ctx); - ctx.Rollback(); - - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(0)); - tx.Rollback(); - } + Assert.That(conn1.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test1')"), Is.EqualTo(1), "Unexpected first insert rowcount"); } - } - [Test] - public void ScopeAfterScope() - { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) + using (var conn2 = dataSource.OpenConnection()) { - using (new TransactionScope()) - conn.EnlistTransaction(Transaction.Current); - using (new TransactionScope()) - conn.EnlistTransaction(Transaction.Current); - - using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(0)); - tx.Rollback(); - } + Assert.That(conn2.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test2')"), Is.EqualTo(1), "Unexpected second insert rowcount"); } + + // Consecutive connections used in same scope should not promote the transaction to distributed. + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + scope.Complete(); } + AssertNumberOfRows(2, tableName); + } - [Test] - public void ReuseConnection() + [Test] + public void Close_connection() + { + // We assert the number of idle connections below + using var dataSource = CreateDataSource(csb => csb.Enlist = true); + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (var scope = new TransactionScope()) + using (var conn = dataSource.OpenConnection()) { - using (var scope = new TransactionScope()) - using (var conn = new NpgsqlConnection(ConnectionStringEnlistOn)) - { - conn.Open(); - var processId = conn.ProcessID; - conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"); - conn.Close(); + Assert.That(conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test')"), Is.EqualTo(1), "Unexpected insert rowcount"); + conn.Close(); + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + scope.Complete(); + } + AssertNumberOfRows(1, tableName); + Assert.That(dataSource.Statistics.Idle, Is.EqualTo(1)); + } - conn.Open(); - Assert.That(conn.ProcessID, Is.EqualTo(processId)); - conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"); - conn.Close(); + [Test] + public void Enlist_to_two_transactions() + { + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.OpenConnection(); + var ctx = new CommittableTransaction(); + conn.EnlistTransaction(ctx); + Assert.That(() => conn.EnlistTransaction(new CommittableTransaction()), Throws.Exception.TypeOf()); + ctx.Rollback(); + + using var tx = conn.BeginTransaction(); + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(0)); + tx.Rollback(); + } - scope.Complete(); - } - AssertNumberOfRows(2); + [Test] + public void Enlist_twice_to_same_transaction() + { + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.OpenConnection(); + var ctx = new CommittableTransaction(); + conn.EnlistTransaction(ctx); + conn.EnlistTransaction(ctx); + ctx.Rollback(); + + using var tx = conn.BeginTransaction(); + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(0)); + tx.Rollback(); + } + + [Test] + public void Scope_after_scope() + { + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var conn = dataSource.OpenConnection(); + using (new TransactionScope()) + conn.EnlistTransaction(Transaction.Current); + using (new TransactionScope()) + conn.EnlistTransaction(Transaction.Current); + + using (var tx = conn.BeginTransaction()) + { + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(0)); + tx.Rollback(); } + } - [Test] - public void ReuseConnectionRollback() + [Test] + public void Reuse_connection() + { + // We check the ProcessID below + using var dataSource = CreateDataSource(csb => csb.Enlist = true); + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (var scope = new TransactionScope()) + using (var conn = dataSource.CreateConnection()) { - using (new TransactionScope()) - using (var conn = new NpgsqlConnection(ConnectionStringEnlistOn)) - { - conn.Open(); - var processId = conn.ProcessID; - conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test1')"); - conn.Close(); + conn.Open(); + var processId = conn.ProcessID; + conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test1')"); + conn.Close(); - conn.Open(); - Assert.That(conn.ProcessID, Is.EqualTo(processId)); - conn.ExecuteNonQuery(@"INSERT INTO data (name) VALUES ('test2')"); - conn.Close(); + conn.Open(); + Assert.That(conn.ProcessID, Is.EqualTo(processId)); + conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test2')"); + conn.Close(); - // No commit - } - AssertNumberOfRows(0); + scope.Complete(); } + AssertNumberOfRows(2, tableName); + } - [Test, Ignore("Timeout doesn't seem to fire on .NET Core / Linux")] - public void TimeoutTriggersRollbackWhileBusy() + [Test] + public void Reuse_connection_rollback() + { + // We check the ProcessID below + using var dataSource = CreateDataSource(csb => csb.Enlist = true); + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (new TransactionScope()) + using (var conn = dataSource.CreateConnection()) { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) - { - using (new TransactionScope(TransactionScopeOption.Required, TimeSpan.FromSeconds(1))) - { - conn.EnlistTransaction(Transaction.Current); - Assert.That(() => CreateSleepCommand(conn, 5).ExecuteNonQuery(), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)) - .EqualTo(PostgresErrorCodes.QueryCanceled)); - - } - } - AssertNumberOfRows(0); - } + conn.Open(); + var processId = conn.ProcessID; + conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test1')"); + conn.Close(); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1579")] - public void SchemaConnectionShouldntEnlist() - { - using (var tran = new TransactionScope()) - using (var conn = OpenConnection(ConnectionStringEnlistOn)) - { - using (var cmd = new NpgsqlCommand("SELECT * FROM data", conn)) - using (var reader = cmd.ExecuteReader(CommandBehavior.KeyInfo)) - { - reader.GetColumnSchema(); - AssertNoDistributedIdentifier(); - AssertNoPreparedTransactions(); - tran.Complete(); - } - } + conn.Open(); + Assert.That(conn.ProcessID, Is.EqualTo(processId)); + conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test2')"); + conn.Close(); + + // No commit } + AssertNumberOfRows(0, tableName); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1737")] - public void Bug1737() + [Test, Ignore("Timeout doesn't seem to fire on .NET Core / Linux")] + public void Timeout_triggers_rollback_while_busy() + { + var dataSource = EnlistOffDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using (var conn = dataSource.OpenConnection()) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + using (new TransactionScope(TransactionScopeOption.Required, TimeSpan.FromSeconds(1))) { - Pooling = false, - Enlist = true - }; + conn.EnlistTransaction(Transaction.Current); + Assert.That(() => CreateSleepCommand(conn, 5).ExecuteNonQuery(), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)) + .EqualTo(PostgresErrorCodes.QueryCanceled)); - // Case 1 - using (var scope = new TransactionScope()) - { - using (var conn = OpenConnection(csb)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn)) - cmd.ExecuteNonQuery(); - scope.Complete(); } + } + AssertNumberOfRows(0, tableName); + } - // Case 2 - using (var scope = new TransactionScope()) - { - using (var conn1 = OpenConnection(csb)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn1)) - cmd.ExecuteNonQuery(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1579")] + public void Schema_connection_should_not_enlist() + { + var dataSource = EnlistOnDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var tran = new TransactionScope(); + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand($"SELECT * FROM {tableName}", conn); + using var reader = cmd.ExecuteReader(CommandBehavior.KeyInfo); + reader.GetColumnSchema(); + AssertNoDistributedIdentifier(); + AssertNoPreparedTransactions(); + tran.Complete(); + } - using (var conn2 = OpenConnection(csb)) - using (var cmd = new NpgsqlCommand("SELECT 1", conn2)) - cmd.ExecuteNonQuery(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1737")] + public void Single_unpooled_connection() + { + using var dataSource = CreateDataSource(csb => + { + csb.Pooling = false; + csb.Enlist = true; + }); + using var scope = new TransactionScope(); - scope.Complete(); - } - } + using (var conn = dataSource.OpenConnection()) + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) + cmd.ExecuteNonQuery(); - #region Utilities + scope.Complete(); + } - void AssertNoPreparedTransactions() - => Assert.That(GetNumberOfPreparedTransactions(), Is.EqualTo(0), "Prepared transactions found"); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4963")] + public void Single_unpooled_closed_connection() + { + using var dataSource = CreateDataSource(csb => + { + csb.Pooling = false; + csb.Enlist = true; + }); - int GetNumberOfPreparedTransactions() + using (var scope = new TransactionScope()) + using (var conn = dataSource.OpenConnection()) + using (var cmd = new NpgsqlCommand("SELECT 1", conn)) { - using (var conn = OpenConnection(ConnectionStringEnlistOff)) - using (var cmd = new NpgsqlCommand("SELECT COUNT(*) FROM pg_prepared_xacts WHERE database = @database", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("database", conn.Database)); - return (int)(long)cmd.ExecuteScalar()!; - } + cmd.ExecuteNonQuery(); + conn.Close(); + Assert.That(dataSource.Statistics.Total, Is.EqualTo(1)); + scope.Complete(); } - void AssertNumberOfRows(int expected) - => Assert.That(_controlConn.ExecuteScalar(@"SELECT COUNT(*) FROM data"), Is.EqualTo(expected), "Unexpected data count"); - - static void AssertNoDistributedIdentifier() - => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.EqualTo(Guid.Empty), "Distributed identifier found"); - - public readonly string ConnectionStringEnlistOn; - public readonly string ConnectionStringEnlistOff; + Assert.That(dataSource.Statistics.Total, Is.EqualTo(0)); + } - #endregion Utilities + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3863")] + public void Break_connector_while_in_transaction_scope_with_rollback([Values] bool pooling) + { + using var dataSource = CreateDataSource(csb => csb.Pooling = pooling); + using var scope = new TransactionScope(); + var conn = dataSource.OpenConnection(); - #region Setup + conn.ExecuteNonQuery("SELECT 1"); + conn.Connector!.Break(new Exception(nameof(Break_connector_while_in_transaction_scope_with_rollback))); + } - public SystemTransactionTests() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3863")] + public void Break_connector_while_in_transaction_scope_with_commit([Values] bool pooling) + { + using var dataSource = CreateDataSource(csb => csb.Pooling = pooling); + var ex = Assert.Throws(() => { - ConnectionStringEnlistOn = new NpgsqlConnectionStringBuilder(ConnectionString) { Enlist = true }.ToString(); - ConnectionStringEnlistOff = new NpgsqlConnectionStringBuilder(ConnectionString) { Enlist = false }.ToString(); - } + using var scope = new TransactionScope(); + var conn = dataSource.OpenConnection(); - NpgsqlConnection _controlConn = default!; + conn.ExecuteNonQuery("SELECT 1"); + conn.Connector!.Break(new Exception(nameof(Break_connector_while_in_transaction_scope_with_commit))); - [OneTimeSetUp] - public void OneTimeSetUp() + scope.Complete(); + })!; + Assert.That(ex.InnerException, Is.TypeOf()); + Assert.That(ex.InnerException!.InnerException, Is.TypeOf()); + Assert.That(ex.InnerException!.InnerException!.Message, Is.EqualTo(nameof(Break_connector_while_in_transaction_scope_with_commit))); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4085")] + public void Open_connection_with_enlist_and_aborted_TransactionScope() + { + var dataSource = EnlistOnDataSource; + for (var i = 0; i < 2; i++) { - _controlConn = OpenConnection(); + using var outerScope = new TransactionScope(); - // All tests in this fixture should have exclusive access to the database they're running on. - // If we run these tests in parallel (i.e. two builds in parallel) they will interfere. - // Solve this by taking a PostgreSQL advisory lock for the lifetime of the fixture. - _controlConn.ExecuteNonQuery("SELECT pg_advisory_lock(666)"); + try + { + using var innerScope = new TransactionScope(); + throw new Exception("Random exception to abort the transaction scope"); + } + catch (Exception) + { + } - _controlConn.ExecuteNonQuery("DROP TABLE IF EXISTS data"); - _controlConn.ExecuteNonQuery("CREATE TABLE data (name TEXT)"); + var ex = Assert.Throws(() => dataSource.OpenConnection())!; + Assert.That(ex.Message, Is.EqualTo("The operation is not valid for the state of the transaction.")); } + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1594")] + public void Bug1594() + { + var dataSource = EnlistOnDataSource; + var tableName = CreateTempTable(dataSource, "name TEXT"); + using var outerScope = new TransactionScope(); - [SetUp] - public void SetUp() + using (var conn = dataSource.OpenConnection()) + using (var innerScope1 = new TransactionScope()) { - _controlConn.ExecuteNonQuery("TRUNCATE data"); + conn.ExecuteNonQuery(@$"INSERT INTO {tableName} (name) VALUES ('test1')"); + innerScope1.Complete(); } -#pragma warning disable CS8625 - [OneTimeTearDown] - public void OneTimeTearDown() + using (dataSource.OpenConnection()) + using (new TransactionScope()) { - _controlConn?.Close(); - _controlConn = null; + // Don't complete, triggering rollback } -#pragma warning restore CS8625 + } + + #region Utilities + + void AssertNoPreparedTransactions() + => Assert.That(GetNumberOfPreparedTransactions(), Is.EqualTo(0), "Prepared transactions found"); - #endregion + int GetNumberOfPreparedTransactions() + { + var dataSource = EnlistOffDataSource; + using var conn = dataSource.OpenConnection(); + using var cmd = new NpgsqlCommand("SELECT COUNT(*) FROM pg_prepared_xacts WHERE database = @database", conn); + cmd.Parameters.Add(new NpgsqlParameter("database", conn.Database)); + return (int)(long)cmd.ExecuteScalar()!; + } + + void AssertNumberOfRows(int expected, string tableName) + { + using var conn = OpenConnection(); + Assert.That(conn.ExecuteScalar(@$"SELECT COUNT(*) FROM {tableName}"), Is.EqualTo(expected), "Unexpected data count"); } + + static void AssertNoDistributedIdentifier() + => Assert.That(Transaction.Current?.TransactionInformation.DistributedIdentifier ?? Guid.Empty, Is.EqualTo(Guid.Empty), "Distributed identifier found"); + + #endregion Utilities + + #region Setup + + NpgsqlDataSource EnlistOnDataSource { get; set; } = default!; + + NpgsqlDataSource EnlistOffDataSource { get; set; } = default!; + + [OneTimeSetUp] + public void OneTimeSetUp() + { + EnlistOnDataSource = CreateDataSource(csb => csb.Enlist = true); + EnlistOffDataSource = CreateDataSource(csb => csb.Enlist = false); + } + + [OneTimeTearDown] + public void OnTimeTearDown() + { + EnlistOnDataSource?.Dispose(); + EnlistOnDataSource = null!; + EnlistOffDataSource?.Dispose(); + EnlistOffDataSource = null!; + } + + internal static string CreateTempTable(NpgsqlDataSource dataSource, string columns) + { + var tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); + dataSource.ExecuteNonQuery(@$" +START TRANSACTION; SELECT pg_advisory_xact_lock(0); +DROP TABLE IF EXISTS {tableName} CASCADE; +COMMIT; +CREATE TABLE {tableName} ({columns})"); + return tableName; + } + + #endregion } diff --git a/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs b/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs new file mode 100644 index 0000000000..e3759d35e9 --- /dev/null +++ b/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs @@ -0,0 +1,162 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Npgsql.Util; + +namespace Npgsql.Tests; + +[NonParallelizable] // To make sure unobserved tasks from other tests do not leak +public class TaskTimeoutAndCancellationTest : TestBase +{ + const int TestResultValue = 777; + + async Task GetResultTaskAsync(int timeout, CancellationToken ct) + { + await Task.Delay(timeout, ct); + return TestResultValue; + } + + Task GetVoidTaskAsync(int timeout, CancellationToken ct) => Task.Delay(timeout, ct); + + [Test] + public async Task SuccessfulResultTaskAsync() => + Assert.AreEqual(TestResultValue, await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetResultTaskAsync(10, ct), NpgsqlTimeout.Infinite, CancellationToken.None)); + + [Test] + public async Task SuccessfulVoidTaskAsync() => + await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(10, ct), NpgsqlTimeout.Infinite, CancellationToken.None); + + [Test] + public void InfinitelyLongTaskTimeout() => + Assert.ThrowsAsync(async () => + await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(Timeout.Infinite, ct), new NpgsqlTimeout(TimeSpan.FromMilliseconds(10)), CancellationToken.None)); + + [Test] + public void InfinitelyLongTaskCancellation() + { + using var cts = new CancellationTokenSource(10); + Assert.ThrowsAsync(async () => + await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(Timeout.Infinite, ct), NpgsqlTimeout.Infinite, cts.Token)); + } + + /// + /// The test creates a delayed execution Task that is being fake-cancelled and fails subsequently and triggers 'TaskScheduler.UnobservedTaskException event'. + /// + /// + /// The test is based on timing and depends on availability of thread pool threads. Therefore it could become unstable if the environment is under pressure. + /// + [Theory, IssueLink("https://github.com/npgsql/npgsql/issues/4149")] + [TestCase("CancelAndTimeout")] + [TestCase("CancelOnly")] + [TestCase("TimeoutOnly")] + [TestCase("CancelAndTimeout")] + [TestCase("CancelOnly")] + [TestCase("TimeoutOnly")] + public Task DelayedFaultedTaskCancellation(string testCase) => RunDelayedFaultedTaskTestAsync(async getUnobservedTaskException => + { + var cancel = true; + var timeout = true; + switch (testCase) + { + case "TimeoutOnly": + cancel = false; + break; + case "CancelOnly": + timeout = false; + break; + } + + var notifyDelayCompleted = new SemaphoreSlim(0, 1); + + // Invoke the method that creates a delayed execution Task that fails subsequently. + await CreateTaskAndPreemptWithCancellationAsync(500, cancel, timeout, notifyDelayCompleted); + + // Wait enough time for the non-cancelable task to notify us that an exception is thrown. + await notifyDelayCompleted.WaitAsync(); + + // And then wait some more. + var repeatCount = 2; + while (getUnobservedTaskException() is null && repeatCount-- > 0) + { + await Task.Delay(100); + + // Run the garbage collector to collect unobserved Tasks. + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + }); + + static async Task RunDelayedFaultedTaskTestAsync(Func, Task> test) + { + // Run the garbage collector to collect unobserved Tasks from other tests. + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + Exception? unobservedTaskException = null; + + // Subscribe to UnobservedTaskException event to store the Exception, if any. + void OnUnobservedTaskException(object? source, UnobservedTaskExceptionEventArgs args) + { + if (!args.Observed) + { + args.SetObserved(); + } + unobservedTaskException = args.Exception; + } + TaskScheduler.UnobservedTaskException += OnUnobservedTaskException; + + try + { + await test(() => unobservedTaskException); + + // Verify the unobserved Task exception event has not been received. + Assert.IsNull(unobservedTaskException, unobservedTaskException?.Message); + } + finally + { + TaskScheduler.UnobservedTaskException -= OnUnobservedTaskException; + } + } + + /// + /// Create a delayed execution, non-Cancellable Task that fails subsequently after the Task goes out of scope. + /// + static async Task CreateTaskAndPreemptWithCancellationAsync(int delayMs, bool cancel, bool timeout, SemaphoreSlim notifyDelayCompleted) + { + var nonCancellableTask = Task.Delay(delayMs, CancellationToken.None) + .ContinueWith( + async _ => + { + try + { + await Task.FromException(new Exception("Unobserved Task Test Exception")); + } + finally + { + notifyDelayCompleted.Release(); + } + }) + .Unwrap(); + + var timeoutMs = delayMs / 5; + using var cts = cancel ? new CancellationTokenSource(timeoutMs) : null; + try + { + await TaskTimeoutAndCancellation.ExecuteAsync( + _ => nonCancellableTask, + timeout ? new NpgsqlTimeout(TimeSpan.FromMilliseconds(timeoutMs)) : NpgsqlTimeout.Infinite, + cts?.Token ?? CancellationToken.None); + } + catch (TimeoutException) + { + // Expected due to preemptive time out. + } + catch (OperationCanceledException) when (cts?.IsCancellationRequested == true) + { + // Expected due to preemptive cancellation. + } + Assert.False(nonCancellableTask.IsCompleted); + } +} diff --git a/test/Npgsql.Tests/TestBase.cs b/test/Npgsql.Tests/TestBase.cs deleted file mode 100644 index 872d9847e2..0000000000 --- a/test/Npgsql.Tests/TestBase.cs +++ /dev/null @@ -1,88 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace Npgsql.Tests -{ - public abstract class TestBase - { - /// - /// The connection string that will be used when opening the connection to the tests database. - /// May be overridden in fixtures, e.g. to set special connection parameters - /// - public virtual string ConnectionString => TestUtil.ConnectionString; - - #region Utilities for use by tests - - protected virtual NpgsqlConnection CreateConnection(string? connectionString = null) - => new NpgsqlConnection(connectionString ?? ConnectionString); - - protected virtual NpgsqlConnection CreateConnection(Action builderAction) - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - builderAction(builder); - return new NpgsqlConnection(builder.ConnectionString); - } - - protected virtual NpgsqlConnection OpenConnection(string? connectionString = null) - => OpenConnection(connectionString, async: false).GetAwaiter().GetResult(); - - protected virtual NpgsqlConnection OpenConnection(Action builderAction) - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - builderAction(builder); - return OpenConnection(builder.ConnectionString, async: false).GetAwaiter().GetResult(); - } - - protected virtual ValueTask OpenConnectionAsync(string? connectionString = null) - => OpenConnection(connectionString, async: true); - - protected virtual ValueTask OpenConnectionAsync( - Action builderAction) - { - var builder = new NpgsqlConnectionStringBuilder(ConnectionString); - builderAction(builder); - return OpenConnection(builder.ConnectionString, async: true); - } - - async ValueTask OpenConnection(string? connectionString, bool async) - { - var conn = CreateConnection(connectionString); - try - { - if (async) - await conn.OpenAsync(); - else - conn.Open(); - } - catch (PostgresException e) - { - if (e.SqlState == PostgresErrorCodes.InvalidCatalogName) - TestUtil.IgnoreExceptOnBuildServer("Please create a database npgsql_tests, owned by user npgsql_tests"); - else if (e.SqlState == PostgresErrorCodes.InvalidPassword && connectionString == TestUtil.DefaultConnectionString) - TestUtil.IgnoreExceptOnBuildServer("Please create a user npgsql_tests as follows: create user npgsql_tests with password 'npgsql_tests'"); - else - throw; - } - - return conn; - } - - protected NpgsqlConnection OpenConnection(NpgsqlConnectionStringBuilder csb) - => OpenConnection(csb.ToString()); - - protected virtual ValueTask OpenConnectionAsync(NpgsqlConnectionStringBuilder csb) - => OpenConnectionAsync(csb.ToString()); - - // In PG under 9.1 you can't do SELECT pg_sleep(2) in binary because that function returns void and PG doesn't know - // how to transfer that. So cast to text server-side. - protected static NpgsqlCommand CreateSleepCommand(NpgsqlConnection conn, int seconds = 1000) - => new NpgsqlCommand($"SELECT pg_sleep({seconds}){(conn.PostgreSqlVersion < new Version(9, 1, 0) ? "::TEXT" : "")}", conn); - - protected bool IsRedshift => new NpgsqlConnectionStringBuilder(ConnectionString).ServerCompatibilityMode == ServerCompatibilityMode.Redshift; - - #endregion - } -} diff --git a/test/Npgsql.Tests/TestMetrics.cs b/test/Npgsql.Tests/TestMetrics.cs index a30ee51ad6..52bf2ed935 100644 --- a/test/Npgsql.Tests/TestMetrics.cs +++ b/test/Npgsql.Tests/TestMetrics.cs @@ -1,180 +1,179 @@ using System; using System.Diagnostics; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +/// +/// Keep track of metrics related to performance. +/// +internal sealed class TestMetrics : IDisposable { + private static Process process = Process.GetCurrentProcess(); + + private bool running; /// - /// Keep track of metrics related to performance. + /// The number of iterations accumulated. /// - internal sealed class TestMetrics : IDisposable - { - private static Process process = Process.GetCurrentProcess(); + public int Iterations { get; private set; } + private TimeSpan systemCPUTime; + private TimeSpan userCPUTime; + private Stopwatch stopwatch; + private TimeSpan allowedTime; - private bool running; - /// - /// The number of iterations accumulated. - /// - public int Iterations { get; private set; } - private TimeSpan systemCPUTime; - private TimeSpan userCPUTime; - private Stopwatch stopwatch; - private TimeSpan allowedTime; + private bool reportOnStop; - private bool reportOnStop; + private TestMetrics(TimeSpan allowedTime, bool reportOnStop) + { + Iterations = 0; + systemCPUTime = process.PrivilegedProcessorTime; + userCPUTime = process.UserProcessorTime; + stopwatch = Stopwatch.StartNew(); + this.allowedTime = allowedTime; + this.reportOnStop = reportOnStop; + + running = true; + } - private TestMetrics(TimeSpan allowedTime, bool reportOnStop) - { - Iterations = 0; - systemCPUTime = process.PrivilegedProcessorTime; - userCPUTime = process.UserProcessorTime; - stopwatch = Stopwatch.StartNew(); - this.allowedTime = allowedTime; - this.reportOnStop = reportOnStop; - - running = true; - } + /// + /// Construct and start a new TestMetrics object. + /// + /// Length of time the test should run. + /// Report metrics to stdout when stopped. + /// A new running TestMetrics object. + public static TestMetrics Start(TimeSpan allowedTime, bool reportOnStop) + { + return new(allowedTime, reportOnStop); + } - /// - /// Construct and start a new TestMetrics object. - /// - /// Length of time the test should run. - /// Report metrics to stdout when stopped. - /// A new running TestMetrics object. - public static TestMetrics Start(TimeSpan allowedTime, bool reportOnStop) - { - return new TestMetrics(allowedTime, reportOnStop); - } + /// + /// Incremnent the Iterations value by one. + /// + public void IncrementIterations() + { + Iterations++; + } - /// - /// Incremnent the Iterations value by one. - /// - public void IncrementIterations() + /// + /// Stop the internal stop watch and record elapsed CPU times. + /// + public void Stop() + { + if (! running) { - Iterations++; + return; } - /// - /// Stop the internal stop watch and record elapsed CPU times. - /// - public void Stop() - { - if (! running) - { - return; - } - - stopwatch.Stop(); - systemCPUTime = process.PrivilegedProcessorTime - systemCPUTime; - userCPUTime = process.UserProcessorTime - userCPUTime; - - running = false; + stopwatch.Stop(); + systemCPUTime = process.PrivilegedProcessorTime - systemCPUTime; + userCPUTime = process.UserProcessorTime - userCPUTime; - if (reportOnStop) - { - Console.WriteLine("Elapsed: {0:mm\\:ss\\.ff}", ElapsedClockTime); - Console.WriteLine("CPU: {0:mm\\:ss\\.ffffff} (User: {1:mm\\:ss\\.ffffff}, System: {2:mm\\:ss\\.ffffff})", ElapsedTotalCPUTime, ElapsedUserCPUTime, ElapsedSystemCPUTime); - Console.WriteLine("Iterations: {0}; {1:0.00}/second, {2:0.00}/CPU second", Iterations, IterationsPerSecond(), IterationsPerCPUSecond()); - } - } + running = false; - /// - /// Stop the internal stop watch and record elapsed CPU times. - /// - public void Dispose() + if (reportOnStop) { - Stop(); + Console.WriteLine("Elapsed: {0:mm\\:ss\\.ff}", ElapsedClockTime); + Console.WriteLine("CPU: {0:mm\\:ss\\.ffffff} (User: {1:mm\\:ss\\.ffffff}, System: {2:mm\\:ss\\.ffffff})", ElapsedTotalCPUTime, ElapsedUserCPUTime, ElapsedSystemCPUTime); + Console.WriteLine("Iterations: {0}; {1:0.00}/second, {2:0.00}/CPU second", Iterations, IterationsPerSecond(), IterationsPerCPUSecond()); } + } - /// - /// Report whether ElapsedClockTime has met or exceeded the maximum run time. - /// - public bool TimesUp => (stopwatch.Elapsed >= allowedTime); - - /// - /// Calculate the number of iterations accumulated per the time span provided. - /// - /// - /// The number of iterations accumulated per the time span provided. - public double IterationsPer(TimeSpan timeSpan) - { - return (double)Iterations / ((double)stopwatch.Elapsed.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); - } + /// + /// Stop the internal stop watch and record elapsed CPU times. + /// + public void Dispose() + { + Stop(); + } - /// - /// Calculate the number of iterations accumulated per second. - /// Equivelent to calling IterationsPer(new TimeSpan(0, 0, 1)). - /// - /// The number of iterations accumulated per second. - public double IterationsPerSecond() - { - return IterationsPer(new TimeSpan(0, 0, 1)); - } + /// + /// Report whether ElapsedClockTime has met or exceeded the maximum run time. + /// + public bool TimesUp => (stopwatch.Elapsed >= allowedTime); - /// - /// Calculate the number of iterations accumulated per the CPU time span provided. - /// - /// - /// The number of iterations accumulated per the CPU time span provided. - public double IterationsPerCPU(TimeSpan timeSpan) - { - return (double)Iterations / ((double)ElapsedTotalCPUTime.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); - } + /// + /// Calculate the number of iterations accumulated per the time span provided. + /// + /// + /// The number of iterations accumulated per the time span provided. + public double IterationsPer(TimeSpan timeSpan) + { + return (double)Iterations / ((double)stopwatch.Elapsed.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); + } - /// - /// Calculate the number of iterations accumulated per CPU second. - /// Equivelent to calling IterationsPerCPU(new TimeSpan(0, 0, 1)). - /// - /// - /// The number of iterations accumulated per CPU second. - public double IterationsPerCPUSecond() - { - return IterationsPerCPU(new TimeSpan(0, 0, 1)); - } + /// + /// Calculate the number of iterations accumulated per second. + /// Equivelent to calling IterationsPer(new TimeSpan(0, 0, 1)). + /// + /// The number of iterations accumulated per second. + public double IterationsPerSecond() + { + return IterationsPer(new TimeSpan(0, 0, 1)); + } - /// - /// Elapsed time since start. - /// - public TimeSpan ElapsedClockTime => stopwatch.Elapsed; + /// + /// Calculate the number of iterations accumulated per the CPU time span provided. + /// + /// + /// The number of iterations accumulated per the CPU time span provided. + public double IterationsPerCPU(TimeSpan timeSpan) + { + return (double)Iterations / ((double)ElapsedTotalCPUTime.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); + } - /// - /// Elapsed system CPU time since start. - /// - public TimeSpan ElapsedSystemCPUTime + /// + /// Calculate the number of iterations accumulated per CPU second. + /// Equivelent to calling IterationsPerCPU(new TimeSpan(0, 0, 1)). + /// + /// + /// The number of iterations accumulated per CPU second. + public double IterationsPerCPUSecond() + { + return IterationsPerCPU(new TimeSpan(0, 0, 1)); + } + + /// + /// Elapsed time since start. + /// + public TimeSpan ElapsedClockTime => stopwatch.Elapsed; + + /// + /// Elapsed system CPU time since start. + /// + public TimeSpan ElapsedSystemCPUTime + { + get { - get + if (running) { - if (running) - { - return process.PrivilegedProcessorTime - systemCPUTime; - } - else - { - return systemCPUTime; - } + return process.PrivilegedProcessorTime - systemCPUTime; + } + else + { + return systemCPUTime; } } + } - /// - /// Elapsed user CPU time since start. - /// - public TimeSpan ElapsedUserCPUTime + /// + /// Elapsed user CPU time since start. + /// + public TimeSpan ElapsedUserCPUTime + { + get { - get + if (running) + { + return process.UserProcessorTime - userCPUTime; + } + else { - if (running) - { - return process.UserProcessorTime - userCPUTime; - } - else - { - return userCPUTime; - } + return userCPUTime; } } - - /// - /// Elapsed total (system + user) CPU time since start. - /// - public TimeSpan ElapsedTotalCPUTime => ElapsedSystemCPUTime + ElapsedUserCPUTime; } -} + + /// + /// Elapsed total (system + user) CPU time since start. + /// + public TimeSpan ElapsedTotalCPUTime => ElapsedSystemCPUTime + ElapsedUserCPUTime; +} \ No newline at end of file diff --git a/test/Npgsql.Tests/TestUtil.cs b/test/Npgsql.Tests/TestUtil.cs index 9ea47f709a..b2e61317ad 100644 --- a/test/Npgsql.Tests/TestUtil.cs +++ b/test/Npgsql.Tests/TestUtil.cs @@ -1,469 +1,543 @@ using System; using System.Collections.Generic; using System.Data; +using System.Diagnostics.CodeAnalysis; using System.Globalization; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; -using JetBrains.Annotations; +using Microsoft.Extensions.Logging; using NUnit.Framework; -using NUnit.Framework.Interfaces; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public static class TestUtil { - public static class TestUtil - { - /// - /// Unless the NPGSQL_TEST_DB environment variable is defined, this is used as the connection string for the - /// test database. - /// - public const string DefaultConnectionString = "Server=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0"; - - /// - /// The connection string that will be used when opening the connection to the tests database. - /// May be overridden in fixtures, e.g. to set special connection parameters - /// - public static string ConnectionString { get; } - = Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; - - public static bool IsOnBuildServer => - Environment.GetEnvironmentVariable("GITHUB_ACTIONS") != null || - Environment.GetEnvironmentVariable("CI") != null; - - /// - /// Calls Assert.Ignore() unless we're on the build server, in which case calls - /// Assert.Fail(). We don't to miss any regressions just because something was misconfigured - /// at the build server and caused a test to be inconclusive. - /// - public static void IgnoreExceptOnBuildServer(string message) + /// + /// Unless the NPGSQL_TEST_DB environment variable is defined, this is used as the connection string for the + /// test database. + /// + public const string DefaultConnectionString = + "Host=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0;SSL Mode=Disable;Multiplexing=False"; + + /// + /// The connection string that will be used when opening the connection to the tests database. + /// May be overridden in fixtures, e.g. to set special connection parameters + /// + public static string ConnectionString { get; } + = Environment.GetEnvironmentVariable("NPGSQL_TEST_DB") ?? DefaultConnectionString; + + public static bool IsOnBuildServer => + Environment.GetEnvironmentVariable("GITHUB_ACTIONS") != null || + Environment.GetEnvironmentVariable("CI") != null; + + /// + /// Calls Assert.Ignore() unless we're on the build server, in which case calls + /// Assert.Fail(). We don't to miss any regressions just because something was misconfigured + /// at the build server and caused a test to be inconclusive. + /// + [DoesNotReturn] + public static void IgnoreExceptOnBuildServer(string message) + { + if (IsOnBuildServer) + Assert.Fail(message); + else + Assert.Ignore(message); + + throw new Exception("Should not occur"); + } + + public static void IgnoreExceptOnBuildServer(string message, params object[] args) + => IgnoreExceptOnBuildServer(string.Format(message, args)); + + public static void MinimumPgVersion(NpgsqlDataSource dataSource, string minVersion, string? ignoreText = null) + { + using var connection = dataSource.OpenConnection(); + MinimumPgVersion(connection, minVersion, ignoreText); + } + + public static bool MinimumPgVersion(NpgsqlConnection conn, string minVersion, string? ignoreText = null) + { + var min = new Version(minVersion); + if (conn.PostgreSqlVersion < min) { - if (IsOnBuildServer) - Assert.Fail(message); - else - Assert.Ignore(message); + var msg = $"Postgresql backend version {conn.PostgreSqlVersion} is less than the required {min}"; + if (ignoreText != null) + msg += ": " + ignoreText; + Assert.Ignore(msg); + return false; } - public static void IgnoreExceptOnBuildServer(string message, params object[] args) - => IgnoreExceptOnBuildServer(string.Format(message, args)); + return true; + } - public static void MinimumPgVersion(NpgsqlConnection conn, string minVersion, string? ignoreText = null) + public static void MaximumPgVersionExclusive(NpgsqlConnection conn, string maxVersion, string? ignoreText = null) + { + var max = new Version(maxVersion); + if (conn.PostgreSqlVersion >= max) { - var min = new Version(minVersion); - if (conn.PostgreSqlVersion < min) - { - var msg = $"Postgresql backend version {conn.PostgreSqlVersion} is less than the required {min}"; - if (ignoreText != null) - msg += ": " + ignoreText; - Assert.Ignore(msg); - } + var msg = $"Postgresql backend version {conn.PostgreSqlVersion} is greater than or equal to the required (exclusive) maximum of {maxVersion}"; + if (ignoreText != null) + msg += ": " + ignoreText; + Assert.Ignore(msg); } + } + + static readonly Version MinCreateExtensionVersion = new(9, 1); - public static void MaximumPgVersionExclusive(NpgsqlConnection conn, string maxVersion, string? ignoreText = null) + public static void IgnoreOnRedshift(NpgsqlConnection conn, string? ignoreText = null) + { + if (new NpgsqlConnectionStringBuilder(conn.ConnectionString).ServerCompatibilityMode == ServerCompatibilityMode.Redshift) { - var max = new Version(maxVersion); - if (conn.PostgreSqlVersion >= max) - { - var msg = $"Postgresql backend version {conn.PostgreSqlVersion} is greater than or equal to the required (exclusive) maximum of {maxVersion}"; - if (ignoreText != null) - msg += ": " + ignoreText; - Assert.Ignore(msg); - } + var msg = "Test ignored on Redshift"; + if (ignoreText != null) + msg += ": " + ignoreText; + Assert.Ignore(msg); } + } - static readonly Version MinCreateExtensionVersion = new Version(9, 1); - - public static bool IsPgPrerelease(NpgsqlConnection conn) - => ((string)conn.ExecuteScalar("SELECT version()")!).Contains("beta"); + public static bool IsPgPrerelease(NpgsqlConnection conn) + => ((string)conn.ExecuteScalar("SELECT version()")!).Contains("beta"); - public static void EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion = null) - => EnsureExtension(conn, extension, minVersion, async: false).GetAwaiter().GetResult(); + public static void EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion = null) + => EnsureExtension(conn, extension, minVersion, async: false).GetAwaiter().GetResult(); - public static Task EnsureExtensionAsync(NpgsqlConnection conn, string extension, string? minVersion = null) - => EnsureExtension(conn, extension, minVersion, async: true); + public static Task EnsureExtensionAsync(NpgsqlConnection conn, string extension, string? minVersion = null) + => EnsureExtension(conn, extension, minVersion, async: true); - static async Task EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion, bool async) - { - if (minVersion != null) - MinimumPgVersion(conn, minVersion, - $"The extension '{extension}' only works for PostgreSQL {minVersion} and higher."); + static async Task EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion, bool async) + { + if (minVersion != null && !MinimumPgVersion(conn, minVersion, $"The extension '{extension}' only works for PostgreSQL {minVersion} and higher.")) + return; - if (conn.PostgreSqlVersion < MinCreateExtensionVersion) - Assert.Ignore($"The 'CREATE EXTENSION' command only works for PostgreSQL {MinCreateExtensionVersion} and higher."); + if (conn.PostgreSqlVersion < MinCreateExtensionVersion) + Assert.Ignore($"The 'CREATE EXTENSION' command only works for PostgreSQL {MinCreateExtensionVersion} and higher."); + try + { if (async) await conn.ExecuteNonQueryAsync($"CREATE EXTENSION IF NOT EXISTS {extension}"); else conn.ExecuteNonQuery($"CREATE EXTENSION IF NOT EXISTS {extension}"); - - conn.ReloadTypes(); } - - public static async Task EnsurePostgis(NpgsqlConnection conn) + catch (PostgresException ex) when (ex.ConstraintName == "pg_extension_name_index") { - try - { - await EnsureExtensionAsync(conn, "postgis"); - } - catch (PostgresException e) when (e.SqlState == "58P01") - { - // PostGIS packages aren't available for PostgreSQL prereleases - if (IsPgPrerelease(conn)) - { - Assert.Ignore($"PostGIS could not be installed, but PostgreSQL is prerelease ({conn.ServerVersion}), ignoring test suite."); - } - } + // The extension is already installed, but we can race across threads. + // https://stackoverflow.com/questions/63104126/create-extention-if-not-exists-doesnt-really-check-if-extention-does-not-exis } - public static string GetUniqueIdentifier(string prefix) - => prefix + Interlocked.Increment(ref _counter); - - static int _counter; - - /// - /// Creates a table with a unique name, usable for a single test, and returns an to - /// drop it at the end of the test. - /// - internal static Task CreateTempTable(NpgsqlConnection conn, string columns, out string tableName) - { - tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); - return conn.ExecuteNonQueryAsync($"DROP TABLE IF EXISTS {tableName} CASCADE; CREATE TABLE {tableName} ({columns})") - .ContinueWith( - (t, name) => (IAsyncDisposable)new DatabaseObjectDropper(conn, (string)name!, "TABLE"), - tableName, - TaskContinuationOptions.OnlyOnRanToCompletion); - } + conn.ReloadTypes(); + } - /// - /// Creates a schema with a unique name, usable for a single test, and returns an to - /// drop it at the end of the test. - /// - internal static Task CreateTempSchema(NpgsqlConnection conn, out string schemaName) - { - schemaName = "temp_schema" + Interlocked.Increment(ref _tempSchemaCounter); - return conn.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS {schemaName} CASCADE; CREATE SCHEMA {schemaName}") - .ContinueWith( - (t, name) => (IAsyncDisposable)new DatabaseObjectDropper(conn, (string)name!, "SCHEMA"), - schemaName, - TaskContinuationOptions.OnlyOnRanToCompletion); - } + /// + /// Causes the test to be ignored if the supplied query fails with SqlState 0A000 (feature_not_supported) + /// + /// The connection to execute the test query. The connection needs to be open. + /// The query to test for the feature. + /// This query needs to fail with SqlState 0A000 (feature_not_supported) if the feature isn't present. + public static void IgnoreIfFeatureNotSupported(NpgsqlConnection conn, string testQuery) + => IgnoreIfFeatureNotSupported(conn, testQuery, async: false).GetAwaiter().GetResult(); - /// - /// Generates a unique table name, usable for a single test, and drops it if it already exists. - /// Actual creation of the table is the responsibility of the caller. - /// - /// - /// An to drop the table at the end of the test. - /// - internal static Task GetTempTableName(NpgsqlConnection conn, out string tableName) - { - tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); - return conn.ExecuteNonQueryAsync($"DROP TABLE IF EXISTS {tableName} CASCADE") - .ContinueWith( - (t, name) => (IAsyncDisposable)new DatabaseObjectDropper(conn, (string)name!, "TABLE"), - tableName, - TaskContinuationOptions.OnlyOnRanToCompletion); - } + /// + /// Causes the test to be ignored if the supplied query fails with SqlState 0A000 (feature_not_supported) + /// + /// The connection to execute the test query. The connection needs to be open. + /// The query to test for the feature. + /// This query needs to fail with SqlState 0A000 (feature_not_supported) if the feature isn't present. + public static Task IgnoreIfFeatureNotSupportedAsync(NpgsqlConnection conn, string testQuery) + => IgnoreIfFeatureNotSupported(conn, testQuery, async: true); - /// - /// Generates a unique view name, usable for a single test, and drops it if it already exists. - /// Actual creation of the view is the responsibility of the caller. - /// - /// - /// An to drop the view at the end of the test. - /// - internal static Task GetTempViewName(NpgsqlConnection conn, out string viewName) + static async Task IgnoreIfFeatureNotSupported(NpgsqlConnection conn, string testQuery, bool async) + { + try { - viewName = "temp_view" + Interlocked.Increment(ref _tempViewCounter); - return conn.ExecuteNonQueryAsync($"DROP VIEW IF EXISTS {viewName} CASCADE") - .ContinueWith( - (t, name) => (IAsyncDisposable)new DatabaseObjectDropper(conn, (string)name!, "VIEW"), - viewName, - TaskContinuationOptions.OnlyOnRanToCompletion); + if (async) + await conn.ExecuteNonQueryAsync(testQuery); + else + conn.ExecuteNonQuery(testQuery); } - - /// - /// Generates a unique function name, usable for a single test. - /// Actual creation of the function is the responsibility of the caller. - /// - /// - /// An to drop the function at the end of the test. - /// - internal static IAsyncDisposable GetTempFunctionName(NpgsqlConnection conn, out string functionName) + catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.FeatureNotSupported) { - functionName = "temp_func" + Interlocked.Increment(ref _tempFunctionCounter); - return new DatabaseObjectDropper(conn, functionName, "FUNCTION"); + Assert.Ignore(e.Message); } + } - /// - /// Generates a unique function name, usable for a single test. - /// Actual creation of the function is the responsibility of the caller. - /// - /// - /// An to drop the function at the end of the test. - /// - internal static Task GetTempTypeName(NpgsqlConnection conn, out string typeName) + public static async Task EnsurePostgis(NpgsqlConnection conn) + { + var isPreRelease = IsPgPrerelease(conn); + try { - typeName = "temp_type" + Interlocked.Increment(ref _tempTypeCounter); - return conn.ExecuteNonQueryAsync($"DROP TYPE IF EXISTS {typeName} CASCADE") - .ContinueWith( - (t, name) => (IAsyncDisposable)new DatabaseObjectDropper(conn, (string)name!, "TYPE"), - typeName, - TaskContinuationOptions.OnlyOnRanToCompletion); + await EnsureExtensionAsync(conn, "postgis"); } - - static volatile int _tempTableCounter; - static volatile int _tempViewCounter; - static volatile int _tempFunctionCounter; - static volatile int _tempSchemaCounter; - static volatile int _tempTypeCounter; - - readonly struct DatabaseObjectDropper : IAsyncDisposable + catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.UndefinedFile) { - readonly NpgsqlConnection _conn; - readonly string _type; - readonly string _name; - - internal DatabaseObjectDropper(NpgsqlConnection conn, string name, string type) - => (_conn, _name, _type) = (conn, name, type); - - public async ValueTask DisposeAsync() + // PostGIS packages aren't available for PostgreSQL prereleases + if (isPreRelease) { - try - { - await _conn.ExecuteNonQueryAsync($"DROP {_type} {_name} CASCADE"); - } - catch - { - // Swallow to allow triggering exceptions to surface - } + Assert.Ignore($"PostGIS could not be installed, but PostgreSQL is prerelease ({conn.ServerVersion}), ignoring test suite."); } - } - - /// - /// Creates a pool with a unique application name, usable for a single test, and returns an - /// to drop it at the end of the test. - /// - internal static IDisposable CreateTempPool(string origConnectionString, out string tempConnectionString) - => CreateTempPool(new NpgsqlConnectionStringBuilder(origConnectionString), out tempConnectionString); - - /// - /// Creates a pool with a unique application name, usable for a single test, and returns an - /// to drop it at the end of the test. - /// - internal static IDisposable CreateTempPool(NpgsqlConnectionStringBuilder builder, out string tempConnectionString) - { - builder.ApplicationName = (builder.ApplicationName ?? "TempPool") + Interlocked.Increment(ref _tempPoolCounter); - tempConnectionString = builder.ConnectionString; - return new PoolDisposer(tempConnectionString); - } - - static volatile int _tempPoolCounter; - - readonly struct PoolDisposer : IDisposable - { - readonly string _connectionString; - - internal PoolDisposer(string connectionString) => _connectionString = connectionString; - - public void Dispose() + else { - var conn = new NpgsqlConnection(_connectionString); - NpgsqlConnection.ClearPool(conn); + throw; } } + } - /// - /// Utility to generate a bytea literal in Postgresql hex format - /// See https://www.postgresql.org/docs/current/static/datatype-binary.html - /// - internal static string EncodeByteaHex(ICollection buf) - { - var hex = new StringBuilder(@"E'\\x", buf.Count * 2 + 3); - foreach (var b in buf) - hex.Append($"{b:x2}"); - hex.Append("'"); - return hex.ToString(); - } + public static string GetUniqueIdentifier(string prefix) + => prefix + Interlocked.Increment(ref _counter); - internal static IDisposable SetEnvironmentVariable(string name, string? value) - { - var resetter = new EnvironmentVariableResetter(name, Environment.GetEnvironmentVariable(name)); - Environment.SetEnvironmentVariable(name, value); - return resetter; - } + static int _counter; - internal static IDisposable SetCurrentCulture(CultureInfo culture) => - new CultureSetter(culture); - - class EnvironmentVariableResetter : IDisposable - { - readonly string _name; - readonly string? _value; + /// + /// Creates a table with a unique name, usable for a single test. + /// + internal static async Task CreateTempTable(NpgsqlConnection conn, string columns) + { + var tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); - internal EnvironmentVariableResetter(string name, string? value) - { - _name = name; - _value = value; - } + await conn.ExecuteNonQueryAsync(@$" +START TRANSACTION; +SELECT pg_advisory_xact_lock(0); +DROP TABLE IF EXISTS {tableName} CASCADE; +COMMIT; +CREATE TABLE {tableName} ({columns});"); - public void Dispose() => - Environment.SetEnvironmentVariable(_name, _value); - } + return tableName; + } - class CultureSetter : IDisposable - { - readonly CultureInfo _oldCulture; + /// + /// Generates a unique table name, usable for a single test, and drops it if it already exists. + /// Actual creation of the table is the responsibility of the caller. + /// + internal static async Task GetTempTableName(NpgsqlConnection conn) + { + var tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); + await conn.ExecuteNonQueryAsync(@$" +START TRANSACTION; +SELECT pg_advisory_xact_lock(0); +DROP TABLE IF EXISTS {tableName} CASCADE; +COMMIT"); + return tableName; + } - internal CultureSetter(CultureInfo newCulture) - { - _oldCulture = CultureInfo.CurrentCulture; - CultureInfo.CurrentCulture = newCulture; - } + /// + /// Creates a table with a unique name, usable for a single test, and returns an to + /// drop it at the end of the test. + /// + internal static async Task CreateTempTable(NpgsqlDataSource dataSource, string columns) + { + var tableName = "temp_table" + Interlocked.Increment(ref _tempTableCounter); + await dataSource.ExecuteNonQueryAsync(@$" +START TRANSACTION; +SELECT pg_advisory_xact_lock(0); +DROP TABLE IF EXISTS {tableName} CASCADE; +COMMIT; +CREATE TABLE {tableName} ({columns});"); + return tableName; + } - public void Dispose() => - CultureInfo.CurrentCulture = _oldCulture; - } + /// + /// Creates a schema with a unique name, usable for a single test. + /// + internal static async Task CreateTempSchema(NpgsqlConnection conn) + { + var schemaName = "temp_schema" + Interlocked.Increment(ref _tempSchemaCounter); + await conn.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS {schemaName} CASCADE; CREATE SCHEMA {schemaName}"); + return schemaName; } - public static class NpgsqlConnectionExtensions + /// + /// Generates a unique view name, usable for a single test, and drops it if it already exists. + /// Actual creation of the view is the responsibility of the caller. + /// + internal static async Task GetTempViewName(NpgsqlConnection conn) { - public static int ExecuteNonQuery(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) - { - var cmd = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); - using (cmd) - return cmd.ExecuteNonQuery(); - } + var viewName = "temp_view" + Interlocked.Increment(ref _tempViewCounter); + await conn.ExecuteNonQueryAsync($"DROP VIEW IF EXISTS {viewName} CASCADE"); + return viewName; + } - public static object? ExecuteScalar(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) - { - var cmd = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); - using (cmd) - return cmd.ExecuteScalar(); - } + /// + /// Generates a unique materialized view name, usable for a single test, and drops it if it already exists. + /// Actual creation of the materialized view is the responsibility of the caller. + /// + internal static async Task GetTempMaterializedViewName(NpgsqlConnection conn) + { + var viewName = "temp_materialized_view" + Interlocked.Increment(ref _tempViewCounter); + await conn.ExecuteNonQueryAsync($"DROP MATERIALIZED VIEW IF EXISTS {viewName} CASCADE"); + return viewName; + } - public static async Task ExecuteNonQueryAsync(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) - { - var cmd = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); - using (cmd) - return await cmd.ExecuteNonQueryAsync(); - } + /// + /// Generates a unique function name, usable for a single test. + /// Actual creation of the function is the responsibility of the caller. + /// + internal static async Task GetTempFunctionName(NpgsqlConnection conn) + { + var functionName = "temp_func" + Interlocked.Increment(ref _tempFunctionCounter); + await conn.ExecuteNonQueryAsync($"DROP FUNCTION IF EXISTS {functionName} CASCADE"); + return functionName; + } - public static async Task ExecuteScalarAsync(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) - { - var cmd = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); - using (cmd) - return await cmd.ExecuteScalarAsync(); - } + /// + /// Generates a unique function name, usable for a single test. + /// Actual creation of the function is the responsibility of the caller. + /// + /// + /// An to drop the function at the end of the test. + /// + internal static async Task GetTempProcedureName(NpgsqlDataSource dataSource) + { + var procedureName = "temp_procedure" + Interlocked.Increment(ref _tempProcedureCounter); + await dataSource.ExecuteNonQueryAsync($"DROP PROCEDURE IF EXISTS {procedureName} CASCADE"); + return procedureName; } - public static class CommandBehaviorExtensions + /// + /// Generates a unique function name, usable for a single test. + /// Actual creation of the function is the responsibility of the caller. + /// + /// + /// An to drop the function at the end of the test. + /// + internal static async Task GetTempProcedureName(NpgsqlConnection connection) { - public static bool IsSequential(this CommandBehavior behavior) - => (behavior & CommandBehavior.SequentialAccess) != 0; + var procedureName = "temp_procedure" + Interlocked.Increment(ref _tempProcedureCounter); + await connection.ExecuteNonQueryAsync($"DROP PROCEDURE IF EXISTS {procedureName} CASCADE"); + return procedureName; } /// - /// Semantic attribute that points to an issue linked with this test (e.g. this - /// test reproduces the issue) + /// Generates a unique type name, usable for a single test. + /// Actual creation of the type is the responsibility of the caller. /// - [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] - public class IssueLink : Attribute + internal static async Task GetTempTypeName(NpgsqlConnection conn) { - public string LinkAddress { get; private set; } - public IssueLink(string linkAddress) - { - LinkAddress = linkAddress; - } + var typeName = "temp_type" + Interlocked.Increment(ref _tempTypeCounter); + await conn.ExecuteNonQueryAsync($"DROP TYPE IF EXISTS {typeName} CASCADE"); + return typeName; } + internal static volatile int _tempTableCounter; + static volatile int _tempViewCounter; + static volatile int _tempFunctionCounter; + static volatile int _tempProcedureCounter; + static volatile int _tempSchemaCounter; + static volatile int _tempTypeCounter; + /// - /// Causes the test to be ignored on mono + /// Creates a pool with a unique application name, usable for a single test, and returns an + /// to drop it at the end of the test. /// - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly, AllowMultiple = false)] - public class MonoIgnore : Attribute, ITestAction + internal static IDisposable CreateTempPool(string origConnectionString, out string tempConnectionString) + => CreateTempPool(new NpgsqlConnectionStringBuilder(origConnectionString), out tempConnectionString); + + /// + /// Creates a pool with a unique application name, usable for a single test, and returns an + /// to drop it at the end of the test. + /// + internal static IDisposable CreateTempPool(NpgsqlConnectionStringBuilder builder, out string tempConnectionString) + { + builder.ApplicationName = (builder.ApplicationName ?? "TempPool") + Interlocked.Increment(ref _tempPoolCounter); + tempConnectionString = builder.ConnectionString; + return new PoolDisposer(tempConnectionString); + } + + static volatile int _tempPoolCounter; + + readonly struct PoolDisposer : IDisposable { - readonly string? _ignoreText; + readonly string _connectionString; - public MonoIgnore(string? ignoreText = null) { _ignoreText = ignoreText; } + internal PoolDisposer(string connectionString) => _connectionString = connectionString; - public void BeforeTest(ITest test) + public void Dispose() { - if (Type.GetType("Mono.Runtime") != null) - { - var msg = "Ignored on mono"; - if (_ignoreText != null) - msg += ": " + _ignoreText; - Assert.Ignore(msg); - } + var conn = new NpgsqlConnection(_connectionString); + NpgsqlConnection.ClearPool(conn); } - - public void AfterTest(ITest test) { } - public ActionTargets Targets => ActionTargets.Test; } /// - /// Causes the test to be ignored on Linux + /// Utility to generate a bytea literal in Postgresql hex format + /// See https://www.postgresql.org/docs/current/static/datatype-binary.html /// - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly, AllowMultiple = false)] - public class LinuxIgnore : Attribute, ITestAction + internal static string EncodeByteaHex(ICollection buf) + { + var hex = new StringBuilder(@"E'\\x", buf.Count * 2 + 3); + foreach (var b in buf) + hex.Append($"{b:x2}"); + hex.Append("'"); + return hex.ToString(); + } + + internal static IDisposable SetEnvironmentVariable(string name, string? value) { - readonly string? _ignoreText; + var oldValue = Environment.GetEnvironmentVariable(name); + Environment.SetEnvironmentVariable(name, value); + return new DeferredExecutionDisposable(() => Environment.SetEnvironmentVariable(name, oldValue)); + } - public LinuxIgnore(string? ignoreText = null) { _ignoreText = ignoreText; } + internal static IDisposable SetCurrentCulture(CultureInfo culture) + { + var oldCulture = CultureInfo.CurrentCulture; + CultureInfo.CurrentCulture = culture; - public void BeforeTest(ITest test) - { - var osEnvVar = Environment.GetEnvironmentVariable("OS"); - if (osEnvVar == null || osEnvVar != "Windows_NT") - { - var msg = "Ignored on Linux"; - if (_ignoreText != null) - msg += ": " + _ignoreText; - Assert.Ignore(msg); - } - } + return new DeferredExecutionDisposable(() => CultureInfo.CurrentCulture = oldCulture); + } - public void AfterTest(ITest test) { } - public ActionTargets Targets => ActionTargets.Test; + internal static IDisposable DisableSqlRewriting() + { +#if DEBUG + NpgsqlCommand.EnableSqlRewriting = false; + return new DeferredExecutionDisposable(() => NpgsqlCommand.EnableSqlRewriting = true); +#else + Assert.Ignore("Cannot disable SQL rewriting in RELEASE builds"); + throw new NotSupportedException("Cannot disable SQL rewriting in RELEASE builds"); +#endif } - /// - /// Causes the test to be ignored on Windows - /// - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly, AllowMultiple = false)] - public class WindowsIgnore : Attribute, ITestAction + class DeferredExecutionDisposable : IDisposable { - readonly string? _ignoreText; + readonly Action _action; + + internal DeferredExecutionDisposable(Action action) => _action = action; - public WindowsIgnore(string? ignoreText = null) { _ignoreText = ignoreText; } + public void Dispose() + => _action(); + } - public void BeforeTest(ITest test) + internal static object AssertLoggingStateContains( + (LogLevel Level, EventId Id, string Message, object? State, Exception? Exception) log, + string key) + { + if (log.State is not IEnumerable> keyValuePairs || keyValuePairs.All(kvp => kvp.Key != key)) { - var osEnvVar = Environment.GetEnvironmentVariable("OS"); - if (osEnvVar == "Windows_NT") - { - var msg = "Ignored on Windows"; - if (_ignoreText != null) - msg += ": " + _ignoreText; - Assert.Ignore(msg); - } + Assert.Fail($@"Dod not find logging state key ""{key}"""); + throw new Exception(); } - public void AfterTest(ITest test) { } - public ActionTargets Targets => ActionTargets.Test; + return keyValuePairs.Single(kvp => kvp.Key == key).Value; + } + + internal static void AssertLoggingStateContains( + (LogLevel Level, EventId Id, string Message, object? State, Exception? Exception) log, + string key, + T value) + => Assert.That(log.State, Contains.Item(new KeyValuePair(key, value))); + + internal static void AssertLoggingStateDoesNotContain( + (LogLevel Level, EventId Id, string Message, object? State, Exception? Exception) log, + string key) + { + var value = log.State is IEnumerable> keyValuePairs && + keyValuePairs.FirstOrDefault(kvp => kvp.Key == key) is { } kvpPair + ? kvpPair.Value + : null; + + Assert.That(value, Is.Null, $@"Found logging state (""{key}"", {value}"); + } +} + +public static class NpgsqlConnectionExtensions +{ + public static int ExecuteNonQuery(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) + { + using var command = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); + return command.ExecuteNonQuery(); + } + + public static object? ExecuteScalar(this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null) + { + using var command = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); + return command.ExecuteScalar(); } - public enum PrepareOrNot + public static async Task ExecuteNonQueryAsync( + this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null, CancellationToken cancellationToken = default) { - Prepared, - NotPrepared + await using var command = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); + return await command.ExecuteNonQueryAsync(cancellationToken); } - public enum PooledOrNot + public static async Task ExecuteScalarAsync( + this NpgsqlConnection conn, string sql, NpgsqlTransaction? tx = null, CancellationToken cancellationToken = default) { - Pooled, - Unpooled + await using var command = tx == null ? new NpgsqlCommand(sql, conn) : new NpgsqlCommand(sql, conn, tx); + return await command.ExecuteScalarAsync(cancellationToken); } +} + +public static class NpgsqlDataSourceExtensions +{ + public static int ExecuteNonQuery(this NpgsqlDataSource dataSource, string sql) + { + using var command = dataSource.CreateCommand(sql); + return command.ExecuteNonQuery(); + } + + public static object? ExecuteScalar(this NpgsqlDataSource dataSource, string sql) + { + using var command = dataSource.CreateCommand(sql); + return command.ExecuteScalar(); + } + + public static async Task ExecuteNonQueryAsync( + this NpgsqlDataSource dataSource, string sql, CancellationToken cancellationToken = default) + { + await using var command = dataSource.CreateCommand(sql); + return await command.ExecuteNonQueryAsync(cancellationToken); + } + + public static async Task ExecuteScalarAsync( + this NpgsqlDataSource dataSource, string sql, CancellationToken cancellationToken = default) + { + await using var command = dataSource.CreateCommand(sql); + return await command.ExecuteScalarAsync(cancellationToken); + } +} + +public static class CommandBehaviorExtensions +{ + public static bool IsSequential(this CommandBehavior behavior) + => (behavior & CommandBehavior.SequentialAccess) != 0; +} + +public static class NpgsqlCommandExtensions +{ + public static void WaitUntilCommandIsInProgress(this NpgsqlCommand command) + { + while (command.State != CommandState.InProgress) + Thread.Sleep(50); + } +} + +/// +/// Semantic attribute that points to an issue linked with this test (e.g. this +/// test reproduces the issue) +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] +public class IssueLink : Attribute +{ + public string LinkAddress { get; private set; } + public IssueLink(string linkAddress) + { + LinkAddress = linkAddress; + } +} + +public enum PrepareOrNot +{ + Prepared, + NotPrepared +} + +public enum PooledOrNot +{ + Pooled, + Unpooled +} #if NETSTANDARD2_0 static class QueueExtensions @@ -481,4 +555,3 @@ public static bool TryDequeue(this Queue queue, out T result) } } #endif -} diff --git a/test/Npgsql.Tests/TransactionTests.cs b/test/Npgsql.Tests/TransactionTests.cs index a43ea99f52..e0e61f95b4 100644 --- a/test/Npgsql.Tests/TransactionTests.cs +++ b/test/Npgsql.Tests/TransactionTests.cs @@ -1,8 +1,7 @@ using System; -using System.Buffers.Binary; using System.Data; using System.Threading.Tasks; -using Npgsql.BackendMessages; +using Npgsql.Internal; using Npgsql.Tests.Support; using Npgsql.Util; using NUnit.Framework; @@ -11,35 +10,80 @@ // ReSharper disable MethodHasAsyncOverload // ReSharper disable UseAwaitUsing -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class TransactionTests : MultiplexingTestBase { - public class TransactionTests : MultiplexingTestBase + [Test, Description("Basic insert within a commited transaction")] + public async Task Commit([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - [Test] - public async Task Commit() + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + var tx = await conn.BeginTransactionAsync(); + await using (tx) { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + var cmd = new NpgsqlCommand($"INSERT INTO {table} (name) VALUES ('X')", conn, tx); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + cmd.ExecuteNonQuery(); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + tx.Commit(); + Assert.That(tx.IsCompleted); + Assert.That(() => tx.Connection, Throws.Nothing); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + } - var tx = conn.BeginTransaction(); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); + // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException + // Because disposed NpgsqlTransaction might be reused by another thread + if (!IsMultiplexing) + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + } + + [Test, Description("Basic insert within a commited transaction")] + public async Task CommitAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + var tx = await conn.BeginTransactionAsync(); + await using (tx) + { + var cmd = new NpgsqlCommand($"INSERT INTO {table} (name) VALUES ('X')", conn, tx); + if (prepare == PrepareOrNot.Prepared) + cmd.Prepare(); + await cmd.ExecuteNonQueryAsync(); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); await tx.CommitAsync(); + Assert.That(tx.IsCompleted); + Assert.That(() => tx.Connection, Throws.Nothing); Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); - await tx.DisposeAsync(); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } - [Test, Description("Basic insert within a rolled back transaction")] - public async Task Rollback([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException + // Because disposed NpgsqlTransaction might be reused by another thread + if (!IsMultiplexing) + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + } - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + [Test, Description("Basic insert within a rolled back transaction")] + public async Task Rollback([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); - var tx = conn.BeginTransaction(); + var tx = await conn.BeginTransactionAsync(); + await using (tx) + { var cmd = new NpgsqlCommand($"INSERT INTO {table} (name) VALUES ('X')", conn, tx); if (prepare == PrepareOrNot.Prepared) cmd.Prepare(); @@ -47,450 +91,461 @@ public async Task Rollback([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepar Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); tx.Rollback(); Assert.That(tx.IsCompleted); + Assert.That(() => tx.Connection, Throws.Nothing); Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); - await tx.DisposeAsync(); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } - [Test, Description("Basic insert within a rolled back transaction")] - public async Task RollbackAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) - { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; + // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException + // Because disposed NpgsqlTransaction might be reused by another thread + if (!IsMultiplexing) + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + } + + [Test, Description("Basic insert within a rolled back transaction")] + public async Task RollbackAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + if (prepare == PrepareOrNot.Prepared && IsMultiplexing) + return; - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); - var tx = conn.BeginTransaction(); + var tx = await conn.BeginTransactionAsync(); + await using (tx) + { var cmd = new NpgsqlCommand($"INSERT INTO {table} (name) VALUES ('X')", conn, tx); if (prepare == PrepareOrNot.Prepared) cmd.Prepare(); - cmd.ExecuteNonQuery(); + await cmd.ExecuteNonQueryAsync(); Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); await tx.RollbackAsync(); Assert.That(tx.IsCompleted); + Assert.That(() => tx.Connection, Throws.Nothing); Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); - await tx.DisposeAsync(); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } - [Test, Description("Dispose a transaction in progress, should roll back")] - public async Task RollbackOnDispose() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException + // Because disposed NpgsqlTransaction might be reused by another thread + if (!IsMultiplexing) + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + } + + [Test, Description("Dispose a transaction in progress, should roll back")] + public async Task Rollback_on_Dispose() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); - var tx = conn.BeginTransaction(); + await using (var tx = await conn.BeginTransactionAsync()) + { await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); - await tx.DisposeAsync(); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - [Test] - public async Task RollbackOnClose() + [Test] + public async Task Rollback_on_Close() + { + await using var conn1 = await OpenConnectionAsync(); + var table = await CreateTempTable(conn1, "name TEXT"); + + using (var conn2 = await OpenConnectionAsync()) { - await using var conn1 = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn1, "name TEXT", out var table); + var tx = await conn2.BeginTransactionAsync(); + await conn2.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx); + } - NpgsqlTransaction tx; - using (var conn2 = await OpenConnectionAsync()) - { - tx = conn2.BeginTransaction(); - await conn2.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx); - } + Assert.That(await conn1.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - Assert.That(await conn1.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); - } + [Test, Description("Intentionally generates an error, putting us in a failed transaction block. Rolls back.")] + public async Task Rollback_failed() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + await using var tx = await conn.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); + Assert.That(async () => await conn.ExecuteNonQueryAsync("BAD QUERY"), Throws.Exception.TypeOf()); + tx.Rollback(); + Assert.That(tx.IsCompleted); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - [Test, Description("Intentionally generates an error, putting us in a failed transaction block. Rolls back.")] - public async Task RollbackFailed() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + [Test, Description("Commits an empty transaction")] + public async Task Empty_commit() + { + await using var conn = await OpenConnectionAsync(); + await conn.BeginTransaction().CommitAsync(); + } - var tx = conn.BeginTransaction(); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); - Assert.That(async () => await conn.ExecuteNonQueryAsync("BAD QUERY"), Throws.Exception.TypeOf()); - tx.Rollback(); - Assert.That(tx.IsCompleted); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } + [Test, Description("Rolls back an empty transaction")] + public async Task Empty_rollback() + { + await using var conn = await OpenConnectionAsync(); + await conn.BeginTransaction().RollbackAsync(); + } - [Test, Description("Commits an empty transaction")] - public async Task EmptyCommit() - { - await using var conn = await OpenConnectionAsync(); - await conn.BeginTransaction().CommitAsync(); - } + [Test, Description("Disposes an empty transaction")] + public async Task Empty_Dispose() + { + await using var dataSource = CreateDataSource(); + + using (var conn = await dataSource.OpenConnectionAsync()) + using (conn.BeginTransaction()) + { } - [Test, Description("Rolls back an empty transaction")] - public async Task EmptyRollback() + using (var conn = await dataSource.OpenConnectionAsync()) { - await using var conn = await OpenConnectionAsync(); - await conn.BeginTransaction().RollbackAsync(); + // Make sure the pending BEGIN TRANSACTION didn't leak from the previous open + Assert.That(async () => await conn.ExecuteNonQueryAsync("SAVEPOINT foo"), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.NoActiveSqlTransaction)); } + } - [Test, Description("Disposes an empty transaction")] - public async Task EmptyDisposeTransaction() - { - using var _ = CreateTempPool(ConnectionString, out var connString); + [Test, Description("Tests that the isolation levels are properly supported")] + [TestCase(IsolationLevel.ReadCommitted, "read committed")] + [TestCase(IsolationLevel.ReadUncommitted, "read uncommitted")] + [TestCase(IsolationLevel.RepeatableRead, "repeatable read")] + [TestCase(IsolationLevel.Serializable, "serializable")] + [TestCase(IsolationLevel.Snapshot, "repeatable read")] + [TestCase(IsolationLevel.Unspecified, "read committed")] + public async Task Isolation_levels(IsolationLevel level, string expectedName) + { + await using var conn = await OpenConnectionAsync(); + var tx = conn.BeginTransaction(level); + Assert.That(conn.ExecuteScalar("SHOW TRANSACTION ISOLATION LEVEL"), Is.EqualTo(expectedName)); + await tx.CommitAsync(); + } - using (var conn = await OpenConnectionAsync(connString)) - using (conn.BeginTransaction()) - { } + [Test] + public async Task IsolationLevel_Chaos_is_unsupported() + { + await using var conn = await OpenConnectionAsync(); + Assert.That(() => conn.BeginTransaction(IsolationLevel.Chaos), Throws.Exception.TypeOf()); + } - using (var conn = await OpenConnectionAsync(connString)) - { - // Make sure the pending BEGIN TRANSACTION didn't leak from the previous open - Assert.That(async () => await conn.ExecuteNonQueryAsync("SAVEPOINT foo"), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("25P01")); - } - } + [Test, Description("Rollback of an already rolled back transaction")] + public async Task Rollback_twice() + { + await using var conn = await OpenConnectionAsync(); + var transaction = conn.BeginTransaction(); + transaction.Rollback(); + Assert.That(() => transaction.Rollback(), Throws.Exception.TypeOf()); + } - [Test, Description("Tests that the isolation levels are properly supported")] - [TestCase(IsolationLevel.ReadCommitted, "read committed")] - [TestCase(IsolationLevel.ReadUncommitted, "read uncommitted")] - [TestCase(IsolationLevel.RepeatableRead, "repeatable read")] - [TestCase(IsolationLevel.Serializable, "serializable")] - [TestCase(IsolationLevel.Snapshot, "repeatable read")] - [TestCase(IsolationLevel.Unspecified, "read committed")] - public async Task IsolationLevels(IsolationLevel level, string expectedName) - { - await using var conn = await OpenConnectionAsync(); - var tx = conn.BeginTransaction(level); - Assert.That(conn.ExecuteScalar("SHOW TRANSACTION ISOLATION LEVEL"), Is.EqualTo(expectedName)); - await tx.CommitAsync(); - } + [Test, Description("Makes sure the creating a transaction via DbConnection sets the proper isolation level")] + [IssueLink("https://github.com/npgsql/npgsql/issues/559")] + public async Task Default_IsolationLevel() + { + await using var conn = await OpenConnectionAsync(); + var tx = conn.BeginTransaction(); + Assert.That(tx.IsolationLevel, Is.EqualTo(IsolationLevel.ReadCommitted)); + tx.Rollback(); + + tx = conn.BeginTransaction(IsolationLevel.Unspecified); + Assert.That(tx.IsolationLevel, Is.EqualTo(IsolationLevel.ReadCommitted)); + tx.Rollback(); + } - [Test] - public async Task IsolationLevelChaosUnsupported() - { - await using var conn = await OpenConnectionAsync(); - Assert.That((TestDelegate)(() => conn.BeginTransaction(IsolationLevel.Chaos)), Throws.Exception.TypeOf()); - } + [Test, Description("Makes sure that transactions started in SQL work, except in multiplexing")] + public async Task Via_sql() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: not implemented"); - [Test, Description("Rollback of an already rolled back transaction")] - public async Task RollbackTwice() - { - await using var conn = await OpenConnectionAsync(); - var transaction = conn.BeginTransaction(); - transaction.Rollback(); - Assert.That(() => transaction.Rollback(), Throws.Exception.TypeOf()); - } + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); - [Test, Description("Makes sure the creating a transaction via DbConnection sets the proper isolation level")] - [IssueLink("https://github.com/npgsql/npgsql/issues/559")] - public async Task DbConnectionDefaultIsolation() + if (IsMultiplexing) { - await using var conn = await OpenConnectionAsync(); - var tx = conn.BeginTransaction(); - Assert.That(tx.IsolationLevel, Is.EqualTo(IsolationLevel.ReadCommitted)); - tx.Rollback(); - - tx = conn.BeginTransaction(IsolationLevel.Unspecified); - Assert.That(tx.IsolationLevel, Is.EqualTo(IsolationLevel.ReadCommitted)); - tx.Rollback(); + Assert.That(async () => await conn.ExecuteNonQueryAsync("BEGIN"), Throws.Exception.TypeOf()); + return; } - [Test, Description("Makes sure that transactions started in SQL work, except in multiplexing")] - public async Task ViaSql() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: not implemented"); + await conn.ExecuteNonQueryAsync("BEGIN"); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); + await conn.ExecuteNonQueryAsync("ROLLBACK"); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); + [Test] + public async Task Nested() + { + await using var conn = await OpenConnectionAsync(); + conn.BeginTransaction(); + Assert.That(() => conn.BeginTransaction(), Throws.TypeOf()); + } - if (IsMultiplexing) - { - Assert.That(async () => await conn.ExecuteNonQueryAsync("BEGIN"), Throws.Exception.TypeOf()); - return; - } + [Test] + public void Begin_transaction_on_closed_connection_throws() + { + using var conn = new NpgsqlConnection(); + Assert.That(() => conn.BeginTransaction(), Throws.Exception.TypeOf()); + } - await conn.ExecuteNonQueryAsync("BEGIN"); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); - await conn.ExecuteNonQueryAsync("ROLLBACK"); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); - } + [Test] + public async Task Rollback_failed_transaction_with_timeout() + { + await using var conn = await OpenConnectionAsync(); - [Test] - public async Task Nested() + var tx = conn.BeginTransaction(); + using var cmd = new NpgsqlCommand("BAD QUERY", conn, tx); + Assert.That(cmd.CommandTimeout != 1); + cmd.CommandTimeout = 1; + try { - await using var conn = await OpenConnectionAsync(); - conn.BeginTransaction(); - Assert.That(() => conn.BeginTransaction(), Throws.TypeOf()); + cmd.ExecuteScalar(); + Assert.Fail(); } - - [Test] - public void BeginTransactionBeforeOpen() + catch (PostgresException) { - using var conn = new NpgsqlConnection(); - Assert.That((TestDelegate)(() => conn.BeginTransaction()), Throws.Exception.TypeOf()); + // Timeout at the backend is now 1 + await tx.RollbackAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } + } - [Test] - public async Task RollbackFailedTransactionWithTimeout() - { - await using var conn = await OpenConnectionAsync(); + [Test, Description("If a custom command timeout is set, a failed transaction could not be rollbacked to a previous savepoint")] + [IssueLink("https://github.com/npgsql/npgsql/issues/363")] + [IssueLink("https://github.com/npgsql/npgsql/issues/184")] + public async Task Failed_transaction_cannot_rollback_to_savepoint_with_custom_timeout() + { + await using var conn = await OpenConnectionAsync(); - var tx = conn.BeginTransaction(); - using var cmd = new NpgsqlCommand("BAD QUERY", conn, tx); - Assert.That(cmd.CommandTimeout != 1); - cmd.CommandTimeout = 1; - try - { - cmd.ExecuteScalar(); - Assert.Fail(); - } - catch (PostgresException) - { - // Timeout at the backend is now 1 - await tx.RollbackAsync(); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + var transaction = conn.BeginTransaction(); + transaction.Save("TestSavePoint"); - [Test, Description("If a custom command timeout is set, a failed transaction could not be rollbacked to a previous savepoint")] - [IssueLink("https://github.com/npgsql/npgsql/issues/363")] - [IssueLink("https://github.com/npgsql/npgsql/issues/184")] - public async Task FailedTransactionCantRollbackToSavepointWithCustomTimeout() + using var cmd = new NpgsqlCommand("SELECT unknown_thing", conn); + cmd.CommandTimeout = 1; + try + { + cmd.ExecuteScalar(); + } + catch (PostgresException) { - await using var conn = await OpenConnectionAsync(); + transaction.Rollback("TestSavePoint"); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + } + } - var transaction = conn.BeginTransaction(); - transaction.Save("TestSavePoint"); + [Test, Description("Closes a (pooled) connection with a failed transaction and a custom timeout")] + [IssueLink("https://github.com/npgsql/npgsql/issues/719")] + public async Task Failed_transaction_on_close_with_custom_timeout() + { + await using var dataSource = CreateDataSource(csb => csb.Pooling = true); - using var cmd = new NpgsqlCommand("SELECT unknown_thing", conn); - cmd.CommandTimeout = 1; - try - { - cmd.ExecuteScalar(); - } - catch (PostgresException) - { - transaction.Rollback("TestSavePoint"); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } - } + await using var conn = await dataSource.OpenConnectionAsync(); - [Test, Description("Closes a (pooled) connection with a failed transaction and a custom timeout")] - [IssueLink("https://github.com/npgsql/npgsql/issues/719")] - public async Task FailedTransactionOnCloseWithCustomTimeout() + conn.BeginTransaction(); + var backendProcessId = conn.ProcessID; + using (var badCmd = new NpgsqlCommand("SEL", conn)) { - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = true - }.ToString(); + badCmd.CommandTimeout = NpgsqlCommand.DefaultTimeout + 1; + Assert.That(() => badCmd.ExecuteNonQuery(), Throws.Exception.TypeOf()); + } + // Connection now in failed transaction state, and a custom timeout is in place + conn.Close(); + conn.Open(); + conn.BeginTransaction(); + Assert.That(conn.ProcessID, Is.EqualTo(backendProcessId)); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + } - await using var conn = await OpenConnectionAsync(connString); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/555")] + public async Task Transaction_on_recycled_connection() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing: fails"); + + // Use application name to make sure we have our very own private connection pool + await using var conn = new NpgsqlConnection(ConnectionString + $";Application Name={GetUniqueIdentifier(nameof(Transaction_on_recycled_connection))}"); + conn.Open(); + var prevConnectorId = conn.Connector!.Id; + conn.Close(); + conn.Open(); + Assert.That(conn.Connector.Id, Is.EqualTo(prevConnectorId), "Connection pool returned a different connector, can't test"); + var tx = conn.BeginTransaction(); + conn.ExecuteScalar("SELECT 1"); + await tx.CommitAsync(); + NpgsqlConnection.ClearPool(conn); + } - conn.BeginTransaction(); - var backendProcessId = conn.ProcessID; - using (var badCmd = new NpgsqlCommand("SEL", conn)) - { - badCmd.CommandTimeout = NpgsqlCommand.DefaultTimeout + 1; - Assert.That(() => badCmd.ExecuteNonQuery(), Throws.Exception.TypeOf()); - } - // Connection now in failed transaction state, and a custom timeout is in place - conn.Close(); - conn.Open(); - conn.BeginTransaction(); - Assert.That(conn.ProcessID, Is.EqualTo(backendProcessId)); - Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - } + [Test] + public async Task Savepoint() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + const string name = "theSavePoint"; - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/555")] - public async Task TransactionOnRecycledConnection() + using (var tx = conn.BeginTransaction()) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); + tx.Save(name); + + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); + tx.Rollback(name); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(0)); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); + tx.Release(name); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); - // Use application name to make sure we have our very own private connection pool - await using var conn = new NpgsqlConnection(ConnectionString + $";Application Name={GetUniqueIdentifier(nameof(TransactionOnRecycledConnection))}"); - conn.Open(); - var prevConnectorId = conn.Connector!.Id; - conn.Close(); - conn.Open(); - Assert.That(conn.Connector.Id, Is.EqualTo(prevConnectorId), "Connection pool returned a different connector, can't test"); - var tx = conn.BeginTransaction(); - conn.ExecuteScalar("SELECT 1"); await tx.CommitAsync(); - NpgsqlConnection.ClearPool(conn); } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + } - [Test] - public async Task Savepoint() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - const string name = "theSavePoint"; + [Test] + public async Task Savepoint_async() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + const string name = "theSavePoint"; - using (var tx = conn.BeginTransaction()) - { - tx.Save(name); + using (var tx = conn.BeginTransaction()) + { + await tx.SaveAsync(name); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); - tx.Rollback(name); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(0)); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); - tx.Release(name); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); + await tx.RollbackAsync(name); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(0)); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); + await tx.ReleaseAsync(name); + Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); - await tx.CommitAsync(); - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + await tx.CommitAsync(); } + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + } - [Test] - public async Task SavepointAsync() - { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - const string name = "theSavePoint"; + [Test] + public async Task Savepoint_quoted() + { + await using var conn = await OpenConnectionAsync(); + await using var tx = conn.BeginTransaction(); + tx.Save("a;b"); + tx.Rollback("a;b"); + } - using (var tx = conn.BeginTransaction()) - { - await tx.SaveAsync(name); + [Test(Description = "Makes sure that creating a savepoint doesn't perform an additional roundtrip, but prepends to the next command")] + public async Task Savepoint_prepends() + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var pgMock = await postmasterMock.WaitForServerConnection(); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); - await tx.RollbackAsync(name); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(0)); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('savepointtest')", tx: tx); - await tx.ReleaseAsync(name); - Assert.That(conn.ExecuteScalar($"SELECT COUNT(*) FROM {table}", tx: tx), Is.EqualTo(1)); + using var tx = conn.BeginTransaction(); + var saveTask = tx.SaveAsync("foo"); + Assert.That(saveTask.Status, Is.EqualTo(TaskStatus.RanToCompletion)); - await tx.CommitAsync(); - } - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); - } + // If we're here, SaveAsync above didn't wait for any response, which is the right behavior - [Test] - public async Task SavepointQuoted() - { - await using var conn = await OpenConnectionAsync(); - await using var tx = conn.BeginTransaction(); - tx.Save("a;b"); - tx.Rollback("a;b"); - } + await pgMock + .WriteCommandComplete() + .WriteReadyForQuery() // BEGIN response + .WriteCommandComplete() + .WriteReadyForQuery() // SAVEPOINT response + .WriteScalarResponseAndFlush(1); - [Test(Description = "Makes sure that creating a savepoint doesn't perform an additional roundtrip, but prepends to the next command")] - public async Task SavepointPrepends() - { - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); - using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); - await using var conn = await OpenConnectionAsync(connectionString); - var pgMock = await postmasterMock.WaitForServerConnection(); + await conn.ExecuteScalarAsync("SELECT 1"); - using var tx = conn.BeginTransaction(); - var saveTask = tx.SaveAsync("foo"); - Assert.That(saveTask.Status, Is.EqualTo(TaskStatus.RanToCompletion)); + await pgMock.ExpectSimpleQuery("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); + await pgMock.ExpectSimpleQuery("SAVEPOINT foo"); + await pgMock.ExpectExtendedQuery(); + } - // If we're here, SaveAsync above didn't wait for any response, which is the right behavior + [Test, Description("Check IsCompleted before, during and after a normal committed transaction")] + [IssueLink("https://github.com/npgsql/npgsql/issues/985")] + public async Task IsCompleted_commit() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + var tx = conn.BeginTransaction(); + Assert.That(!tx.IsCompleted); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); + Assert.That(!tx.IsCompleted); + await tx.CommitAsync(); + Assert.That(tx.IsCompleted); + } - await pgMock - .WriteCommandComplete() - .WriteReadyForQuery() // BEGIN response - .WriteCommandComplete() - .WriteReadyForQuery() // SAVEPOINT response - .WriteScalarResponseAndFlush(1); + [Test, Description("Check IsCompleted before, during, and after a successful but rolled back transaction")] + [IssueLink("https://github.com/npgsql/npgsql/issues/985")] + public async Task IsCompleted_rollback() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + var tx = conn.BeginTransaction(); + Assert.That(!tx.IsCompleted); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); + Assert.That(!tx.IsCompleted); + tx.Rollback(); + Assert.That(tx.IsCompleted); + } - await conn.ExecuteScalarAsync("SELECT 1"); + [Test, Description("Check IsCompleted before, during, and after a failed then rolled back transaction")] + [IssueLink("https://github.com/npgsql/npgsql/issues/985")] + public async Task IsCompleted_rollback_failed() + { + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + var tx = conn.BeginTransaction(); + Assert.That(!tx.IsCompleted); + await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); + Assert.That(!tx.IsCompleted); + Assert.That(async () => await conn.ExecuteNonQueryAsync("BAD QUERY"), Throws.Exception.TypeOf()); + Assert.That(!tx.IsCompleted); + tx.Rollback(); + Assert.That(tx.IsCompleted); + Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + } - await pgMock.ExpectSimpleQuery("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); - await pgMock.ExpectSimpleQuery("SAVEPOINT foo"); - await pgMock.ExpectExtendedQuery(); - } + [Test, Description("Tests that a if a DatabaseInfoFactory is registered for a database that doesn't support transactions, no transactions are created")] + [Parallelizable(ParallelScope.None)] + public async Task Transaction_not_supported() + { + // TODO: rewrite to DataSource + if (IsMultiplexing) + Assert.Ignore("Need to rethink/redo dummy transaction mode"); - [Test, Description("Check IsCompleted before, during and after a normal committed transaction")] - [IssueLink("https://github.com/npgsql/npgsql/issues/985")] - public async Task IsCompletedCommit() + var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var tx = conn.BeginTransaction(); - Assert.That(!tx.IsCompleted); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); - Assert.That(!tx.IsCompleted); - await tx.CommitAsync(); - Assert.That(tx.IsCompleted); - } + ApplicationName = nameof(Transaction_not_supported) + IsMultiplexing + }.ToString(); - [Test, Description("Check IsCompleted before, during, and after a successful but rolled back transaction")] - [IssueLink("https://github.com/npgsql/npgsql/issues/985")] - public async Task IsCompletedRollback() + NpgsqlDatabaseInfo.RegisterFactory(new NoTransactionDatabaseInfoFactory()); + try { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var tx = conn.BeginTransaction(); - Assert.That(!tx.IsCompleted); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); - Assert.That(!tx.IsCompleted); - tx.Rollback(); - Assert.That(tx.IsCompleted); - } + using var conn = new NpgsqlConnection(connString); + await conn.OpenAsync(); + using var tx = conn.BeginTransaction(); - [Test, Description("Check IsCompleted before, during, and after a failed then rolled back transaction")] - [IssueLink("https://github.com/npgsql/npgsql/issues/985")] - public async Task IsCompletedRollbackFailed() + // Detect that we're not really in a transaction + var prevTxId = conn.ExecuteScalar("SELECT txid_current()"); + var nextTxId = conn.ExecuteScalar("SELECT txid_current()"); + // If we're in an actual transaction, the two IDs should be the same + // https://stackoverflow.com/questions/1651219/how-to-check-for-pending-operations-in-a-postgresql-transaction + Assert.That(nextTxId, Is.Not.EqualTo(prevTxId)); + conn.Close(); + } + finally { - await using var conn = await OpenConnectionAsync(); - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var tx = conn.BeginTransaction(); - Assert.That(!tx.IsCompleted); - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')", tx: tx); - Assert.That(!tx.IsCompleted); - Assert.That(async () => await conn.ExecuteNonQueryAsync("BAD QUERY"), Throws.Exception.TypeOf()); - Assert.That(!tx.IsCompleted); - tx.Rollback(); - Assert.That(tx.IsCompleted); - Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); + NpgsqlDatabaseInfo.ResetFactories(); } - [Test, Description("Tests that a if a DatabaseInfoFactory is registered for a database that doesn't support transactions, no transactions are created")] - [Parallelizable(ParallelScope.None)] - public async Task TransactionNotSupported() + using (var conn = new NpgsqlConnection(connString)) { - if (IsMultiplexing) - Assert.Ignore("Need to rethink/redo dummy transaction mode"); - - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(TransactionNotSupported) + IsMultiplexing - }.ToString(); - - NpgsqlDatabaseInfo.RegisterFactory(new NoTransactionDatabaseInfoFactory()); - try - { - using var conn = await OpenConnectionAsync(connString); - using var tx = conn.BeginTransaction(); - - // Detect that we're not really in a transaction - var prevTxId = conn.ExecuteScalar("SELECT txid_current()"); - var nextTxId = conn.ExecuteScalar("SELECT txid_current()"); - // If we're in an actual transaction, the two IDs should be the same - // https://stackoverflow.com/questions/1651219/how-to-check-for-pending-operations-in-a-postgresql-transaction - Assert.That(nextTxId, Is.Not.EqualTo(prevTxId)); - conn.Close(); - } - finally - { - NpgsqlDatabaseInfo.ResetFactories(); - } - - using (var conn = await OpenConnectionAsync(connString)) - { - NpgsqlConnection.ClearPool(conn); - conn.ReloadTypes(); - } + await conn.OpenAsync(); + NpgsqlConnection.ClearPool(conn); + conn.ReloadTypes(); + } - // Check that everything is back to normal - using (var conn = await OpenConnectionAsync(connString)) + // Check that everything is back to normal + using (var conn = new NpgsqlConnection(connString)) + { + await conn.OpenAsync(); using (var tx = conn.BeginTransaction()) { var prevTxId = conn.ExecuteScalar("SELECT txid_current()"); @@ -498,106 +553,199 @@ public async Task TransactionNotSupported() Assert.That(nextTxId, Is.EqualTo(prevTxId)); } } + } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] - // More at #3254 - public async Task Bug3248DisposeTransactionRollback() - { - if (!IsMultiplexing) - return; - - using var conn = await OpenConnectionAsync(); - await using (var tx = conn.BeginTransaction()) - { - Assert.That(conn.Connector, Is.Not.Null); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM \"unknown_table\"", tx: tx), - Throws.Exception.TypeOf()); - Assert.That(conn.Connector, Is.Not.Null); - } - - Assert.That(conn.Connector, Is.Null); - } + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] + // More at #3254 + public async Task Bug3248_Dispose_transaction_Rollback() + { + if (!IsMultiplexing) + return; - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] - // More at #3254 - public async Task Bug3248DisposeConnectionRollback() + using var conn = await OpenConnectionAsync(); + await using (var tx = await conn.BeginTransactionAsync()) { - if (!IsMultiplexing) - return; - - var conn = await OpenConnectionAsync(); - var tx = conn.BeginTransaction(); Assert.That(conn.Connector, Is.Not.Null); Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM \"unknown_table\"", tx: tx), Throws.Exception.TypeOf()); Assert.That(conn.Connector, Is.Not.Null); - - await conn.DisposeAsync(); - Assert.That(conn.Connector, Is.Null); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3306")] - [TestCase(true)] - [TestCase(false)] - public async Task Bug3306(bool inTransactionBlock) - { - var conn = await OpenConnectionAsync(); - var tx = await conn.BeginTransactionAsync(); - await conn.ExecuteNonQueryAsync("SELECT 1", tx); - if (!inTransactionBlock) - await tx.RollbackAsync(); - await conn.CloseAsync(); + Assert.That(conn.Connector, Is.Null); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] + // More at #3254 + public async Task Bug3248_Dispose_connection_Rollback() + { + if (!IsMultiplexing) + return; + + var conn = await OpenConnectionAsync(); + var tx = conn.BeginTransaction(); + Assert.That(conn.Connector, Is.Not.Null); + Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM \"unknown_table\"", tx: tx), + Throws.Exception.TypeOf()); + Assert.That(conn.Connector, Is.Not.Null); + + await conn.DisposeAsync(); + Assert.That(conn.Connector, Is.Null); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/3306")] + [TestCase(true)] + [TestCase(false)] + public async Task Bug3306(bool inTransactionBlock) + { + var conn = await OpenConnectionAsync(); + var tx = await conn.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync("SELECT 1", tx); + if (!inTransactionBlock) + await tx.RollbackAsync(); + await conn.CloseAsync(); - conn = await OpenConnectionAsync(); - var tx2 = await conn.BeginTransactionAsync(); + conn = await OpenConnectionAsync(); + var tx2 = await conn.BeginTransactionAsync(); - await tx.DisposeAsync(); + await tx.DisposeAsync(); - Assert.That(tx.IsDisposed, Is.True); - Assert.That(tx2.IsDisposed, Is.False); + Assert.That(tx.IsDisposed, Is.True); + Assert.That(tx2.IsDisposed, Is.False); - await conn.DisposeAsync(); - } + await conn.DisposeAsync(); + } + + [Test, IssueLink("https://github.com/npgsql/efcore.pg/issues/1593")] + public async Task Access_connection_on_completed_transaction() + { + using var conn = await OpenConnectionAsync(); + using var tx = await conn.BeginTransactionAsync(); + tx.Commit(); + Assert.That(tx.Connection, Is.SameAs(conn)); + } + + [Test] + public async Task Unbound_transaction_reuse() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MinPoolSize = 1; + csb.MaxPoolSize = 1; + }); - class NoTransactionDatabaseInfoFactory : INpgsqlDatabaseInfoFactory + await using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "name TEXT"); + + await using var conn1 = await dataSource.OpenConnectionAsync(); + var tx1 = conn1.BeginTransaction(); + await using (var ___ = tx1) { - public async Task Load(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async) + using var cmd1 = conn1.CreateCommand(); + cmd1.CommandText = $"INSERT INTO {table} (name) VALUES ('X'); SELECT 1"; + await using (var reader1 = await cmd1.ExecuteReaderAsync()) { - var db = new NoTransactionDatabaseInfo(conn); - await db.LoadPostgresInfo(conn, timeout, async); - return db; + Assert.That(async () => await reader1.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader1.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader1.RecordsAffected, Is.EqualTo(1)); } + await tx1.CommitAsync(); + Assert.That(await conn1.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); + await conn1.CloseAsync(); } - class NoTransactionDatabaseInfo : PostgresDatabaseInfo + await using var conn2 = await dataSource.OpenConnectionAsync(); + var tx2 = conn2.BeginTransaction(); + await using (var ___ = tx2) { - public override bool SupportsTransactions => false; + Assert.That(tx2, Is.Not.SameAs(tx1)); + using var cmd2 = conn2.CreateCommand(); + cmd2.CommandText = $"INSERT INTO {table} (name) VALUES ('Y'); SELECT 2"; + await using (var reader2 = await cmd2.ExecuteReaderAsync()) + { + Assert.That(async () => await reader2.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader2.GetInt32(0), Is.EqualTo(2)); + Assert.That(reader2.RecordsAffected, Is.EqualTo(1)); + } + await tx2.CommitAsync(); + Assert.That(await conn2.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(2)); + await conn2.CloseAsync(); + } - internal NoTransactionDatabaseInfo(NpgsqlConnection conn) : base(conn) {} + await using var conn3 = await dataSource.OpenConnectionAsync(); + var tx3 = conn3.BeginTransaction(); + await using (var ___ = tx3) + { + Assert.That(tx3, Is.SameAs(tx1)); + using var cmd3 = conn3.CreateCommand(); + cmd3.CommandText = $"INSERT INTO {table} (name) VALUES ('Z'); SELECT 3"; + await using (var reader3 = await cmd3.ExecuteReaderAsync()) + { + Assert.That(async () => await reader3.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader3.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader3.RecordsAffected, Is.EqualTo(1)); + } + await tx3.CommitAsync(); + Assert.That(await conn3.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(3)); + await conn3.CloseAsync(); } + } - // Older tests + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3686")] + public async Task Bug3686() + { + if (IsMultiplexing) + return; + + await using var dataSource = CreateDataSource(csb => csb.Pooling = false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync("SELECT 1", tx); + await tx.CommitAsync(); + await conn.CloseAsync(); + Assert.DoesNotThrow(() => + { + _ = tx.Connection; + }); + } - [Test] - public void Bug184RollbackFailsOnAbortedTransaction() + class NoTransactionDatabaseInfoFactory : INpgsqlDatabaseInfoFactory + { + public async Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - csb.CommandTimeout = 100000; - - using var connTimeoutChanged = new NpgsqlConnection(csb.ToString()); - connTimeoutChanged.Open(); - using var t = connTimeoutChanged.BeginTransaction(); - try { - var command = new NpgsqlCommand("select count(*) from dta", connTimeoutChanged, t); - _ = command.ExecuteScalar(); - } catch (Exception) { - t.Rollback(); - } + var db = new NoTransactionDatabaseInfo(conn); + await db.LoadPostgresInfo(conn, timeout, async); + return db; } + } - public TransactionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + class NoTransactionDatabaseInfo : PostgresDatabaseInfo + { + public override bool SupportsTransactions => false; + + internal NoTransactionDatabaseInfo(NpgsqlConnector conn) : base(conn) {} } + + // Older tests + + [Test] + public void Bug184_Rollback_fails_on_aborted_transaction() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); + csb.CommandTimeout = 100000; + + using var connTimeoutChanged = new NpgsqlConnection(csb.ToString()); + connTimeoutChanged.Open(); + using var t = connTimeoutChanged.BeginTransaction(); + try { + var command = new NpgsqlCommand("select count(*) from dta", connTimeoutChanged, t); + _ = command.ExecuteScalar(); + } catch (Exception) { + t.Rollback(); + } + } + + public TransactionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/TypeMapperTests.cs b/test/Npgsql.Tests/TypeMapperTests.cs index ab3fb35e2c..d0d1e36587 100644 --- a/test/Npgsql.Tests/TypeMapperTests.cs +++ b/test/Npgsql.Tests/TypeMapperTests.cs @@ -1,258 +1,115 @@ -using System; -using System.Data; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using Npgsql.PostgresTypes; -using Npgsql.TypeHandlers; -using Npgsql.TypeHandlers.NumericHandlers; -using Npgsql.TypeHandling; -using Npgsql.TypeMapping; -using NpgsqlTypes; +using Npgsql.Internal; using NUnit.Framework; +using System; +using System.Threading.Tasks; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +public class TypeMapperTests : TestBase { - [NonParallelizable] - public class TypeMapperTests : TestBase + [Test] + public async Task ReloadTypes_across_connections_in_data_source() { - [Test] - public void GlobalMapping() - { - var myFactory = MapMyIntGlobally(); - using (var pool = CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = OpenConnection(connectionString)) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var range = new NpgsqlRange(8, true, false, 0, false, true); - var parameters = new[] - { - // Base - new NpgsqlParameter("p", NpgsqlDbType.Integer) { Value = 8 }, - new NpgsqlParameter("p", DbType.Int32) { Value = 8 }, - new NpgsqlParameter { ParameterName = "p", Value = 8 }, - // Array - new NpgsqlParameter { ParameterName = "p", Value = new[] { 8 } }, - new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Integer) { Value = new[] { 8 } }, - // Range - new NpgsqlParameter { ParameterName = "p", Value = range }, - new NpgsqlParameter("p", NpgsqlDbType.Range | NpgsqlDbType.Integer) { Value = range }, - }; + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + // Note that we don't actually create the type in the database at this point; we want to exercise the type being created later, + // via the data source. + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection1 = await dataSource.OpenConnectionAsync(); + await using var connection2 = await dataSource.OpenConnectionAsync(); + + await connection1.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection1.ReloadTypesAsync(); + + // The data source type mapper has been replaced and connection1 should have the new mapper, but connection2 should retain the older + // type mapper - where there's no mapping - as long as it's still open + Assert.ThrowsAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + Assert.DoesNotThrowAsync(async () => await connection1.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + + // Close connection2 and reopen to make sure it picks up the new type and mapping from the data source + var connId = connection2.ProcessID; + await connection2.CloseAsync(); + await connection2.OpenAsync(); + Assert.That(connection2.ProcessID, Is.EqualTo(connId), "Didn't get the same connector back"); + + Assert.DoesNotThrowAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + } - for (var i = 0; i < parameters.Length; i++) - { - cmd.Parameters.Add(parameters[i]); - cmd.ExecuteScalar(); - Assert.That(myFactory.Reads, Is.EqualTo(i+1)); - Assert.That(myFactory.Writes, Is.EqualTo(i+1)); - cmd.Parameters.Clear(); - } - } - } + [Test] + [NonParallelizable] // Depends on citext which could be dropped concurrently + public async Task String_to_citext() + { + await using var adminConnection = await OpenConnectionAsync(); + await EnsureExtensionAsync(adminConnection, "citext"); - [Test] - public void LocalMapping() - { - MyInt32HandlerFactory myFactory; - using var _ = CreateTempPool(ConnectionString, out var connectionString); + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.AddTypeInfoResolverFactory(new CitextToStringTypeHandlerResolverFactory()); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - using (var conn = OpenConnection(connectionString)) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - myFactory = MapMyIntLocally(conn); - cmd.Parameters.AddWithValue("p", 8); - cmd.ExecuteScalar(); - Assert.That(myFactory.Reads, Is.EqualTo(1)); - Assert.That(myFactory.Writes, Is.EqualTo(1)); - } + await using var command = new NpgsqlCommand("SELECT @p = 'hello'::citext", connection); + command.Parameters.AddWithValue("p", "HeLLo"); + Assert.That(command.ExecuteScalar(), Is.True); + } - // Make sure reopening (same physical connection) reverts the mapping - using (var conn = OpenConnection(connectionString)) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", 8); - cmd.ExecuteScalar(); - Assert.That(myFactory.Reads, Is.EqualTo(1)); - Assert.That(myFactory.Writes, Is.EqualTo(1)); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4582")] + [NonParallelizable] // Drops extension + public async Task Type_in_non_default_schema() + { + await using var conn = await OpenConnectionAsync(); - [Test] - public void RemoveGlobalMapping() - { - NpgsqlConnection.GlobalTypeMapper.RemoveMapping("integer"); - using var _ = CreateTempPool(ConnectionString, out var connectionString); - using var conn = OpenConnection(connectionString); - Assert.That(() => conn.ExecuteScalar("SELECT 8"), Throws.TypeOf()); - } + var schemaName = await CreateTempSchema(conn); - [Test] - public void RemoveLocalMapping() - { - using var _ = CreateTempPool(ConnectionString, out var connectionString); - using (var conn = OpenConnection(connectionString)) - { - conn.TypeMapper.RemoveMapping("integer"); - Assert.That(() => conn.ExecuteScalar("SELECT 8"), Throws.TypeOf()); - } - // Make sure reopening (same physical connection) reverts the mapping - using (var conn = OpenConnection(connectionString)) - Assert.That(conn.ExecuteScalar("SELECT 8"), Is.EqualTo(8)); - } + await conn.ExecuteNonQueryAsync(@$" +DROP EXTENSION IF EXISTS citext; +CREATE EXTENSION citext SCHEMA ""{schemaName}"""); - [Test] - public void GlobalReset() + try { - var myFactory = MapMyIntGlobally(); - using var _ = CreateTempPool(ConnectionString, out var connectionString); + await conn.ReloadTypesAsync(); - using (OpenConnection(connectionString)) {} - // We now have a connector in the pool with our custom mapping + var tableName = await CreateTempTable(conn, $"created_by {schemaName}.citext NOT NULL"); - NpgsqlConnection.GlobalTypeMapper.Reset(); - using (var conn = OpenConnection(connectionString)) - { - // Should be the pooled connector from before, but it should have picked up the reset - conn.ExecuteScalar("SELECT 1"); - Assert.That(myFactory.Reads, Is.Zero); + const string expected = "SomeValue"; + await conn.ExecuteNonQueryAsync($"INSERT INTO \"{tableName}\" VALUES('{expected}')"); - // Now create a second *physical* connection to make sure it picks up the new mapping as well - using (var conn2 = OpenConnection(connectionString)) - { - conn2.ExecuteScalar("SELECT 1"); - Assert.That(myFactory.Reads, Is.Zero); - } - NpgsqlConnection.ClearPool(conn); - } + var value = (string?)await conn.ExecuteScalarAsync($"SELECT created_by FROM \"{tableName}\" LIMIT 1"); + Assert.That(value, Is.EqualTo(expected)); } - - [Test] - public void DomainMappingNotSupported() + finally { - // PostgreSQL sends RowDescription with the OID of the base type, not the domain, - // it's not possible to map domains - using (CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = OpenConnection(connectionString)) - { - conn.ExecuteNonQuery(@"CREATE DOMAIN pg_temp.us_postal_code AS TEXT -CHECK -( - VALUE ~ '^\d{5}$' - OR VALUE ~ '^\d{5}-\d{4}$' -); -"); - conn.ReloadTypes(); - Assert.That(() => conn.TypeMapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "us_postal_code", - TypeHandlerFactory = new DummyTypeHandlerFactory() - }.Build()), Throws.TypeOf()); - } + await conn.ExecuteNonQueryAsync(@"DROP EXTENSION citext CASCADE"); } + } - class DummyTypeHandlerFactory : NpgsqlTypeHandlerFactory - { - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => throw new Exception(); - } + #region Support - [Test] - public void MandatoryMappingFields() - { - Assert.That(() => new NpgsqlTypeMappingBuilder().Build(), Throws.ArgumentException); - Assert.That(() => new NpgsqlTypeMappingBuilder{ PgTypeName = "foo" }.Build(), Throws.ArgumentException); - } + class CitextToStringTypeHandlerResolverFactory : PgTypeInfoResolverFactory + { + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver? CreateArrayResolver() => null; - [Test] - public async Task StringToCitext() + sealed class Resolver : IPgTypeInfoResolver { - using (CreateTempPool(ConnectionString, out var connectionString)) - using (var conn = OpenConnection(connectionString)) + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { - await EnsureExtensionAsync(conn, "citext"); - - conn.TypeMapper.RemoveMapping("text"); - conn.TypeMapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "citext", - NpgsqlDbType = NpgsqlDbType.Citext, - DbTypes = new[] { DbType.String }, - ClrTypes = new[] { typeof(string) }, - TypeHandlerFactory = new TextHandlerFactory() - }.Build()); + if (type == typeof(string) || dataTypeName?.UnqualifiedName == "citext") + if (options.DatabaseInfo.TryGetPostgresTypeByName("citext", out var pgType)) + return new(options, new StringTextConverter(options.TextEncoding), options.ToCanonicalTypeId(pgType)); - using (var cmd = new NpgsqlCommand("SELECT @p = 'hello'::citext", conn)) - { - cmd.Parameters.AddWithValue("p", "HeLLo"); - Assert.That(cmd.ExecuteScalar(), Is.True); - } + return null; } } - #region Support - - MyInt32HandlerFactory MapMyIntGlobally() - { - var myFactory = new MyInt32HandlerFactory(); - NpgsqlConnection.GlobalTypeMapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "integer", - NpgsqlDbType = NpgsqlDbType.Integer, - DbTypes = new[] { DbType.Int32 }, - ClrTypes = new[] { typeof(int) }, - TypeHandlerFactory = myFactory - }.Build()); - return myFactory; - } - - MyInt32HandlerFactory MapMyIntLocally(NpgsqlConnection conn) - { - var myFactory = new MyInt32HandlerFactory(); - conn.TypeMapper.AddMapping(new NpgsqlTypeMappingBuilder - { - PgTypeName = "integer", - NpgsqlDbType = NpgsqlDbType.Integer, - DbTypes = new[] { DbType.Int32 }, - ClrTypes = new[] { typeof(int) }, - TypeHandlerFactory = myFactory - }.Build()); - return myFactory; - } - - class MyInt32HandlerFactory : NpgsqlTypeHandlerFactory - { - internal int Reads, Writes; - - public override NpgsqlTypeHandler Create(PostgresType postgresType, NpgsqlConnection conn) - => new MyInt32Handler(postgresType, this); - } - - class MyInt32Handler : Int32Handler - { - readonly MyInt32HandlerFactory _factory; - - public MyInt32Handler(PostgresType postgresType, MyInt32HandlerFactory factory) - : base(postgresType) - { - _factory = factory; - } - - public override int Read(NpgsqlReadBuffer buf, int len, FieldDescription? fieldDescription = null) - { - _factory.Reads++; - return base.Read(buf, len, fieldDescription); - } - - public override void Write(int value, NpgsqlWriteBuffer buf, NpgsqlParameter? parameter) - { - _factory.Writes++; - base.Write(value, buf, parameter); - } - } + } - #endregion Support + enum Mood { Sad, Ok, Happy } - [TearDown] - public void TearDown() => NpgsqlConnection.GlobalTypeMapper.Reset(); - } + #endregion Support } diff --git a/test/Npgsql.Tests/Types/ArrayTests.cs b/test/Npgsql.Tests/Types/ArrayTests.cs index c4f75b5156..a567e4891e 100644 --- a/test/Npgsql.Tests/Types/ArrayTests.cs +++ b/test/Npgsql.Tests/Types/ArrayTests.cs @@ -1,599 +1,420 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Immutable; using System.Data; using System.Linq; using System.Text; -using System.Threading; using System.Threading.Tasks; -using Npgsql.TypeHandlers; +using Npgsql.Internal.Converters; using NpgsqlTypes; using NUnit.Framework; -using NUnit.Framework.Internal; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +// ReSharper disable BitwiseOperatorOnEnumWithoutFlags + +/// +/// Tests on PostgreSQL arrays +/// +/// +/// https://www.postgresql.org/docs/current/static/arrays.html +/// +public class ArrayTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL arrays - /// - /// - /// https://www.postgresql.org/docs/current/static/arrays.html - /// - public class ArrayTests : MultiplexingTestBase + static readonly TestCaseData[] ArrayTestCases = { - [Test, Description("Resolves an array type handler via the different pathways")] - public async Task ArrayTypeResolution() + new TestCaseData(new[] { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + .SetName("Integer_array"), + new TestCaseData(Array.Empty(), "{}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + .SetName("Empty_array"), + new TestCaseData(new[,] { { 1, 2, 3 }, { 7, 8, 9 } }, "{{1,2,3},{7,8,9}}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + .SetName("Two_dimensional_array"), + new TestCaseData(new[] { new byte[] { 1, 2 }, new byte[] { 3, 4 } }, """{"\\x0102","\\x0304"}""", "bytea[]", NpgsqlDbType.Bytea | NpgsqlDbType.Array) + .SetName("Bytea_array") + }; + + [Test, TestCaseSource(nameof(ArrayTestCases))] + public Task Arrays(T array, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) + => AssertType(array, sqlLiteral, pgTypeName, npgsqlDbType); + + [Test] + public async Task NullableInts() + { + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(ConnectionString) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(ArrayTypeResolution), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = await OpenConnectionAsync(csb)) - { - // Resolve type by NpgsqlDbType - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Array | NpgsqlDbType.Integer, DBNull.Value); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer[]")); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = new int[0] }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer[]")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT '{1, 3}'::INTEGER[]", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer[]")); - } - } - } + ArrayNullabilityMode = ArrayNullabilityMode.Always + }; + var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionStringBuilder.ToString()); + await using var dataSource = dataSourceBuilder.Build(); - [Test, Description("Roundtrips a simple, one-dimensional array of ints")] - public async Task Ints() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var expected = new[] { 1, 5, 9 }; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer); - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - var p3 = new NpgsqlParameter("p3", expected); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = expected; - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.TypeOf()); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof(Array))); - } - } - } + await AssertType(dataSource, new int?[] { 1, 2, null, 3 }, "{1,2,NULL,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); + } - [Test, Description("Roundtrips a simple, one-dimensional array of int? values")] - public async Task NullableInts() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn); - - var expected = new int?[] { 1, 5, null, 9 }; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer); - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - var p3 = new NpgsqlParameter("p3", expected); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = expected; - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + [Test, Description("Checks that PG arrays containing nulls can't be read as CLR arrays of non-nullable value types (the default).")] + public async Task Nullable_ints_cannot_be_read_as_non_nullable() + => await AssertTypeUnsupportedRead("{1,NULL,2}", "int[]"); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected.ToList())); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof(Array))); - } - } + [Test] + public async Task Throws_too_many_dimensions() + { + await using var conn = CreateConnection(); + await conn.OpenAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + cmd.Parameters.AddWithValue("p", new int[1, 1, 1, 1, 1, 1, 1, 1, 1]); // 9 dimensions + Assert.That( + () => cmd.ExecuteScalarAsync(), + Throws.Exception.TypeOf().With.Message.EqualTo("values (Parameter 'Postgres arrays can have at most 8 dimensions.')")); + } - [Test, Description("Checks that PG arrays containing nulls can't be read as CLR arrays of non-nullable value types.")] - public async Task NullableIntsCannotBeReadAsNonNullable() + [Test, Description("Checks that PG arrays containing nulls are returned as set via ValueTypeArrayMode.")] + [TestCase(ArrayNullabilityMode.Always)] + [TestCase(ArrayNullabilityMode.Never)] + [TestCase(ArrayNullabilityMode.PerInstance)] + public async Task Value_type_array_nullabilities(ArrayNullabilityMode mode) + { + await using var dataSource = CreateDataSource(csb => csb.ArrayNullabilityMode = mode); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( +""" +SELECT onedim, twodim FROM (VALUES +('{1, 2, 3, 4}'::int[],'{{1, 2},{3, 4}}'::int[][]), +('{5, NULL, 6, 7}'::int[],'{{5, NULL},{6, 7}}'::int[][])) AS x(onedim,twodim) +""", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + + switch (mode) { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT '{1, NULL, 2}'::integer[]", conn); - using var reader = await cmd.ExecuteReaderAsync(); + case ArrayNullabilityMode.Never: reader.Read(); - - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); - Assert.That(() => reader.GetFieldValue>(0), Throws.Exception.TypeOf()); - Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); - } - - [Test] - public async Task EmptyArray() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p", conn); - - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Integer) { Value = new int[0] }); - var reader = await cmd.ExecuteReaderAsync(); + var value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int[]))); + Assert.That(value, Is.EqualTo(new []{1, 2, 3, 4})); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(reader.GetValue(1).GetType(), Is.EqualTo(typeof(int[,]))); + Assert.That(reader.GetValue(1), Is.EqualTo(new [,]{{1, 2}, {3, 4}})); reader.Read(); - - Assert.That(reader.GetFieldValue(0), Is.SameAs(Array.Empty())); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(() => reader.GetValue(1), Throws.Exception.TypeOf()); + break; + case ArrayNullabilityMode.Always: + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[]))); + Assert.That(value, Is.EqualTo(new int?[]{1, 2, 3, 4})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[,]))); + Assert.That(value, Is.EqualTo(new int?[,]{{1, 2}, {3, 4}})); + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[]))); + Assert.That(value, Is.EqualTo(new int?[]{5, null, 6, 7})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[,]))); + Assert.That(value, Is.EqualTo(new int?[,]{{5, null},{6, 7}})); + break; + case ArrayNullabilityMode.PerInstance: + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int[]))); + Assert.That(value, Is.EqualTo(new []{1, 2, 3, 4})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int[,]))); + Assert.That(value, Is.EqualTo(new [,]{{1, 2}, {3, 4}})); + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[]))); + Assert.That(value, Is.EqualTo(new int?[]{5, null, 6, 7})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(int?[,]))); + Assert.That(value, Is.EqualTo(new int?[,]{{5, null},{6, 7}})); + break; + default: + throw new ArgumentOutOfRangeException(nameof(mode), mode, null); } + } - [Test, Description("Roundtrips an empty multi-dimensional array.")] - public async Task EmptyMultidimensionalArray() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p", conn); - - var expected = new int[0, 0]; - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Array | NpgsqlDbType.Integer, expected); + // Note that PG normalizes empty multidimensional arrays to single-dimensional, e.g. ARRAY[[], []]::integer[] returns {}. + [Test] + public async Task Write_empty_multidimensional_array() + => await AssertTypeWrite(new int[0, 0], "{}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + [Test] + public async Task Generic_List() + => await AssertType( + new List { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefaultForReading: false); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } + [Test] + public async Task Write_IList_implementation() + => await AssertTypeWrite( + ImmutableArray.Create(1, 2, 3), "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); - [Test, Description("Verifies that an InvalidOperationException is thrown when the returned array has a different number of dimensions from what was requested.")] - public async Task WrongArrayDimensions() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT ARRAY[[1], [2]]", conn); + [Test] + public void Read_IList_implementation_throws() + { + Assert.ThrowsAsync(() => + AssertTypeRead("{1,2,3}", "integer[]", ImmutableArray.Create(1, 2, 3), isDefault: false)); + } - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + [Test] + public async Task Generic_IList() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1", conn); - var ex = Assert.Throws(() => reader.GetFieldValue(0)); - Assert.That(ex.Message, Is.EqualTo("Cannot read an array with 1 dimension(s) from an array with 2 dimension(s)")); - } + var expected = ImmutableArray.Create(1,2,3); + cmd.Parameters.Add(new NpgsqlParameter>("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer) { TypedValue = expected }); - [Test, Description("Verifies that an attempt to read an Array of value types that contains null values as array of a non-nullable type fails.")] - public async Task ReadNullAsNonNullableArrayFails() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1", conn); + var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.AreEqual(expected, reader.GetFieldValue(0)); + } - var expected = new int?[] { 1, 5, null, 9 }; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, expected); + [Test, Description("Verifies that an InvalidOperationException is thrown when the returned array has a different number of dimensions from what was requested.")] + public async Task Wrong_array_dimensions_throws() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT ARRAY[[1], [2]]", conn); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); - Assert.That( - () => reader.GetFieldValue(0), - Throws.Exception.TypeOf() - .With.Message.EqualTo(ArrayHandler.ReadNonNullableCollectionWithNullsExceptionMessage)); - } + var ex = Assert.Throws(() => reader.GetFieldValue(0))!; + Assert.That(ex.Message, Does.StartWith("Cannot read an array value with 2 dimensions into a collection type with 1 dimension")); + } + [Test, Description("Verifies that an attempt to read an Array of value types that contains null values as array of a non-nullable type fails.")] + public async Task Read_null_as_non_nullable_array_throws() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1", conn); - [Test, Description("Verifies that an attempt to read an Array of value types that contains null values as List of a non-nullable type fails.")] - public async Task ReadNullAsNonNullableListFails() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1", conn); + var expected = new int?[] { 1, 5, null, 9 }; + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, expected); - var expected = new int?[] { 1, 5, null, 9 }; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, expected); + var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + Assert.That( + () => reader.GetFieldValue(0), + Throws.Exception.TypeOf() + .With.Message.EqualTo(PgArrayConverter.ReadNonNullableCollectionWithNullsExceptionMessage)); + } - Assert.That( - () => reader.GetFieldValue>(0), - Throws.Exception.TypeOf() - .With.Message.EqualTo(ArrayHandler.ReadNonNullableCollectionWithNullsExceptionMessage)); - } - [Test, Description("Roundtrips a large, one-dimensional array of ints that will be chunked")] - public async Task LongOneDimensional() - { - using (var conn = await OpenConnectionAsync()) - { - var expected = new int[conn.Settings.WriteBufferSize/4 + 100]; - for (var i = 0; i < expected.Length; i++) - expected[i] = i; - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var p = new NpgsqlParameter {ParameterName = "p", Value = expected}; - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(expected)); - } - } - } - } + [Test, Description("Verifies that an attempt to read an Array of value types that contains null values as List of a non-nullable type fails.")] + public async Task Read_null_as_non_nullable_list_throws() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1", conn); - [Test, Description("Roundtrips a large, two-dimensional array of ints that will be chunked")] - public async Task LongTwoDimensional() - { - using (var conn = await OpenConnectionAsync()) - { - var len = conn.Settings.WriteBufferSize/2 + 100; - var expected = new int[2, len]; - for (var i = 0; i < len; i++) - expected[0, i] = i; - for (var i = 0; i < len; i++) - expected[1, i] = i; - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var p = new NpgsqlParameter {ParameterName = "p", Value = expected}; - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(expected)); - } - } - } - } + var expected = new int?[] { 1, 5, null, 9 }; + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, expected); - [Test, Description("Roundtrips a long, one-dimensional array of strings, including a null")] - public async Task StringsWithNull() - { - using (var conn = await OpenConnectionAsync()) - { - var largeString = new StringBuilder(); - largeString.Append('a', conn.Settings.WriteBufferSize); - var expected = new[] {"value1", null, largeString.ToString(), "val3"}; - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Text) {Value = expected}; - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } - } - } - } + var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); - [Test, Description("Roundtrips a zero-dimensional array of ints, should return empty one-dimensional")] - public async Task ZeroDimensional() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var expected = new int[0]; - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Integer) { Value = expected }; - cmd.Parameters.Add(p); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - cmd.Dispose(); - } - } + Assert.That( + () => reader.GetFieldValue>(0), + Throws.Exception.TypeOf() + .With.Message.EqualTo(PgArrayConverter.ReadNonNullableCollectionWithNullsExceptionMessage)); + } - [Test, Description("Roundtrips a two-dimensional array of ints")] - public async Task TwoDimensionalInts() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var expected = new[,] { { 1, 2, 3 }, { 7, 8, 9 } }; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer); - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - p1.Value = expected; - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } - } + [Test, Description("Roundtrips a large, one-dimensional array of ints that will be chunked")] + public async Task Long_one_dimensional() + { + await using var conn = await OpenConnectionAsync(); - [Test, Description("Reads a one-dimensional array dates, both as DateTime and as the provider-specific NpgsqlDate")] - public async Task ReadProviderSpecificType() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand(@"SELECT '{ ""2014-01-04"", ""2014-01-08"" }'::DATE[]", conn)) - { - var expectedRegular = new[] { new DateTime(2014, 1, 4), new DateTime(2014, 1, 8) }; - var expectedPsv = new[] { new NpgsqlDate(2014, 1, 4), new NpgsqlDate(2014, 1, 8) }; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expectedRegular)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedRegular)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(expectedPsv)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedPsv)); - } - } - } + var expected = new int[conn.Settings.WriteBufferSize/4 + 100]; + for (var i = 0; i < expected.Length; i++) + expected[i] = i; - [Test, Description("Reads an one-dimensional array with lower bound != 0")] - public async Task ReadNonZeroLowerBounded() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT '[2:3]={ 8, 9 }'::INT[]", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new[] {8, 9})); - } - - using (var cmd = new NpgsqlCommand("SELECT '[2:3][2:3]={ {8,9}, {1,2} }'::INT[][]", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new[,] {{8, 9}, {1, 2}})); - } - } - } + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + var p = new NpgsqlParameter {ParameterName = "p", Value = expected}; + cmd.Parameters.Add(p); - [Test, Description("Roundtrips a one-dimensional array of bytea values")] - public async Task Byteas() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var expected = new[] { new byte[] { 1, 2 }, new byte[] { 3, 4, } }; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Array | NpgsqlDbType.Bytea); - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - p1.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); - } - } - } + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + reader.Read(); + Assert.That(reader[0], Is.EqualTo(expected)); + } + [Test, Description("Roundtrips a large, two-dimensional array of ints that will be chunked")] + public async Task Long_two_dimensional() + { + await using var conn = await OpenConnectionAsync(); + var len = conn.Settings.WriteBufferSize/2 + 100; + var expected = new int[2, len]; + for (var i = 0; i < len; i++) + expected[0, i] = i; + for (var i = 0; i < len; i++) + expected[1, i] = i; + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + var p = new NpgsqlParameter {ParameterName = "p", Value = expected}; + cmd.Parameters.Add(p); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + reader.Read(); + Assert.That(reader[0], Is.EqualTo(expected)); + } - [Test, Description("Roundtrips a non-generic IList as an array")] - // ReSharper disable once InconsistentNaming - public async Task IListNonGeneric() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var expected = new ArrayList(new[] { 1, 2, 3 }); - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Integer) { Value = expected }; - cmd.Parameters.Add(p); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(expected.ToArray())); - } - } + [Test, Description("Reads an one-dimensional array with lower bound != 0")] + public Task Read_non_zero_lower_bounded() + => AssertTypeRead("[2:3]={ 8, 9 }", "integer[]", new[] { 8, 9 }); - [Test, Description("Roundtrips a generic List as an array")] - // ReSharper disable once InconsistentNaming - public async Task IListGeneric() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var expected = new[] { 1, 2, 3 }.ToList(); - var p1 = new NpgsqlParameter { ParameterName = "p1", Value = expected }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue>(1), Is.EqualTo(expected)); - } - } - } + [Test, Description("Reads an one-dimensional array with lower bound != 0")] + public Task Read_non_zero_lower_bounded_multidimensional() + => AssertTypeRead("[2:3][2:3]={ {8,9}, {1,2} }", "integer[]", new[,] { { 8, 9 }, { 1, 2 }}); - [Test, Description("Tests for failure when reading a generic IList from a multidimensional array")] - // ReSharper disable once InconsistentNaming - public async Task IListGenericFailsForMultidimensionalArray() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - var expected = new[,] { { 1, 2 }, { 3, 4 } }; - var p1 = new NpgsqlParameter { ParameterName = "p1", Value = expected }; - cmd.Parameters.Add(p1); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - var exception = Assert.Throws(() => - { - reader.GetFieldValue>(0); - }); - Assert.That(exception.Message, Is.EqualTo("Can't read multidimensional array as List")); - } - } - } + [Test, Description("Roundtrips a long, one-dimensional array of strings, including a null")] + public async Task Strings_with_null() + { + await using var conn = await OpenConnectionAsync(); + var largeString = new StringBuilder(); + largeString.Append('a', conn.Settings.WriteBufferSize); + var expected = new[] {"value1", null, largeString.ToString(), "val3"}; + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Text) {Value = expected}; + cmd.Parameters.Add(p); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/844")] - public async Task IEnumerableThrowsFriendlyException() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - cmd.Parameters.AddWithValue("p1", Enumerable.Range(1, 3)); - Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf().With.Message.Contains("use .ToList()/.ToArray() instead")); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/844")] + public async Task Writing_IEnumerable_is_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1", conn); + cmd.Parameters.AddWithValue("p1", new EnumerableOnly()); + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf().With.Property("InnerException").Message.Contains("array or some implementation of IList")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")] - public async Task MixedElementTypes() - { - var mixedList = new ArrayList { 1, "yo" }; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, mixedList); - Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception - .TypeOf() - .With.Message.Contains("mix")); - } - } + class EnumerableOnly : IEnumerable + { + public IEnumerator GetEnumerator() => throw new NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")] - public async Task JaggedArraysNotSupported() - { - var jagged = new int[2][]; - jagged[0] = new[] { 8 }; - jagged[1] = new[] { 8, 10 }; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, jagged); - Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception - .TypeOf() - .With.Message.Contains("jagged")); - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")] + public async Task Jagged_arrays_not_supported() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1", conn); + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, new[] { new[] { 8 }, new[] { 8, 10 } }); + Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception + .TypeOf() + .With.Property("InnerException").Message.Contains("jagged")); + } - [Test, Description("Checks that ILists are properly serialized as arrays of their underlying types")] - public async Task ListTypeResolution() - { - using (var conn = await OpenConnectionAsync(ConnectionString)) - { - await AssertIListRoundtrips(conn, new[] { 1, 2, 3 }); - await AssertIListRoundtrips(conn, new IntList { 1, 2, 3 }); - await AssertIListRoundtrips(conn, new MisleadingIntList() { 1, 2, 3 }); - } - } + [Test, Description("Roundtrips one-dimensional and two-dimensional arrays of a PostgreSQL domain.")] + public async Task Array_of_domain() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing, ReloadTypes"); + + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Arrays of domains were introduced in PostgreSQL 11"); + await conn.ExecuteNonQueryAsync("CREATE DOMAIN pg_temp.posint AS integer CHECK (VALUE > 0);"); + await conn.ReloadTypesAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1::posint[], @p2::posint[][]", conn); + var oneDim = new[] { 1, 3, 5, 9 }; + var twoDim = new[,] { { 1, 3 }, { 5, 9 } }; + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Integer | NpgsqlDbType.Array, oneDim); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer | NpgsqlDbType.Array, twoDim); + await using var reader = cmd.ExecuteReader(); + reader.Read(); + + Assert.That(reader.GetValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); + + Assert.That(reader.GetValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetProviderSpecificValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(reader.GetProviderSpecificFieldType(1), Is.EqualTo(typeof(Array))); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1546")] - public void GenericListGetNpgsqlDbType() - { - var p = new NpgsqlParameter - { - ParameterName = "p1", - Value = new List { 1, 2, 3 } - }; - Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Array | NpgsqlDbType.Integer)); - } + [Test, Description("Roundtrips a PostgreSQL domain over a one-dimensional and a two-dimensional array.")] + public async Task Domain_of_array() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing, ReloadTypes"); + + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "11.0", "Domains over arrays were introduced in PostgreSQL 11"); + await conn.ExecuteNonQueryAsync( +""" +CREATE DOMAIN pg_temp.int_array_1d AS int[] CHECK(array_length(VALUE, 1) = 4); +CREATE DOMAIN pg_temp.int_array_2d AS int[][] CHECK(array_length(VALUE, 2) = 2); +"""); + await conn.ReloadTypesAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1::int_array_1d, @p2::int_array_2d", conn); + var oneDim = new[] { 1, 3, 5, 9 }; + var twoDim = new[,] { { 1, 3 }, { 5, 9 } }; + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Integer | NpgsqlDbType.Array, oneDim); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer | NpgsqlDbType.Array, twoDim); + await using var reader = cmd.ExecuteReader(); + reader.Read(); + + Assert.That(reader.GetValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); + + Assert.That(reader.GetValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetProviderSpecificValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(twoDim)); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(reader.GetProviderSpecificFieldType(1), Is.EqualTo(typeof(Array))); + } - [Test, Description("Roundtrips one-dimensional and two-dimensional arrays of a PostgreSQL domain.")] - public async Task ArrayOfDomain() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - - using (var conn = await OpenConnectionAsync()) - { - TestUtil.MinimumPgVersion(conn, "11.0", "Arrays of domains were introduced in PostgreSQL 11"); - conn.ExecuteNonQuery("CREATE DOMAIN pg_temp.posint AS integer CHECK (VALUE > 0);"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p1::posint[], @p2::posint[][]", conn)) - { - var oneDim = new[] { 1, 3, 5, 9 }; - var twoDim = new[,] { { 1, 3 }, { 5, 9 } }; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Integer | NpgsqlDbType.Array, oneDim); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer | NpgsqlDbType.Array, twoDim); - using var reader = cmd.ExecuteReader(); - reader.Read(); - - Assert.That(reader.GetValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); - - Assert.That(reader.GetValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetProviderSpecificValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(1), Is.EqualTo(typeof(Array))); - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3417")] + public async Task Read_two_empty_arrays() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT '{}'::INT[], '{}'::INT[]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.AreSame(reader.GetFieldValue(0), reader.GetFieldValue(1)); + // Unlike T[], List is mutable so we should not return the same instance + Assert.AreNotSame(reader.GetFieldValue>(0), reader.GetFieldValue>(1)); + } - [Test, Description("Roundtrips a PostgreSQL domain over a one-dimensional and a two-dimensional array.")] - public async Task DomainOfArray() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - - using (var conn = await OpenConnectionAsync()) - { - TestUtil.MinimumPgVersion(conn, "11.0", "Domains over arrays were introduced in PostgreSQL 11"); - conn.ExecuteNonQuery("CREATE DOMAIN pg_temp.int_array_1d AS int[] CHECK(array_length(VALUE, 1) = 4);" + - "CREATE DOMAIN pg_temp.int_array_2d AS int[][] CHECK(array_length(VALUE, 2) = 2);"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p1::int_array_1d, @p2::int_array_2d", conn)) - { - var oneDim = new[] { 1, 3, 5, 9 }; - var twoDim = new[,] { { 1, 3 }, { 5, 9 } }; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Integer | NpgsqlDbType.Array, oneDim); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Integer | NpgsqlDbType.Array, twoDim); - using var reader = cmd.ExecuteReader(); - reader.Read(); - - Assert.That(reader.GetValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(oneDim)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(Array))); - - Assert.That(reader.GetValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetProviderSpecificValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(twoDim)); - Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetProviderSpecificFieldType(1), Is.EqualTo(typeof(Array))); - } - } - } + [Test] + public async Task Arrays_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); - async Task AssertIListRoundtrips(NpgsqlConnection conn, IEnumerable value) - { - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = value }); - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer[]")); - Assert.That(reader[0], Is.EqualTo(value.ToArray())); - } - } - } + await AssertTypeUnsupportedRead("{1,2,3}", "integer[]", dataSource); + await AssertTypeUnsupportedWrite(new[] { 1, 2, 3 }, "integer[]", dataSource); + } - class IntList : List { } - class MisleadingIntList : List { } + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableArrays() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableArrays(); + await using var dataSource = dataSourceBuilder.Build(); - public ArrayTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + await AssertType(dataSource, new[] { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); } + + public ArrayTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/BitStringTests.cs b/test/Npgsql.Tests/Types/BitStringTests.cs index 068d914618..95c81ffb41 100644 --- a/test/Npgsql.Tests/Types/BitStringTests.cs +++ b/test/Npgsql.Tests/Types/BitStringTests.cs @@ -1,266 +1,128 @@ using System; using System.Collections; using System.Collections.Specialized; -using System.Data; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on the PostgreSQL BitString type +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-bit.html +/// +public class BitStringTests : MultiplexingTestBase { - /// - /// Tests on the PostgreSQL BitString type - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-bit.html - /// - public class BitStringTests : MultiplexingTestBase + [Test] + [TestCase("10110110", TestName = "BitArray")] + [TestCase("1011011000101111010110101101011011", TestName = "BitArray_with_34_bits")] + [TestCase("", TestName = "BitArray_empty")] + public async Task BitArray(string sqlLiteral) { - [Test] - public async Task RoundtripBitArray( - [Values( - "1011011000101111010110101101011011", // 34 bits - "10110110", - "" - )] - string bits - ) - { - var expected = new BitArray(bits.Length); - for (var i = 0; i < bits.Length; i++) - expected[i] = bits[i] == '1'; - - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Varbit); - var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Bit); - var p3 = new NpgsqlParameter("p3", NpgsqlDbType.Varbit) {Value = bits}; - var p4 = new NpgsqlParameter {ParameterName = "p4", Value = expected}; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(() => reader.GetFieldValue(i), Throws.Exception.TypeOf()); - } - } - } - } - - [Test] - public async Task Long() - { - using (var conn = await OpenConnectionAsync()) - { - var bitLen = (conn.Settings.WriteBufferSize + 10) * 8; - var chars = new char[bitLen]; - for (var i = 0; i < bitLen; i++) - chars[i] = i % 2 == 0 ? '0' : '1'; - await RoundtripBitArray(new string(chars)); - } - } - - [Test] - public async Task RoundtripBitVector32([Values(15, 0)] int bits) - { - var expected = new BitVector32(bits); + var len = sqlLiteral.Length; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } - } - } + var bitArray = new BitArray(len); + for (var i = 0; i < sqlLiteral.Length; i++) + bitArray[i] = sqlLiteral[i] == '1'; - [Test] - public async Task BitVector32TooLong() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand($"SELECT B'{new string('0', 34)}'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(() => reader.GetFieldValue(0), Throws.Exception.TypeOf()); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + await AssertType(bitArray, sqlLiteral, "bit varying", NpgsqlDbType.Varbit); - [Test, Description("Roundtrips a single bit")] - public async Task SingleBit() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::BIT(1), B'01'::BIT(2)", conn)) - { - const bool expected = true; - var p = new NpgsqlParameter("p", NpgsqlDbType.Bit); - // Type inference? But bool is mapped to PG bool - cmd.Parameters.Add(p); - p.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - Assert.That(reader.GetBoolean(0), Is.EqualTo(true)); - Assert.That(reader.GetValue(0), Is.EqualTo(true)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(true)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(bool))); - } - } - } - - [Test, Description("BIT(N) shouldn't be accessible as bool")] - public async Task BitstringAsSingleBit() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT B'01'::BIT(2)", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(() => reader.GetBoolean(0), Throws.Exception.TypeOf()); - - } - // Connection should still be OK - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } + if (len > 0) + await AssertType(bitArray, sqlLiteral, $"bit({len})", NpgsqlDbType.Bit, isDefaultForWriting: false); + } - [Test] - public async Task Array() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var expected = new[] { new BitArray(new[] { true, false, true }), new BitArray(new[] { false }) }; - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Varbit) { Value = expected }; - cmd.Parameters.Add(p); - p.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); + [Test] + public async Task BitArray_long() + { + await using var conn = await OpenConnectionAsync(); + var bitLen = (conn.Settings.WriteBufferSize + 10) * 8; + var chars = new char[bitLen]; + for (var i = 0; i < bitLen; i++) + chars[i] = i % 2 == 0 ? '0' : '1'; + await BitArray(new string(chars)); + } - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - } - } - } + [Test] + public Task BitVector32() + => AssertType( + new BitVector32(4), "00000000000000000000000000000100", "bit varying", NpgsqlDbType.Varbit, isDefaultForReading: false); - [Test] - public async Task SingleBitArray() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::BIT(1)[]", conn)) - { - var expected = new[] { true, false }; - var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Bit) {Value = expected}; - cmd.Parameters.Add(p); - p.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - var x = reader.GetValue(0); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); - } - } - } + [Test] + public Task BitVector32_too_long() + => AssertTypeUnsupportedRead(new string('0', 34), "bit varying"); - [Test] - public async Task Validation() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1::BIT VARYING", conn)) - { - var p = new NpgsqlParameter("p1", NpgsqlDbType.Bit); - cmd.Parameters.Add(p); - p.Value = "001q0"; - Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + [Test] + public Task Bool() + => AssertType(true, "1", "bit(1)", NpgsqlDbType.Bit, isDefault: false); - // Make sure the connection state is OK - Assert.That(await conn.ExecuteScalarAsync("SELECT 8"), Is.EqualTo(8)); - } - } + [Test] + public async Task Bitstring_with_multiple_bits_as_bool_throws() + { + await AssertTypeUnsupportedRead("01", "varbit"); + await AssertTypeUnsupportedRead("01", "bit(2)"); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2766")] - [Timeout(3000)] - public async Task SequentialReadOfOversizedBitArray() - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT 1::bit(100000)", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + [Test] + public async Task Array() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + var expected = new[] { new BitArray(new[] { true, false, true }), new BitArray(new[] { false }) }; + var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Varbit) { Value = expected }; + cmd.Parameters.Add(p); + p.Value = expected; + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + Assert.That(reader.GetValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + } - reader.Read(); + [Test] + public async Task Array_of_single_bits() + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p::BIT(1)[]", conn); + var expected = new[] { true, false }; + var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Bit) {Value = expected}; + cmd.Parameters.Add(p); + p.Value = expected; + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var x = reader.GetValue(0); + Assert.That(reader.GetValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + } - var actual = reader.GetFieldValue(0); - Assert.That(actual, Has.Length.EqualTo(100000)); - } + [Test] + public async Task Array_of_single_bits_and_null() + { + var dataSource = CreateDataSource(builder => builder.ArrayNullabilityMode = ArrayNullabilityMode.Always); + using var conn = await dataSource.OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p::BIT(1)[]", conn); + var expected = new bool?[] { true, false, null }; + var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Bit) {Value = expected}; + cmd.Parameters.Add(p); + p.Value = expected; + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var x = reader.GetValue(0); + Assert.That(reader.GetValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + } - // Older tests from here + [Test] + public Task As_string() + => AssertType("010101", "010101", "bit varying", NpgsqlDbType.Varbit, isDefault: false); - // TODO: Bring this test back -#if FIX - [Test] - public async Task BitString([Values(true, false)] bool prepareCommand) - { - using (var cmd = Conn.CreateCommand()) - { - cmd.CommandText = "Select :bs1 as output, :bs2, :bs3, :bs4, :bs5, array [1::bit, 0::bit], array [bit '10', bit '01'], :ba1, :ba2, :ba3"; - var output = new NpgsqlParameter() { ParameterName = "output", Direction = ParameterDirection.Output }; - cmd.Parameters.Add(output); - cmd.Parameters.Add(new NpgsqlParameter("bs1", NpgsqlDbType.Bit) { Value = new BitString("1011") }); - cmd.Parameters.Add(new NpgsqlParameter("bs2", NpgsqlDbType.Bit, 1) { Value = true }); - cmd.Parameters.Add(new NpgsqlParameter("bs3", NpgsqlDbType.Bit, 1) { Value = false }); - cmd.Parameters.Add(new NpgsqlParameter("bs4", NpgsqlDbType.Bit, 2) { Value = new BitString("01") }); - cmd.Parameters.Add(new NpgsqlParameter("bs5", NpgsqlDbType.Varbit) { Value = new BitString("01") }); - cmd.Parameters.Add(new NpgsqlParameter("ba1", NpgsqlDbType.Varbit | NpgsqlDbType.Array) { Value = new BitString[] { new BitString("10"), new BitString("01") } }); - cmd.Parameters.Add(new NpgsqlParameter("ba2", NpgsqlDbType.Bit | NpgsqlDbType.Array, 1) { Value = new bool[] { true, false } }); - cmd.Parameters.Add(new NpgsqlParameter("ba3", NpgsqlDbType.Bit | NpgsqlDbType.Array, 1) { Value = new BitString[] { new BitString("1"), new BitString("0") } }); - if (prepareCommand) - cmd.Prepare(); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.IsTrue(new BitString("1011") == (BitString)output.Value); - Assert.IsTrue(new BitString("1011") == (BitString)reader.GetValue(0)); - Assert.AreEqual(true, reader.GetValue(1)); - Assert.AreEqual(false, reader.GetValue(2)); - Assert.IsTrue(new BitString("01") == (BitString)reader.GetValue(3)); - Assert.IsTrue(new BitString("01") == (BitString)reader.GetValue(4)); - Assert.AreEqual(true, ((bool[])reader.GetValue(5))[0]); - Assert.AreEqual(false, ((bool[])reader.GetValue(5))[1]); - for (int i = 6; i <= 7; i++) - { - Assert.AreEqual(new BitString("10"), ((BitString[])reader.GetValue(i))[0]); - Assert.AreEqual(new BitString("01"), ((BitString[])reader.GetValue(i))[1]); - } - for (int i = 8; i <= 9; i++) - { - Assert.AreEqual(true, ((bool[])reader.GetValue(i))[0]); - Assert.AreEqual(false, ((bool[])reader.GetValue(i))[1]); - } - } - } - } -#endif + [Test] + public Task Write_as_string_validation() + => AssertTypeUnsupportedWrite("001q0", "bit varying"); - public BitStringTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} - } + public BitStringTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/ByteaTests.cs b/test/Npgsql.Tests/Types/ByteaTests.cs index 38c0fa80d8..c34bce04ff 100644 --- a/test/Npgsql.Tests/Types/ByteaTests.cs +++ b/test/Npgsql.Tests/Types/ByteaTests.cs @@ -1,275 +1,297 @@ using System; +using System.Collections.Generic; using System.Data; -using System.Linq; +using System.IO; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on the PostgreSQL bytea type +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-binary.html +/// +public class ByteaTests : MultiplexingTestBase { - /// - /// Tests on the PostgreSQL bytea type - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-binary.html - /// - public class ByteaTests : MultiplexingTestBase + [Test] + [TestCase(new byte[] { 1, 2, 3, 4, 5 }, "\\x0102030405", TestName = "Bytea")] + [TestCase(new byte[] { }, "\\x", TestName = "Bytea_empty")] + public Task Bytea(byte[] byteArray, string sqlLiteral) + => AssertType(byteArray, sqlLiteral, "bytea", NpgsqlDbType.Bytea, DbType.Binary); + + [Test] + public async Task Bytea_long() { - [Test, Description("Roundtrips a bytea")] - public async Task Roundtrip() + await using var conn = await OpenConnectionAsync(); + var array = new byte[conn.Settings.WriteBufferSize + 100]; + var sqlLiteral = "\\x" + new string('1', (conn.Settings.WriteBufferSize + 100) * 2); + for (var i = 0; i < array.Length; i++) + array[i] = 17; + + await Bytea(array, sqlLiteral); + } + + [Test] + public Task AsMemory() + => AssertType( + new Memory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, + comparer: (left, right) => left.Span.SequenceEqual(right.Span)); + + [Test] + public Task AsReadOnlyMemory() + => AssertType( + new ReadOnlyMemory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, + comparer: (left, right) => left.Span.SequenceEqual(right.Span)); + + [Test] + public Task AsArraySegment() + => AssertType( + new ArraySegment(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + + [Test] + public Task Write_as_MemoryStream() + => AssertTypeWrite( + () => new MemoryStream(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + + [Test] + public Task Write_as_MemoryStream_truncated() + { + var msFactory = () => { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - byte[] expected = { 1, 2, 3, 4, 5 }; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Bytea); - var p2 = new NpgsqlParameter("p2", DbType.Binary); - var p3 = new NpgsqlParameter { ParameterName = "p3", Value = expected }; - Assert.That(p3.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); - Assert.That(p3.DbType, Is.EqualTo(DbType.Binary)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(byte[]))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - } - } - } + var ms = new MemoryStream(new byte[] { 1, 2, 3, 4 }); + ms.ReadByte(); + return ms; + }; - [Test] - public async Task RoundtripLarge() + return AssertTypeWrite( + msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + } + + [Test] + public Task Write_as_MemoryStream_exposableArray() + { + var msFactory = () => { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::BYTEA", conn)) - { - var expected = new byte[conn.Settings.WriteBufferSize + 100]; - for (var i = 0; i < expected.Length; i++) - expected[i] = 8; - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Bytea) { Value = expected }); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(byte[]))); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } - } + var ms = new MemoryStream(20); + ms.WriteByte(1); + ms.WriteByte(2); + ms.WriteByte(3); + ms.WriteByte(4); + ms.Position = 1; + return ms; + }; + + return AssertTypeWrite( + msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + } - [Test] - public async Task Read([Values(CommandBehavior.Default, CommandBehavior.SequentialAccess)] CommandBehavior behavior) + [Test] + public async Task Write_as_MemoryStream_long() + { + var rnd = new Random(1); + var bytes = new byte[8192 * 4]; + rnd.NextBytes(bytes); + var expectedSql = "\\x" + ToHex(bytes); + + await AssertTypeWrite( + () => new MemoryStream(bytes), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + } + + [Test] + public async Task Write_as_FileStream() + { + var filePath = Path.GetTempFileName(); + var fsList = new List(); + try { - using (var conn = await OpenConnectionAsync()) - await using (await CreateTempTable(conn, "bytes BYTEA", out var table)) - { - // TODO: This is too small to actually test any interesting sequential behavior - byte[] expected = {1, 2, 3, 4, 5}; - await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES ({EncodeByteaHex(expected)})"); - - string queryText = $"SELECT bytes, 'foo', bytes, bytes, bytes FROM {table}"; - using (var cmd = new NpgsqlCommand(queryText, conn)) - using (var reader = await cmd.ExecuteReaderAsync(behavior)) - { - reader.Read(); - - var actual = reader.GetFieldValue(0); - Assert.That(actual, Is.EqualTo(expected)); - - if (behavior.IsSequential()) - Assert.That(() => reader[0], Throws.Exception.TypeOf(), "Seek back sequential"); - else - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - - Assert.That(reader.GetString(1), Is.EqualTo("foo")); - - Assert.That(reader[2], Is.EqualTo(expected)); - Assert.That(reader.GetValue(3), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(4), Is.EqualTo(expected)); - } - } - } + await File.WriteAllBytesAsync(filePath, new byte[] { 1, 2, 3 }); - [Test] - public async Task EmptyRoundtrip() + await AssertTypeWrite( + () => FileStreamFactory(filePath, fsList), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + } + finally { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT :val::BYTEA", conn)) + foreach (var fs in fsList) + await fs.DisposeAsync(); + + try { - var expected = new byte[0]; - cmd.Parameters.Add("val", NpgsqlDbType.Bytea); - cmd.Parameters["val"].Value = expected; - var result = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.That(result, Is.EqualTo(expected)); + File.Delete(filePath); } + catch {} } - [Test, Description("Tests that bytea values are truncated when the NpgsqlParameter's Size is set")] - public async Task Truncate() + FileStream FileStreamFactory(string filePath, List fsList) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - byte[] data = { 1, 2, 3, 4, 5, 6 }; - var p = new NpgsqlParameter("p", data) { Size = 4 }; - cmd.Parameters.Add(p); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); - - // NpgsqlParameter.Size needs to persist when value is changed - byte[] data2 = { 11, 12, 13, 14, 15, 16 }; - p.Value = data2; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); - - // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 - p.Size = data2.Length + 10; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - p.Size = 0; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - p.Size = -1; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - - Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); - } + var fs = File.OpenRead(filePath); + fsList.Add(fs); + return fs; } + } - [Test] - public async Task ByteaOverArrayOfBytes() + [Test] + public async Task Write_as_FileStream_long() + { + var filePath = Path.GetTempFileName(); + var fsList = new List(); + var rnd = new Random(1); + try { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", new byte[3]); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("bytea")); - } - } - } + var bytes = new byte[8192 * 4]; + rnd.NextBytes(bytes); + await File.WriteAllBytesAsync(filePath, bytes); + var expectedSql = "\\x" + ToHex(bytes); - [Test] - public async Task ArrayOfBytea() + await AssertTypeWrite( + () => FileStreamFactory(filePath, fsList), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + } + finally { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT :p1", conn)) + foreach (var fs in fsList) + await fs.DisposeAsync(); + + try { - var bytes = new byte[] { 1, 2, 3, 4, 5, 34, 39, 48, 49, 50, 51, 52, 92, 127, 128, 255, 254, 253, 252, 251 }; - var inVal = new[] { bytes, bytes }; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Bytea | NpgsqlDbType.Array, inVal); - var retVal = (byte[][]?)await cmd.ExecuteScalarAsync(); - Assert.AreEqual(inVal.Length, retVal!.Length); - Assert.AreEqual(inVal[0], retVal[0]); - Assert.AreEqual(inVal[1], retVal[1]); + File.Delete(filePath); } + catch {} } -#if !NETSTANDARD2_0 - [Test] - public async Task Memory() + FileStream FileStreamFactory(string filePath, List fsList) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var bytes = new byte[] { 1, 2, 3 }; - cmd.Parameters.AddWithValue("p1", new ReadOnlyMemory(bytes)); - cmd.Parameters.AddWithValue("p2", new Memory(bytes)); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(bytes)); - Assert.That(reader[1], Is.EqualTo(bytes)); - Assert.That(() => reader.GetFieldValue>(0), Throws.Exception.TypeOf()); - Assert.That(() => reader.GetFieldValue>(0), Throws.Exception.TypeOf()); - } - } + var fs = File.OpenRead(filePath); + fsList.Add(fs); + return fs; } -#endif + } - // Older tests from here + static string ToHex(ReadOnlySpan bytes) + { + var c = new char[bytes.Length * 2]; - [Test] - public async Task Insert1() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @bytes", conn)) - { - byte[] toStore = { 0, 1, 255, 254 }; - cmd.Parameters.AddWithValue("@bytes", toStore); - var result = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.AreEqual(toStore, result!); - } - } + byte b; - [Test] - public async Task ArraySegment() + for (int bx = 0, cx = 0; bx < bytes.Length; ++bx, ++cx) { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("select :bytearr", conn)) - { - var arr = new byte[20000]; - for (var i = 0; i < arr.Length; i++) - { - arr[i] = (byte) (i & 0xff); - } - - // Big value, should go through "direct buffer" - var segment = new ArraySegment(arr, 17, 18000); - cmd.Parameters.Add(new NpgsqlParameter("bytearr", DbType.Binary) {Value = segment}); - var returned = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.That(segment.SequenceEqual(returned!)); - - cmd.Parameters[0].Size = 17000; - returned = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.That(returned!.SequenceEqual(new ArraySegment(segment.Array!, segment.Offset, 17000))); - - // Small value, should be written normally through the NpgsqlBuffer - segment = new ArraySegment(arr, 6, 10); - cmd.Parameters[0].Value = segment; - returned = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.That(segment.SequenceEqual(returned!)); - - cmd.Parameters[0].Size = 2; - returned = (byte[]?)await cmd.ExecuteScalarAsync(); - Assert.That(returned!.SequenceEqual(new ArraySegment(segment.Array!, segment.Offset, 2))); - } - - using (var cmd = new NpgsqlCommand("select :bytearr", conn)) - { - var segment = new ArraySegment(new byte[] {1, 2, 3}, 1, 2); - cmd.Parameters.AddWithValue("bytearr", segment); - Assert.That(segment.SequenceEqual((byte[])(await cmd.ExecuteScalarAsync())!)); - } - } + b = (byte)(bytes[bx] >> 4); + c[cx] = (char)(b > 9 ? b - 10 + 'a' : b + '0'); + + b = (byte)(bytes[bx] & 0x0F); + c[++cx] = (char)(b > 9 ? b - 10 + 'a' : b + '0'); } - [Test, Description("Writes a bytea that doesn't fit in a partially-full buffer, but does fit in an empty buffer")] - [IssueLink("https://github.com/npgsql/npgsql/issues/654")] - public async Task WriteDoesntFitInitiallyButFitsLater() + return new string(c); + } + + [Test, Description("Tests that bytea array values are truncated when the NpgsqlParameter's Size is set")] + public async Task Truncate_array() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + byte[] data = { 1, 2, 3, 4, 5, 6 }; + var p = new NpgsqlParameter("p", data) { Size = 4 }; + cmd.Parameters.Add(p); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); + Assert.That(p.Value, Is.EqualTo(new byte[] { 1, 2, 3, 4 }), "Truncated parameter value should be persisted on the parameter per DbParameter.Size docs"); + + // NpgsqlParameter.Size needs to persist when value is changed + byte[] data2 = { 11, 12, 13, 14, 15, 16 }; + p.Value = data2; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); + + // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 + p.Value = data2; + p.Size = data2.Length + 10; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + p.Size = 0; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + p.Size = -1; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + + Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); + } + + [Test, Description("Tests that bytea stream values are truncated when the NpgsqlParameter's Size is set")] + public async Task Truncate_stream() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + byte[] data = { 1, 2, 3, 4, 5, 6 }; + var p = new NpgsqlParameter("p", new MemoryStream(data)) { Size = 4 }; + cmd.Parameters.Add(p); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); + + // NpgsqlParameter.Size needs to persist when value is changed + byte[] data2 = { 11, 12, 13, 14, 15, 16 }; + p.Value = new MemoryStream(data2); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); + + // Handle with offset + var data2ms = new MemoryStream(data2); + data2ms.ReadByte(); + p.Value = data2ms; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 12, 13, 14, 15 })); + + p.Size = 0; + p.Value = new MemoryStream(data2); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + p.Size = -1; + p.Value = new MemoryStream(data2); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + + Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); + + p.Value = new MemoryStream(data2); + p.Size = data2.Length + 10; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + } + + [Test] + public async Task Write_as_NonSeekable_stream() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + byte[] data = { 1, 2, 3, 4, 5, 6 }; + var p = new NpgsqlParameter("p", new NonSeekableStream(data)) { Size = 4 }; + cmd.Parameters.Add(p); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); + + var streamWithOffset = new NonSeekableStream(data); + streamWithOffset.ReadByte(); + p.Value = streamWithOffset; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 2, 3, 4, 5 })); + + p.Value = new NonSeekableStream(data); + p.Size = 0; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + } + + [Test] + public async Task Array_of_bytea() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT :p1", conn); + var bytes = new byte[] { 1, 2, 3, 4, 5, 34, 39, 48, 49, 50, 51, 52, 92, 127, 128, 255, 254, 253, 252, 251 }; + var inVal = new[] { bytes, bytes }; + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Bytea | NpgsqlDbType.Array, inVal); + var retVal = (byte[][]?)await cmd.ExecuteScalarAsync(); + Assert.AreEqual(inVal.Length, retVal!.Length); + Assert.AreEqual(inVal[0], retVal[0]); + Assert.AreEqual(inVal[1], retVal[1]); + } + + sealed class NonSeekableStream : MemoryStream + { + public override bool CanSeek => false; + + public NonSeekableStream(byte[] data) : base(data) { - using (var conn = await OpenConnectionAsync()) - await using (await CreateTempTable(conn, "field BYTEA", out var table)) - { - var bytea = new byte[8180]; - for (var i = 0; i < bytea.Length; i++) - { - bytea[i] = (byte) (i%256); - } - - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (field) VALUES (@p)", conn)) - { - cmd.Parameters.AddWithValue("@p", bytea); - await cmd.ExecuteNonQueryAsync(); - } - } } - - public ByteaTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } + + public ByteaTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs index 304459d916..2188569a49 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs @@ -1,148 +1,157 @@ using System; +using System.Threading.Tasks; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public partial class CompositeHandlerTests { - public partial class CompositeHandlerTests + async Task Read(Action, T> assert, string? schema = null) + where T : IComposite, IInitializable, new() { - void Read(Action, T> assert, string? schema = null) - where T : IComposite, IInitializable, new() - { - var composite = new T(); - composite.Initialize(); - Read(composite, assert, schema); - } + var composite = new T(); + composite.Initialize(); + await Read(composite, assert, schema); + } - void Read(T composite, Action, T> assert, string? schema = null) - where T : IComposite - { - using var connection = OpenAndMapComposite(composite, schema, nameof(Read), out var name); - using var command = new NpgsqlCommand($"SELECT ROW({composite.GetValues()})::{name}", connection); - using var reader = command.ExecuteReader(); - - reader.Read(); - assert(() => reader.GetFieldValue(0), composite); - } - - [Test] - public void Read_ClassWithProperty_Succeeds() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); - - [Test] - public void Read_ClassWithField_Succeeds() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); - - [Test] - public void Read_StructWithProperty_Succeeds() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); - - [Test] - public void Read_StructWithField_Succeeds() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); - - [Test] - public void Read_TypeWithTwoProperties_Succeeds() => - Read((execute, expected) => - { - var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.IntValue); - Assert.AreEqual(expected.StringValue, actual.StringValue); - }); - - [Test] - public void Read_TypeWithTwoPropertiesInverted_Succeeds() => - Read((execute, expected) => - { - var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.IntValue); - Assert.AreEqual(expected.StringValue, actual.StringValue); - }); - - [Test] - public void Read_TypeWithPrivateProperty_ThrowsInvalidOperationException() => - Read(new TypeWithPrivateProperty(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithPrivateGetter_Succeeds() => - Read(new TypeWithPrivateGetter(), (execute, expected) => execute()); - - [Test] - public void Read_TypeWithPrivateSetter_ThrowsInvalidOperationException() => - Read(new TypeWithPrivateSetter(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithoutGetter_Succeeds() => - Read(new TypeWithoutGetter(), (execute, expected) => execute()); - - [Test] - public void Read_TypeWithoutSetter_ThrowsInvalidOperationException() => - Read(new TypeWithoutSetter(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithExplicitPropertyName_Succeeds() => - Read(new TypeWithExplicitPropertyName { MyValue = HelloSlonik }, (execute, expected) => Assert.That(execute().MyValue, Is.EqualTo(expected.MyValue))); - - [Test] - public void Read_TypeWithExplicitParameterName_Succeeds() => - Read(new TypeWithExplicitParameterName(HelloSlonik), (execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); - - [Test] - public void Read_TypeWithMorePropertiesThanAttributes_Succeeds() => - Read(new TypeWithMorePropertiesThanAttributes(), (execute, expected) => - { - var actual = execute(); - Assert.That(actual.IntValue, Is.Not.Null); - Assert.That(actual.StringValue, Is.Null); - }); - - [Test] - public void Read_TypeWithLessPropertiesThanAttributes_ThrowsInvalidOperationException() => - Read(new TypeWithLessPropertiesThanAttributes(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithLessParametersThanAttributes_ThrowsInvalidOperationException() => - Read(new TypeWithLessParametersThanAttributes(TheAnswer), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithMoreParametersThanAttributes_ThrowsInvalidOperationException() => - Read(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Read_TypeWithOneParameter_Succeeds() => - Read(new TypeWithOneParameter(1), (execute, expected) => Assert.That(execute().Value1, Is.EqualTo(expected.Value1))); - - [Test] - public void Read_TypeWithTwoParameters_Succeeds() => - Read(new TypeWithTwoParameters(TheAnswer, HelloSlonik), (execute, expected) => - { - var actual = execute(); - Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); - Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); - }); - - [Test] - public void Read_TypeWithTwoParametersReversed_Succeeds() => - Read(new TypeWithTwoParametersReversed(HelloSlonik, TheAnswer), (execute, expected) => - { - var actual = execute(); - Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); - Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); - }); - - [Test] - public void Read_TypeWithNineParameters_Succeeds() => - Read(new TypeWithNineParameters(1, 2, 3, 4, 5, 6, 7, 8, 9), (execute, expected) => - { - var actual = execute(); - Assert.That(actual.Value1, Is.EqualTo(expected.Value1)); - Assert.That(actual.Value2, Is.EqualTo(expected.Value2)); - Assert.That(actual.Value3, Is.EqualTo(expected.Value3)); - Assert.That(actual.Value4, Is.EqualTo(expected.Value4)); - Assert.That(actual.Value5, Is.EqualTo(expected.Value5)); - Assert.That(actual.Value6, Is.EqualTo(expected.Value6)); - Assert.That(actual.Value7, Is.EqualTo(expected.Value7)); - Assert.That(actual.Value8, Is.EqualTo(expected.Value8)); - Assert.That(actual.Value9, Is.EqualTo(expected.Value9)); - }); + async Task Read(T composite, Action, T> assert, string? schema = null) + where T : IComposite + { + await using var dataSource = await OpenAndMapComposite(composite, schema, nameof(Read), out var name); + await using var connection = await dataSource.OpenConnectionAsync(); + + var literal = $"ROW({composite.GetValues()})::{name}"; + var arrayLiteral = $"ARRAY[{literal}]::{name}[]"; + await using var command = new NpgsqlCommand($"SELECT {literal}, {arrayLiteral}", connection); + await using var reader = command.ExecuteReader(); + + await reader.ReadAsync(); + assert(() => reader.GetFieldValue(0), composite); + assert(() => reader.GetFieldValue(1)[0], composite); } + + [Test] + public Task Read_class_with_property() => + Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + + [Test] + public Task Read_class_with_field() => + Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + + [Test] + public Task Read_struct_with_property() => + Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + + [Test] + public Task Read_struct_with_field() => + Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + + [Test] + public Task Read_type_with_two_properties() => + Read((execute, expected) => + { + var actual = execute(); + Assert.AreEqual(expected.IntValue, actual.IntValue); + Assert.AreEqual(expected.StringValue, actual.StringValue); + }); + + [Test] + public Task Read_type_with_two_properties_inverted() => + Read((execute, expected) => + { + var actual = execute(); + Assert.AreEqual(expected.IntValue, actual.IntValue); + Assert.AreEqual(expected.StringValue, actual.StringValue); + }); + + [Test] + public Task Read_type_with_private_property_throws() => + Read(new TypeWithPrivateProperty(), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); + + [Test] + public Task Read_type_with_private_getter() => + Read(new TypeWithPrivateGetter(), (execute, expected) => execute()); + + [Test] + public Task Read_type_with_private_setter_throws() => + Read(new TypeWithPrivateSetter(), (execute, expected) => Assert.Throws(() => execute())); + + [Test] + public Task Read_type_without_getter() => + Read(new TypeWithoutGetter(), (execute, expected) => execute()); + + [Test] + public Task Read_type_without_setter_throws() => + Read(new TypeWithoutSetter(), (execute, expected) => Assert.Throws(() => execute())); + + [Test] + public Task Read_type_with_explicit_property_name() => + Read(new TypeWithExplicitPropertyName { MyValue = HelloSlonik }, (execute, expected) => Assert.That(execute().MyValue, Is.EqualTo(expected.MyValue))); + + [Test] + public Task Read_type_with_explicit_parameter_name() => + Read(new TypeWithExplicitParameterName(HelloSlonik), (execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); + + [Test] + public Task Read_type_with_more_properties_than_attributes() => + Read(new TypeWithMorePropertiesThanAttributes(), (execute, expected) => + { + var actual = execute(); + Assert.That(actual.IntValue, Is.Not.Null); + Assert.That(actual.StringValue, Is.Null); + }); + + [Test] + public Task Read_type_with_less_properties_than_attributes_throws() => + Read(new TypeWithLessPropertiesThanAttributes(), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); + + [Test] + public Task Read_type_with_less_parameters_than_attributes_throws() => + Read(new TypeWithLessParametersThanAttributes(TheAnswer), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); + + [Test] + public Task Read_type_with_more_parameters_than_attributes_throws() => + Read(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik), (execute, expected) => + Assert.That(() => execute(), Throws.Exception.TypeOf().With.Property("InnerException").TypeOf())); + + [Test] + public Task Read_type_with_one_parameter() => + Read(new TypeWithOneParameter(1), (execute, expected) => Assert.That(execute().Value1, Is.EqualTo(expected.Value1))); + + [Test] + public Task Read_type_with_two_parameters() => + Read(new TypeWithTwoParameters(TheAnswer, HelloSlonik), (execute, expected) => + { + var actual = execute(); + Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); + Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); + }); + + [Test] + public Task Read_type_with_two_parameters_reversed() => + Read(new TypeWithTwoParametersReversed(HelloSlonik, TheAnswer), (execute, expected) => + { + var actual = execute(); + Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); + Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); + }); + + [Test] + public Task Read_type_with_nine_parameters() => + Read(new TypeWithNineParameters(1, 2, 3, 4, 5, 6, 7, 8, 9), (execute, expected) => + { + var actual = execute(); + Assert.That(actual.Value1, Is.EqualTo(expected.Value1)); + Assert.That(actual.Value2, Is.EqualTo(expected.Value2)); + Assert.That(actual.Value3, Is.EqualTo(expected.Value3)); + Assert.That(actual.Value4, Is.EqualTo(expected.Value4)); + Assert.That(actual.Value5, Is.EqualTo(expected.Value5)); + Assert.That(actual.Value6, Is.EqualTo(expected.Value6)); + Assert.That(actual.Value7, Is.EqualTo(expected.Value7)); + Assert.That(actual.Value8, Is.EqualTo(expected.Value8)); + Assert.That(actual.Value9, Is.EqualTo(expected.Value9)); + }); } diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs index bdfd3799ab..160b037a97 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs @@ -1,109 +1,129 @@ using System; +using System.Threading.Tasks; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public partial class CompositeHandlerTests { - public partial class CompositeHandlerTests + async Task Write(Action assert, string? schema = null) + where T : IComposite, IInitializable, new() { - void Write(Action, T> assert, string? schema = null) - where T : IComposite, IInitializable, new() - { - var composite = new T(); - composite.Initialize(); - Write(composite, assert, schema); - } + var composite = new T(); + composite.Initialize(); + await Write(composite, assert, schema); + } - void Write(T composite, Action, T> assert, string? schema = null) - where T : IComposite + async Task Write(T composite, Action? assert = null, string? schema = null) + where T : IComposite + { + await using var dataSource = await OpenAndMapComposite(composite, schema, nameof(Write), out var _); + await using var connection = await dataSource.OpenConnectionAsync(); { - using var connection = OpenAndMapComposite(composite, schema, nameof(Write), out var _); - using var command = new NpgsqlCommand("SELECT (@c).*", connection); + await using var command = new NpgsqlCommand("SELECT (@c).*", connection); command.Parameters.AddWithValue("c", composite); - assert(() => - { - var reader = command.ExecuteReader(); - reader.Read(); - return reader; - }, composite); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (assert is not null) + assert(reader, composite); } - [Test] - public void Write_ClassWithProperty_Succeeds() => - Write((execute, expected) => Assert.AreEqual(expected.Value, execute().GetString(0))); - - [Test] - public void Write_ClassWithField_Succeeds() => - Write((execute, expected) => Assert.AreEqual(expected.Value, execute().GetString(0))); - - [Test] - public void Write_StructWithProperty_Succeeds() => - Write((execute, expected) => Assert.AreEqual(expected.Value, execute().GetString(0))); - - [Test] - public void Write_StructWithField_Succeeds() => - Write((execute, expected) => Assert.AreEqual(expected.Value, execute().GetString(0))); - - [Test] - public void Write_TypeWithTwoProperties_Succeeds() => - Write((execute, expected) => - { - var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.GetInt32(0)); - Assert.AreEqual(expected.StringValue, actual.GetString(1)); - }); - - [Test] - public void Write_TypeWithTwoPropertiesInverted_Succeeds() => - Write((execute, expected) => - { - var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.GetInt32(1)); - Assert.AreEqual(expected.StringValue, actual.GetString(0)); - }); - - [Test] - public void Write_TypeWithPrivateProperty_ThrowsInvalidOperationException() => - Write(new TypeWithPrivateProperty(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Write_TypeWithPrivateGetter_ThrowsInvalidOperationException() => - Write(new TypeWithPrivateGetter(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Write_TypeWithPrivateSetter_Succeeds() => - Write(new TypeWithPrivateSetter(), (execute, expected) => execute()); - - [Test] - public void Write_TypeWithoutGetter_ThrowsInvalidOperationException() => - Write(new TypeWithoutGetter(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Write_TypeWithoutSetter_Succeeds() => - Write(new TypeWithoutSetter(), (execute, expected) => execute()); - - [Test] - public void Write_TypeWithExplicitPropertyName_Succeeds() => - Write(new TypeWithExplicitPropertyName { MyValue = HelloSlonik }, (execute, expected) => Assert.That(execute().GetString(0), Is.EqualTo(expected.MyValue))); - - [Test] - public void Write_TypeWithExplicitParameterName_Succeeds() => - Write(new TypeWithExplicitParameterName(HelloSlonik), (execute, expected) => Assert.That(execute().GetString(0), Is.EqualTo(expected.Value))); - - [Test] - public void Write_TypeWithMorePropertiesThanAttributes_Succeeds() => - Write(new TypeWithMorePropertiesThanAttributes(), (execute, expected) => execute()); - - [Test] - public void Write_TypeWithLessPropertiesThanAttributes_ThrowsInvalidOperationException() => - Write(new TypeWithLessPropertiesThanAttributes(), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Write_TypeWithLessParametersThanAttributes_ThrowsInvalidOperationException() => - Write(new TypeWithLessParametersThanAttributes(TheAnswer), (execute, expected) => Assert.Throws(() => execute())); - - [Test] - public void Write_TypeWithMoreParametersThanAttributes_ThrowsInvalidOperationException() => - Write(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik), (execute, expected) => Assert.Throws(() => execute())); + { + await using var command = new NpgsqlCommand("SELECT (@arrayc)[1].*", connection); + + command.Parameters.AddWithValue("arrayc", new[] { composite }); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + + + if (assert is not null) + assert(reader, composite); + } } + + [Test] + public Task Write_class_with_property() + => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + + [Test] + public Task Write_class_with_field() + => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + + [Test] + public Task Write_struct_with_property() + => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + + [Test] + public Task Write_struct_with_field() + => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + + [Test] + public Task Write_type_with_two_properties() + => Write((reader, expected) => + { + Assert.AreEqual(expected.IntValue, reader.GetInt32(0)); + Assert.AreEqual(expected.StringValue, reader.GetString(1)); + }); + + [Test] + public Task Write_type_with_two_properties_inverted() + => Write((reader, expected) => + { + Assert.AreEqual(expected.IntValue, reader.GetInt32(1)); + Assert.AreEqual(expected.StringValue, reader.GetString(0)); + }); + + [Test] + public void Write_type_with_private_property_throws() + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithPrivateProperty())); + + [Test] + public void Write_type_with_private_getter_throws() + => Assert.ThrowsAsync(async () => await Write(new TypeWithPrivateGetter())); + + [Test] + public Task Write_type_with_private_setter() + => Write(new TypeWithPrivateSetter()); + + [Test] + public void Write_type_without_getter_throws() + => Assert.ThrowsAsync(async () => await Write(new TypeWithoutGetter())); + + [Test] + public Task Write_type_without_setter() => + Write(new TypeWithoutSetter()); + + [Test] + public Task Write_type_with_explicit_property_name() + => Write(new TypeWithExplicitPropertyName { MyValue = HelloSlonik }, (reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.MyValue))); + + [Test] + public Task Write_type_with_explicit_parameter_name() + => Write(new TypeWithExplicitParameterName(HelloSlonik), (reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.Value))); + + [Test] + public Task Write_type_with_more_properties_than_attributes() + => Write(new TypeWithMorePropertiesThanAttributes()); + + [Test] + public void Write_type_with_less_properties_than_attributes_throws() + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithLessPropertiesThanAttributes())); + + [Test] + public void Write_type_with_less_parameters_than_attributes_throws() + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithMoreParametersThanAttributes(TheAnswer, HelloSlonik))); + + [Test] + public void Write_type_with_more_parameters_than_attributes_throws() + => Assert.ThrowsAsync( + Is.TypeOf().With.Property("InnerException").TypeOf(), + async () => await Write(new TypeWithLessParametersThanAttributes(TheAnswer))); } diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.cs index 66f82c487f..1df95980a3 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.cs @@ -1,260 +1,258 @@ -using System; +using System.Threading.Tasks; using Npgsql.NameTranslation; using NpgsqlTypes; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public partial class CompositeHandlerTests : TestBase { - public partial class CompositeHandlerTests : TestBase + Task OpenAndMapComposite(T composite, string? schema, string nameSuffix, out string nameQualified) + where T : IComposite { - NpgsqlConnection OpenAndMapComposite(T composite, string? schema, string nameSuffix, out string nameQualified) - where T : IComposite - { - var nameTranslator = new NpgsqlSnakeCaseNameTranslator(); - var name = nameTranslator.TranslateTypeName(typeof(T).Name + nameSuffix); - - if (schema == null) - nameQualified = name; - else - { - schema = nameTranslator.TranslateTypeName(schema); - nameQualified = schema + "." + name; - } - - var connection = OpenConnection(); - - try - { - connection.ExecuteNonQuery(schema is null ? $"DROP TYPE IF EXISTS {name}" : $"DROP SCHEMA IF EXISTS {schema} CASCADE; CREATE SCHEMA {schema}"); - connection.ExecuteNonQuery($"CREATE TYPE {nameQualified} AS ({composite.GetAttributes()})"); - - connection.ReloadTypes(); - connection.TypeMapper.MapComposite(nameQualified, nameTranslator); - - return connection; - } - catch - { - connection.Dispose(); - throw; - } - } + var nameTranslator = new NpgsqlSnakeCaseNameTranslator(); + var name = nameTranslator.TranslateTypeName(typeof(T).Name + nameSuffix); - interface IComposite + if (schema == null) + nameQualified = name; + else { - string GetAttributes(); - string GetValues(); + schema = nameTranslator.TranslateTypeName(schema); + nameQualified = schema + "." + name; } - interface IInitializable - { - void Initialize(); - } - - const string HelloSlonik = "Hello, Slonik"; - const int TheAnswer = 42; + return OpenAndMapCompositeCore(nameQualified); - public class ClassWithProperty : IComposite, IInitializable + async Task OpenAndMapCompositeCore(string nameQualified) { - public string? Value { get; set; } - public string GetAttributes() => "value text"; - public string GetValues() => $"'{Value}'"; - public void Initialize() => Value = HelloSlonik; - } + await using var adminConnection = await OpenConnectionAsync(); - public class ClassWithField : IComposite, IInitializable - { - public string? Value; - public string GetAttributes() => "value text"; - public string GetValues() => $"'{Value}'"; - public void Initialize() => Value = HelloSlonik; - } + await adminConnection.ExecuteNonQueryAsync(schema is null ? $"DROP TYPE IF EXISTS {name}" : $"DROP SCHEMA IF EXISTS {schema} CASCADE; CREATE SCHEMA {schema}"); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {nameQualified} AS ({composite.GetAttributes()})"); - public struct StructWithProperty : IComposite, IInitializable - { - public string? Value { get; set; } - public string GetAttributes() => "value text"; - public string GetValues() => $"'{Value}'"; - public void Initialize() => Value = HelloSlonik; - } + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(nameQualified, nameTranslator); + var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); - public struct StructWithField : IComposite, IInitializable - { - public string? Value; - public string GetAttributes() => "value text"; - public string GetValues() => $"'{Value}'"; - public void Initialize() => Value = HelloSlonik; + return dataSource; } + } - public class TypeWithTwoProperties : IComposite, IInitializable - { - public int IntValue { get; set; } - public string? StringValue { get; set; } + interface IComposite + { + string GetAttributes(); + string GetValues(); + } - public string GetAttributes() => "int_value integer, string_value text"; - public string GetValues() => $"{IntValue}, '{StringValue}'"; + interface IInitializable + { + void Initialize(); + } - public void Initialize() - { - IntValue = TheAnswer; - StringValue = HelloSlonik; - } - } + const string HelloSlonik = "Hello, Slonik"; + const int TheAnswer = 42; - public class TypeWithTwoPropertiesReversed : IComposite, IInitializable - { - public int IntValue { get; set; } - public string? StringValue { get; set; } + public class ClassWithProperty : IComposite, IInitializable + { + public string? Value { get; set; } + public string GetAttributes() => "value text"; + public string GetValues() => $"'{Value}'"; + public void Initialize() => Value = HelloSlonik; + } - public string GetAttributes() => "string_value text, int_value integer"; - public string GetValues() => $"'{StringValue}', {IntValue}"; + public class ClassWithField : IComposite, IInitializable + { + public string? Value; + public string GetAttributes() => "value text"; + public string GetValues() => $"'{Value}'"; + public void Initialize() => Value = HelloSlonik; + } - public void Initialize() - { - IntValue = TheAnswer; - StringValue = HelloSlonik; - } - } + public struct StructWithProperty : IComposite, IInitializable + { + public string? Value { get; set; } + public string GetAttributes() => "value text"; + public string GetValues() => $"'{Value}'"; + public void Initialize() => Value = HelloSlonik; + } - public abstract class SimpleComposite : IComposite - { - public string GetAttributes() => "value text"; - public string GetValues() => $"'{GetValue()}'"; + public struct StructWithField : IComposite, IInitializable + { + public string? Value; + public string GetAttributes() => "value text"; + public string GetValues() => $"'{Value}'"; + public void Initialize() => Value = HelloSlonik; + } - protected virtual string GetValue() => HelloSlonik; - } + public class TypeWithTwoProperties : IComposite, IInitializable + { + public int IntValue { get; set; } + public string? StringValue { get; set; } - public class TypeWithPrivateProperty : SimpleComposite - { - private string? Value { get; set; } - } + public string GetAttributes() => "int_value integer, string_value text"; + public string GetValues() => $"{IntValue}, '{StringValue}'"; - public class TypeWithPrivateGetter : SimpleComposite + public void Initialize() { - public string? Value { private get; set; } + IntValue = TheAnswer; + StringValue = HelloSlonik; } + } + + public class TypeWithTwoPropertiesReversed : IComposite, IInitializable + { + public int IntValue { get; set; } + public string? StringValue { get; set; } + + public string GetAttributes() => "string_value text, int_value integer"; + public string GetValues() => $"'{StringValue}', {IntValue}"; - public class TypeWithPrivateSetter : SimpleComposite + public void Initialize() { - public string? Value { get; private set; } + IntValue = TheAnswer; + StringValue = HelloSlonik; } + } - public class TypeWithoutGetter : SimpleComposite - { + public abstract class SimpleComposite : IComposite + { + public string GetAttributes() => "value text"; + public string GetValues() => $"'{GetValue()}'"; - public string? Value { set { } } - } + protected virtual string GetValue() => HelloSlonik; + } - public class TypeWithoutSetter : SimpleComposite - { - public string? Value { get; } - } + public class TypeWithPrivateProperty : SimpleComposite + { + private string? Value { get; set; } + } - public class TypeWithExplicitPropertyName : SimpleComposite - { - [PgName("value")] - public string MyValue { get; set; } = string.Empty; - protected override string GetValue() => MyValue; - } + public class TypeWithPrivateGetter : SimpleComposite + { + public string? Value { private get; set; } + } - public class TypeWithExplicitParameterName : SimpleComposite - { - public TypeWithExplicitParameterName([PgName("value")] string myValue) => Value = myValue; - public string Value { get; } - protected override string GetValue() => Value; - } + public class TypeWithPrivateSetter : SimpleComposite + { + public string? Value { get; private set; } + } - public class TypeWithMorePropertiesThanAttributes : IComposite - { - public string GetAttributes() => "int_value integer"; - public string GetValues() => $"{IntValue}"; + public class TypeWithoutGetter : SimpleComposite + { - public int IntValue { get; set; } - public string? StringValue { get; set; } - } + public string? Value { set { } } + } - public class TypeWithLessPropertiesThanAttributes : IComposite - { - public string GetAttributes() => "int_value integer, string_value text"; - public string GetValues() => $"{IntValue}, NULL"; + public class TypeWithoutSetter : SimpleComposite + { + public string? Value { get; } + } - public int IntValue { get; set; } - } - public class TypeWithMoreParametersThanAttributes : IComposite - { - public string GetAttributes() => "int_value integer"; - public string GetValues() => $"{IntValue}"; + public class TypeWithExplicitPropertyName : SimpleComposite + { + [PgName("value")] + public string MyValue { get; set; } = string.Empty; + protected override string GetValue() => MyValue; + } - public TypeWithMoreParametersThanAttributes(int intValue, string? stringValue) - { - IntValue = intValue; - StringValue = stringValue; - } + public class TypeWithExplicitParameterName : SimpleComposite + { + public TypeWithExplicitParameterName([PgName("value")] string myValue) => Value = myValue; + public string Value { get; } + protected override string GetValue() => Value; + } - public int IntValue { get; set; } - public string? StringValue { get; set; } - } + public class TypeWithMorePropertiesThanAttributes : IComposite + { + public string GetAttributes() => "int_value integer"; + public string GetValues() => $"{IntValue}"; - public class TypeWithLessParametersThanAttributes : IComposite - { - public string GetAttributes() => "int_value integer, string_value text"; - public string GetValues() => $"{IntValue}, NULL"; + public int IntValue { get; set; } + public string? StringValue { get; set; } + } - public TypeWithLessParametersThanAttributes(int intValue) => - IntValue = intValue; + public class TypeWithLessPropertiesThanAttributes : IComposite + { + public string GetAttributes() => "int_value integer, string_value text"; + public string GetValues() => $"{IntValue}, NULL"; - public int IntValue { get; } - } + public int IntValue { get; set; } + } + public class TypeWithMoreParametersThanAttributes : IComposite + { + public string GetAttributes() => "int_value integer"; + public string GetValues() => $"{IntValue}"; - public class TypeWithOneParameter : IComposite + public TypeWithMoreParametersThanAttributes(int intValue, string? stringValue) { - public string GetAttributes() => "value1 integer"; - public string GetValues() => $"{Value1}"; - - public TypeWithOneParameter(int value1) => Value1 = value1; - public int Value1 { get; } + IntValue = intValue; + StringValue = stringValue; } - public class TypeWithTwoParameters : IComposite - { - public string GetAttributes() => "int_value integer, string_value text"; - public string GetValues() => $"{IntValue}, '{StringValue}'"; + public int IntValue { get; set; } + public string? StringValue { get; set; } + } - public TypeWithTwoParameters(int intValue, string stringValue) => - (IntValue, StringValue) = (intValue, stringValue); + public class TypeWithLessParametersThanAttributes : IComposite + { + public string GetAttributes() => "int_value integer, string_value text"; + public string GetValues() => $"{IntValue}, NULL"; - public int IntValue { get; } - public string? StringValue { get; } - } + public TypeWithLessParametersThanAttributes(int intValue) => + IntValue = intValue; - public class TypeWithTwoParametersReversed : IComposite - { - public string GetAttributes() => "int_value integer, string_value text"; - public string GetValues() => $"{IntValue}, '{StringValue}'"; + public int IntValue { get; } + } - public TypeWithTwoParametersReversed(string stringValue, int intValue) => - (StringValue, IntValue) = (stringValue, intValue); + public class TypeWithOneParameter : IComposite + { + public string GetAttributes() => "value1 integer"; + public string GetValues() => $"{Value1}"; - public int IntValue { get; } - public string? StringValue { get; } - } + public TypeWithOneParameter(int value1) => Value1 = value1; + public int Value1 { get; } + } - public class TypeWithNineParameters : IComposite - { - public string GetAttributes() => "value1 integer, value2 integer, value3 integer, value4 integer, value5 integer, value6 integer, value7 integer, value8 integer, value9 integer"; - public string GetValues() => $"{Value1}, {Value2}, {Value3}, {Value4}, {Value5}, {Value6}, {Value7}, {Value8}, {Value9}"; - - public TypeWithNineParameters(int value1, int value2, int value3, int value4, int value5, int value6, int value7, int value8, int value9) - => (Value1, Value2, Value3, Value4, Value5, Value6, Value7, Value8, Value9) = (value1, value2, value3, value4, value5, value6, value7, value8, value9); - - public int Value1 { get; } - public int Value2 { get; } - public int Value3 { get; } - public int Value4 { get; } - public int Value5 { get; } - public int Value6 { get; } - public int Value7 { get; } - public int Value8 { get; } - public int Value9 { get; } - } + public class TypeWithTwoParameters : IComposite + { + public string GetAttributes() => "int_value integer, string_value text"; + public string GetValues() => $"{IntValue}, '{StringValue}'"; + + public TypeWithTwoParameters(int intValue, string stringValue) => + (IntValue, StringValue) = (intValue, stringValue); + + public int IntValue { get; } + public string? StringValue { get; } + } + + public class TypeWithTwoParametersReversed : IComposite + { + public string GetAttributes() => "int_value integer, string_value text"; + public string GetValues() => $"{IntValue}, '{StringValue}'"; + + public TypeWithTwoParametersReversed(string stringValue, int intValue) => + (StringValue, IntValue) = (stringValue, intValue); + + public int IntValue { get; } + public string? StringValue { get; } + } + + public class TypeWithNineParameters : IComposite + { + public string GetAttributes() => "value1 integer, value2 integer, value3 integer, value4 integer, value5 integer, value6 integer, value7 integer, value8 integer, value9 integer"; + public string GetValues() => $"{Value1}, {Value2}, {Value3}, {Value4}, {Value5}, {Value6}, {Value7}, {Value8}, {Value9}"; + + public TypeWithNineParameters(int value1, int value2, int value3, int value4, int value5, int value6, int value7, int value8, int value9) + => (Value1, Value2, Value3, Value4, Value5, Value6, Value7, Value8, Value9) = (value1, value2, value3, value4, value5, value6, value7, value8, value9); + + public int Value1 { get; } + public int Value2 { get; } + public int Value3 { get; } + public int Value4 { get; } + public int Value5 { get; } + public int Value6 { get; } + public int Value7 { get; } + public int Value8 { get; } + public int Value9 { get; } } } diff --git a/test/Npgsql.Tests/Types/CompositeTests.cs b/test/Npgsql.Tests/Types/CompositeTests.cs index 22a4cf9e06..713d5220a3 100644 --- a/test/Npgsql.Tests/Types/CompositeTests.cs +++ b/test/Npgsql.Tests/Types/CompositeTests.cs @@ -1,749 +1,723 @@ using System; -using System.Data; -using System.Dynamic; using System.Linq; +using System.Reflection; +using System.Threading.Tasks; using Npgsql.PostgresTypes; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class CompositeTests : MultiplexingTestBase { - [NonParallelizable] - public class CompositeTests : TestBase + [Test] + public async Task Basic() { - #region Test Types + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + type, + npgsqlDbType: null); + } -#pragma warning disable CS8618 - class SomeComposite - { - public int X { get; set; } - public string SomeText { get; set; } - } + [Test] + public async Task Basic_with_custom_default_translator() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, s text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.DefaultNameTranslator = new CustomTranslator(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + type, + npgsqlDbType: null); + } - class SomeCompositeContainer - { - public int A { get; set; } - public SomeComposite Contained { get; set; } - } + [Test] + public async Task Basic_with_custom_translator() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, s text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type, new CustomTranslator()); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + type, + npgsqlDbType: null); + } - struct CompositeStruct - { - public int X { get; set; } - public string SomeText { get; set; } - } -#pragma warning restore CS8618 + class CustomTranslator : INpgsqlNameTranslator + { + public string TranslateTypeName(string clrName) => throw new NotImplementedException(); - #endregion + public string TranslateMemberName(string clrName) => clrName[0].ToString().ToLowerInvariant(); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1779")] - public void CompositePostgresType() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(PostgresType), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.comp1 AS (x int, some_text text)"); - conn.ExecuteNonQuery("CREATE TYPE pg_temp.comp2 AS (comp comp1, comps comp1[])"); - conn.ReloadTypes(); - - using (var cmd = new NpgsqlCommand("SELECT ROW(ROW(8, 'foo')::comp1, ARRAY[ROW(9, 'bar')::comp1, ROW(10, 'baz')::comp1])::comp2", conn)) - { - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var comp2Type = (PostgresCompositeType)reader.GetPostgresType(0); - Assert.That(comp2Type.Name, Is.EqualTo("comp2")); - Assert.That(comp2Type.FullName, Does.StartWith("pg_temp_") & Does.EndWith(".comp2")); - Assert.That(comp2Type.Fields, Has.Count.EqualTo(2)); - var field1 = comp2Type.Fields[0]; - var field2 = comp2Type.Fields[1]; - Assert.That(field1.Name, Is.EqualTo("comp")); - Assert.That(field2.Name, Is.EqualTo("comps")); - var comp1Type = (PostgresCompositeType)field1.Type; - Assert.That(comp1Type.Name, Is.EqualTo("comp1")); - var arrType = (PostgresArrayType)field2.Type; - Assert.That(arrType.Name, Is.EqualTo("comp1[]")); - var elemType = arrType.Element; - Assert.That(elemType, Is.SameAs(comp1Type)); - } - } - } - } +#pragma warning disable CS0618 // GlobalTypeMapper is obsolete + [Test, NonParallelizable] + public async Task Global_mapping() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); - [Test, Description("Resolves an enum type handler via the different pathways, with global mapping")] - public void CompositeTypeResolutionWithGlobalMapping() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(CompositeTypeResolutionWithGlobalMapping), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite1 AS (x int, some_text text)"); - NpgsqlConnection.GlobalTypeMapper.MapComposite("composite1"); - try - { - conn.ReloadTypes(); - - // Resolve type by DataTypeName - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p", - DataTypeName = "composite1", - Value = DBNull.Value - }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite1")); - Assert.That(reader.IsDBNull(0), Is.True); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = new SomeComposite { X = 8, SomeText = "foo" }}); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite1")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT ROW(1, 'foo')::COMPOSITE1", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite1")); - } - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapComposite("composite1"); - } - } - } + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text)"); + NpgsqlConnection.GlobalTypeMapper.MapComposite(type); - [Test, Description("Resolves a composite type handler via the different pathways, with late mapping")] - public void CompositeTypeResolutionWithLateMapping() + try { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(CompositeTypeResolutionWithLateMapping), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite2 AS (x int, some_text text)"); - // Resolve type by NpgsqlDbType - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite2"); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p", - DataTypeName = "composite2", - Value = DBNull.Value - }); - - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite2")); - Assert.That(reader.IsDBNull(0), Is.True); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite2"); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = new SomeComposite { X = 8, SomeText = "foo" } }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite2")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite2"); - using (var cmd = new NpgsqlCommand("SELECT ROW(1, 'foo')::COMPOSITE2", conn)) - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Does.StartWith("pg_temp").And.EndWith(".composite2")); - } - } + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + await connection.ReloadTypesAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + type, + npgsqlDbType: null); } - - [Test, Parallelizable(ParallelScope.None)] - public void LateMapping() + finally { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(LateMapping), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite3 AS (x int, some_text text)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite3"); - - var expected = new SomeComposite {X = 8, SomeText = "foo"}; - using (var cmd = new NpgsqlCommand("SELECT @p1::composite3, @p2::composite3", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p1", - DataTypeName = "composite3", - Value = expected - }); - cmd.Parameters.AddWithValue("p2", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - var actual = reader.GetFieldValue(i); - Assert.That(actual.X, Is.EqualTo(8)); - Assert.That(actual.SomeText, Is.EqualTo("foo")); - } - } - } - } + NpgsqlConnection.GlobalTypeMapper.Reset(); } + } +#pragma warning restore CS0618 // GlobalTypeMapper is obsolete - [Test] - public void GlobalMapping() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(LateMapping), - Pooling = false - }; - try - { - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("DROP TYPE IF EXISTS composite4"); - conn.ExecuteNonQuery("CREATE TYPE composite4 AS (x int, some_text text)"); - NpgsqlConnection.GlobalTypeMapper.MapComposite("composite4"); - conn.ReloadTypes(); - - var expected = new SomeComposite { X = 8, SomeText = "foo" }; - using (var cmd = new NpgsqlCommand($"SELECT @p::composite4", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That(actual.X, Is.EqualTo(8)); - Assert.That(actual.SomeText, Is.EqualTo("foo")); - } - } - } - - // Unmap - NpgsqlConnection.GlobalTypeMapper.UnmapComposite("composite4"); - - using (var conn = OpenConnection(csb)) - { - Assert.That(() => conn.ExecuteScalar("SELECT '(8, \"foo\")'::composite4"), Throws.TypeOf()); - } - } - finally - { - using (var conn = OpenConnection(csb)) - conn.ExecuteNonQuery("DROP TYPE IF EXISTS composite4"); - } - } + [Test] + public async Task Nested() + { + await using var adminConnection = await OpenConnectionAsync(); + var containerType = await GetTempTypeName(adminConnection); + var containeeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {containeeType} AS (x int, some_text text); +CREATE TYPE {containerType} AS (a int, containee {containeeType});"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + // Registration in inverse dependency order should work + dataSourceBuilder + .MapComposite(containerType) + .MapComposite(containeeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, + @"(8,""(9,foo)"")", + containerType, + npgsqlDbType: null); + } - [Test, Description("Tests a composite within another composite")] - public void Recursive() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(Recursive), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite_contained AS (x int, some_text text)"); - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite_container AS (a int, contained composite_contained)"); - conn.ReloadTypes(); - // Registration in inverse dependency order should work - conn.TypeMapper.MapComposite("composite_container"); - conn.TypeMapper.MapComposite("composite_contained"); - - var expected = new SomeCompositeContainer { - A = 4, - Contained = new SomeComposite {X = 8, SomeText = "foo"} - }; - - using (var cmd = new NpgsqlCommand("SELECT @p::composite_container", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That(actual.A, Is.EqualTo(4)); - Assert.That(actual.Contained.X, Is.EqualTo(8)); - Assert.That(actual.Contained.SomeText, Is.EqualTo("foo")); - } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1168")] + public async Task With_schema() + { + await using var adminConnection = await OpenConnectionAsync(); + var schema = await CreateTempSchema(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {schema}.some_composite AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite($"{schema}.some_composite"); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + $"{schema}.some_composite", + npgsqlDbType: null); + } - [Test] - public void Struct() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(Struct), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite_struct AS (x int, some_text text)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite_struct"); - - var expected = new CompositeStruct {X = 8, SomeText = "foo"}; - using (var cmd = new NpgsqlCommand("SELECT @p::composite_struct", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue(0); - Assert.That(actual.X, Is.EqualTo(8)); - Assert.That(actual.SomeText, Is.EqualTo("foo")); - } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4365")] + public async Task In_different_schemas_same_type_with_nested() + { + await using var adminConnection = await OpenConnectionAsync(); + var firstSchemaName = await CreateTempSchema(adminConnection); + var secondSchemaName = await CreateTempSchema(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {firstSchemaName}.containee AS (x int, some_text text); +CREATE TYPE {firstSchemaName}.container AS (a int, containee {firstSchemaName}.containee); +CREATE TYPE {secondSchemaName}.containee AS (x int, some_text text); +CREATE TYPE {secondSchemaName}.container AS (a int, containee {secondSchemaName}.containee);"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder + .MapComposite($"{firstSchemaName}.containee") + .MapComposite($"{firstSchemaName}.container") + .MapComposite($"{secondSchemaName}.containee") + .MapComposite($"{secondSchemaName}.container"); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, + @"(8,""(9,foo)"")", + $"{secondSchemaName}.container", + npgsqlDbType: null, + isDefaultForWriting: false); + + await AssertType( + connection, + new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, + @"(8,""(9,foo)"")", + $"{firstSchemaName}.container", + npgsqlDbType: null, + isDefaultForWriting: true); + } - [Test] - public void Array() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(Array), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.composite5 AS (x int, some_text text)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite5"); - - var expected = new[] { - new SomeComposite {X = 8, SomeText = "foo"}, - new SomeComposite {X = 9, SomeText = "bar"} - }; - - using (var cmd = new NpgsqlCommand("SELECT @p1::composite5[], @p2::composite5[]", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p1", - DataTypeName = "composite5[]", - Value = expected - }); - cmd.Parameters.AddWithValue("p2", expected); // Infer - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - var actual = reader.GetFieldValue(i); - Assert.That(actual[0].X, Is.EqualTo(expected[0].X)); - Assert.That(actual[0].SomeText, Is.EqualTo(expected[0].SomeText)); - Assert.That(actual[1].X, Is.EqualTo(expected[1].X)); - Assert.That(actual[1].SomeText, Is.EqualTo(expected[1].SomeText)); - } - } - } - } - } + [Test] + public async Task Struct() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeStruct { SomeText = "foo", X = 8 }, + "(8,foo)", + type, + npgsqlDbType: null); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] - public void NameTranslation() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(LateMapping), - Pooling = false - }; - var expected = new NameTranslationComposite { Simple = 2, TwoWords = 3, SomeClrName = 4 }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.name_translation_composite AS (simple int, two_words int, some_database_name int)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite(); - - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - var actual = (NameTranslationComposite)cmd.ExecuteScalar()!; - Assert.That(actual.Simple, Is.EqualTo(expected.Simple)); - Assert.That(actual.TwoWords, Is.EqualTo(expected.TwoWords)); - Assert.That(actual.SomeClrName, Is.EqualTo(expected.SomeClrName)); - } - } - } + [Test] + public async Task Array() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite[] { new() { SomeText = "foo", X = 8 }, new() { SomeText = "bar", X = 9 }}, + @"{""(8,foo)"",""(9,bar)""}", + type + "[]", + npgsqlDbType: null); + } - class NameTranslationComposite - { - public int Simple { get; set; } - public int TwoWords { get; set; } - [PgName("some_database_name")] - public int SomeClrName { get; set; } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] + public async Task Name_translation() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync(@$" +CREATE TYPE {type} AS (simple int, two_words int, some_database_name int)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new NameTranslationComposite { Simple = 2, TwoWords = 3, SomeClrName = 4 }, + "(2,3,4)", + type, + npgsqlDbType: null); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/856")] - public void Domain() - { - var setupSql = @"SET search_path=pg_temp; - -CREATE DOMAIN us_postal_code AS TEXT -CHECK -( - VALUE ~ '^\d{5}$' - OR VALUE ~ '^\d{5}-\d{4}$' -); - -CREATE TYPE address AS -( - street TEXT, - postal_code us_postal_code -)"; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(Domain), - Pooling = false - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery(setupSql); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite
(); - - var expected = new Address { PostalCode = "12345", Street = "Main St."}; - using (var cmd = new NpgsqlCommand(@"SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value=expected }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - var actual = reader.GetFieldValue
(0); - Assert.That(actual.Street, Is.EqualTo(expected.Street)); - Assert.That(actual.PostalCode, Is.EqualTo(expected.PostalCode)); - } - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/856")] + public async Task Composite_containing_domain_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var domainType = await GetTempTypeName(adminConnection); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE DOMAIN {domainType} AS TEXT; +CREATE TYPE {compositeType} AS (street TEXT, postal_code {domainType})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite
(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new Address { PostalCode = "12345", Street = "Main St." }, + @"(""Main St."",12345)", + compositeType, + npgsqlDbType: null); + } - public class Address - { - public string Street { get; set; } = default!; - public string PostalCode { get; set; } = default!; - } + [Test] + public async Task Composite_containing_array_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {compositeType} AS (ints int4[])"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeWithArray { Ints = new[] { 1, 2, 3, 4 } }, + @"(""{1,2,3,4}"")", + compositeType, + npgsqlDbType: null, + comparer: (actual, expected) => actual.Ints!.SequenceEqual(expected.Ints!)); + } - class TableAsCompositeType - { - public int Foo { get; set; } - } + [Test] + public async Task Composite_containing_converter_resolver_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {compositeType} AS (date_times timestamp[])"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeWithConverterResolverType { DateTimes = new [] { new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified), new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified).AddDays(1) } }, + """("{""1970-01-01 00:00:00"",""1970-01-02 00:00:00""}")""", + compositeType, + npgsqlDbType: null, + comparer: (actual, expected) => actual.DateTimes!.SequenceEqual(expected.DateTimes!)); + } - #region Table as Composite + [Test] + public async Task Composite_containing_converter_resolver_type_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {compositeType} AS (date_times timestamp[])"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + Assert.ThrowsAsync(() => AssertType( + connection, + new SomeCompositeWithConverterResolverType { DateTimes = new[] { DateTime.UnixEpoch } }, // UTC DateTime + """("{""1970-01-01 01:00:00"",""1970-01-02 01:00:00""}")""", + compositeType, + npgsqlDbType: null, + comparer: (actual, expected) => actual.DateTimes!.SequenceEqual(expected.DateTimes!))); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/990")] - public void TableAsCompositeNotSupportedByDefault() + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/990")] + public async Task Table_as_composite([Values] bool enabled) + { + await using var adminConnection = await OpenConnectionAsync(); + var table = await CreateTempTable(adminConnection, "x int, some_text text"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(table); + if (enabled) + dataSourceBuilder.ConnectionStringBuilder.LoadTableComposites = true; + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + if (enabled) + await DoAssertion(); + else { - using (var conn = OpenConnection()) - { - conn.ExecuteNonQuery("CREATE TEMP TABLE table_as_composite (foo int); INSERT INTO table_as_composite (foo) VALUES (8)"); - conn.ReloadTypes(); - Assert.That(() => conn.TypeMapper.MapComposite("table_as_composite"), Throws.Exception.TypeOf()); - } + Assert.ThrowsAsync(DoAssertion); + // Start a transaction specifically for multiplexing (to bind a connector to the connection) + await using var tx = await connection.BeginTransactionAsync(); + Assert.Null(connection.Connector!.DatabaseInfo.CompositeTypes.SingleOrDefault(c => c.Name.Contains(table))); + Assert.Null(connection.Connector!.DatabaseInfo.ArrayTypes.SingleOrDefault(c => c.Name.Contains(table))); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/990")] - public void TableAsComposite() + Task DoAssertion() + => AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + table, + npgsqlDbType: null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1267")] + public async Task Table_as_composite_with_deleted_columns() + { + await using var adminConnection = await OpenConnectionAsync(); + var table = await CreateTempTable(adminConnection, "x int, some_text text, bar int"); + await adminConnection.ExecuteNonQueryAsync($"ALTER TABLE {table} DROP COLUMN bar;"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.LoadTableComposites = true; + dataSourceBuilder.MapComposite(table); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foo", X = 8 }, + "(8,foo)", + table, + npgsqlDbType: null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] + public async Task Nullable_property_in_class_composite() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (foo INT)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new ClassWithNullableProperty { Foo = 8 }, + "(8)", + type, + npgsqlDbType: null); + + await AssertType( + connection, + new ClassWithNullableProperty { Foo = null }, + "()", + type, + npgsqlDbType: null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] + public async Task Nullable_property_in_struct_composite() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (foo INT)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new StructWithNullableProperty { Foo = 8 }, + "(8)", + type, + npgsqlDbType: null); + + await AssertType( + connection, + new StructWithNullableProperty { Foo = null }, + "()", + type, + npgsqlDbType: null); + } + + [Test] + public async Task PostgresType() + { + await using var connection = await OpenConnectionAsync(); + var type1 = await GetTempTypeName(connection); + var type2 = await GetTempTypeName(connection); + + await connection.ExecuteNonQueryAsync(@$" +CREATE TYPE {type1} AS (x int, some_text text); +CREATE TYPE {type2} AS (comp {type1}, comps {type1}[]);"); + await connection.ReloadTypesAsync(); + + await using var cmd = new NpgsqlCommand( + $"SELECT ROW(ROW(8, 'foo')::{type1}, ARRAY[ROW(9, 'bar')::{type1}, ROW(10, 'baz')::{type1}])::{type2}", + connection); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var comp2Type = (PostgresCompositeType)reader.GetPostgresType(0); + Assert.That(comp2Type.Name, Is.EqualTo(type2)); + Assert.That(comp2Type.FullName, Is.EqualTo($"public.{type2}")); + Assert.That(comp2Type.Fields, Has.Count.EqualTo(2)); + var field1 = comp2Type.Fields[0]; + var field2 = comp2Type.Fields[1]; + Assert.That(field1.Name, Is.EqualTo("comp")); + Assert.That(field2.Name, Is.EqualTo("comps")); + var comp1Type = (PostgresCompositeType)field1.Type; + Assert.That(comp1Type.Name, Is.EqualTo(type1)); + var arrType = (PostgresArrayType)field2.Type; + Assert.That(arrType.Name, Is.EqualTo(type1 + "[]")); + var elemType = arrType.Element; + Assert.That(elemType, Is.SameAs(comp1Type)); + } + + [Test] + public async Task DuplicateConstructorParameters() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (long int8, boolean bool)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var ex = Assert.ThrowsAsync(async () => await AssertType( + connection, + new DuplicateOneLongOneBool(true, 1), + "(1,t)", + type, + npgsqlDbType: null)); + Assert.That(ex!.InnerException, Is.TypeOf()); + } + + [Test] + public async Task PartialConstructorMissingSetter() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (long int8, boolean bool)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var ex = Assert.ThrowsAsync(async () => await AssertTypeRead( + connection, + "(1,t)", + type, + new MissingSetterOneLongOneBool(true, 1))); + Assert.That(ex, Is.TypeOf().With.Message.Contains("No (public) setter for")); + } + + [Test] + public async Task PartialConstructorWorks() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (long int8, boolean bool)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new OneLongOneBool(1) { BooleanValue = true }, + "(1,t)", + type, + npgsqlDbType: null); + } + + [Test] + public async Task CompositeOverRange() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text); CREATE TYPE {rangeType} AS RANGE(subtype={type})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var composite1 = new SomeComposite { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(TableAsComposite), - LoadTableComposites = true - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery( - "CREATE TEMP TABLE table_as_composite (foo int);" + - "INSERT INTO table_as_composite (foo) VALUES (8)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("table_as_composite"); - var value = (TableAsCompositeType)conn.ExecuteScalar(@"SELECT t.*::table_as_composite FROM table_as_composite AS t")!; - Assert.That(value.Foo, Is.EqualTo(8)); - } - } + SomeText = "foo", + X = 8 + }; - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1267")] - public void TableAsCompositeWithDeleteColumns() + var composite2 = new SomeComposite { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(TableAsCompositeWithDeleteColumns), - LoadTableComposites = true - }; - - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery(@" - CREATE TEMP TABLE table_as_composite2 (foo int, bar int); - ALTER TABLE table_as_composite2 DROP COLUMN bar; - INSERT INTO table_as_composite2 (foo) VALUES (8)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("table_as_composite2"); - var value = (TableAsCompositeType)conn.ExecuteScalar(@"SELECT t.*::table_as_composite2 FROM table_as_composite2 AS t")!; - Assert.That(value.Foo, Is.EqualTo(8)); - } - } + SomeText = "bar", + X = 42 + }; + + await AssertType( + connection, + new NpgsqlRange(composite1, composite2), + "[\"(8,foo)\",\"(42,bar)\"]", + rangeType, + npgsqlDbType: null, + isDefaultForWriting: false); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2668")] - public void TableCompositesNotLoadedIfNotRequested() + #region Test Types + + readonly struct DuplicateOneLongOneBool + { + public DuplicateOneLongOneBool(bool boolean, [PgName("boolean")]int @bool) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(TableCompositesNotLoadedIfNotRequested) - }; - - using var conn = OpenConnection(csb); - conn.ExecuteNonQuery("CREATE TEMP TABLE table_as_composite3 (foo int, bar int)"); - conn.ReloadTypes(); - - Assert.Throws(() => conn.TypeMapper.MapComposite("table_as_composite3")); - Assert.Null(conn.Connector!.DatabaseInfo.CompositeTypes.SingleOrDefault(c => c.Name.Contains("table_as_composite3"))); - Assert.Null(conn.Connector!.DatabaseInfo.ArrayTypes.SingleOrDefault(c => c.Name.Contains("table_as_composite3"))); } - #endregion Table as Composite + [PgName("long")] + public long LongValue { get; } + + [PgName("boolean")] + public bool BooleanValue { get; } + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] - public void NullablePropertyInClassComposite() + readonly struct MissingSetterOneLongOneBool + { + public MissingSetterOneLongOneBool(long @long) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(NullablePropertyInClassComposite) - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.nullable_property_type AS (foo INT)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("nullable_property_type"); - - var expected1 = new ClassWithNullableProperty { Foo = 8 }; - var expected2 = new ClassWithNullableProperty { Foo = null }; - using (var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", expected1); - cmd.Parameters.AddWithValue("p2", expected2); - - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0).Foo, Is.EqualTo(8)); - Assert.That(reader.GetFieldValue(1).Foo, Is.Null); - } - } - } + LongValue = @long; } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] - public void NullablePropertyInStructComposite() + public MissingSetterOneLongOneBool(bool boolean, [PgName("boolean")]int @bool) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(NullablePropertyInStructComposite) - }; - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE pg_temp.nullable_property_type AS (foo INT)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("nullable_property_type"); - - var expected1 = new StructWithNullableProperty { Foo = 8 }; - var expected2 = new StructWithNullableProperty { Foo = null }; - using (var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", expected1); - cmd.Parameters.AddWithValue("p2", expected2); - - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0).Foo, Is.EqualTo(8)); - Assert.That(reader.GetFieldValue(1).Foo, Is.Null); - } - } - } } - class ClassWithNullableProperty + [PgName("long")] + public long LongValue { get; } + + [PgName("boolean")] + public bool BooleanValue { get; } + } + + struct OneLongOneBool + { + public OneLongOneBool(bool boolean, [PgName("boolean")]int @bool) { - public int? Foo { get; set; } } - struct StructWithNullableProperty + public OneLongOneBool(long @long) { - public int? Foo { get; set; } + LongValue = @long; } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1168")] - public void WithSchema() + public OneLongOneBool(double other) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(WithSchema), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema CASCADE; CREATE SCHEMA composite_schema"); - try - { - conn.ExecuteNonQuery("CREATE TYPE composite_schema.composite AS (foo int)"); - conn.ReloadTypes(); - conn.TypeMapper.MapComposite("composite_schema.composite"); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", new Composite1 { Foo = 8 }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("composite_schema.composite")); - Assert.That(reader.GetFieldValue(0).Foo, Is.EqualTo(8)); - } - } - } - finally - { - if (conn.State == ConnectionState.Open) - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema CASCADE"); - } - } } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1168")] - public void InDifferentSchemas() + public OneLongOneBool(int boolean, [PgName("boolean")]bool @bool) { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(InDifferentSchemas), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema1 CASCADE; CREATE SCHEMA composite_schema1"); - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema2 CASCADE; CREATE SCHEMA composite_schema2"); - try - { - conn.ExecuteNonQuery("CREATE TYPE composite_schema1.composite AS (foo int)"); - conn.ExecuteNonQuery("CREATE TYPE composite_schema2.composite AS (bar int)"); - conn.ReloadTypes(); - // Attempting to map without a fully-qualified name should fail - Assert.That(() => conn.TypeMapper.MapComposite("composite"), - Throws.Exception.TypeOf() - ); - conn.TypeMapper - .MapComposite("composite_schema1.composite") - .MapComposite("composite_schema2.composite"); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", new Composite1 { Foo = 8 }); - cmd.Parameters.AddWithValue("p2", new Composite2 { Bar = 9 }); - using (var reader = cmd.ExecuteReader()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("composite_schema1.composite")); - Assert.That(reader.GetFieldValue(0).Foo, Is.EqualTo(8)); - Assert.That(reader.GetDataTypeName(1), Is.EqualTo("composite_schema2.composite")); - Assert.That(reader.GetFieldValue(1).Bar, Is.EqualTo(9)); - } - } - } - finally - { - if (conn.State == ConnectionState.Open) - { - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema1 CASCADE"); - conn.ExecuteNonQuery("DROP SCHEMA IF EXISTS composite_schema2 CASCADE"); - } - } - } } - class Composite1 { public int Foo { get; set; } } - class Composite2 { public int Bar { get; set; } } - class Composite3 { public int Bar { get; set; } } + [PgName("long")] + public long LongValue { get; } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1612")] - public void LocalMappingDontLeak() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Pooling = false, - ApplicationName = nameof(LocalMappingDontLeak) - }; - NpgsqlConnection.GlobalTypeMapper.MapComposite("composite"); - try - { - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("CREATE TYPE composite AS (bar int)"); - conn.ReloadTypes(); - Assert.That(conn.ExecuteScalar("SELECT '(8)'::composite"), Is.TypeOf()); - conn.TypeMapper.MapComposite("composite"); - Assert.That(conn.ExecuteScalar("SELECT '(8)'::composite"), Is.TypeOf()); - } - using (var conn = OpenConnection(csb)) - Assert.That(conn.ExecuteScalar("SELECT '(8)'::composite"), Is.TypeOf()); - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapComposite("composite"); - using (var conn = OpenConnection(csb)) - { - conn.ExecuteNonQuery("DROP TYPE IF EXISTS composite"); - NpgsqlConnection.ClearPool(conn); - } - } - } + [PgName("boolean")] + public bool BooleanValue { get; set; } } + + + record SomeComposite + { + public int X { get; set; } + public string SomeText { get; set; } = null!; + } + + record SomeCompositeContainer + { + public int A { get; set; } + public SomeComposite Containee { get; set; } = null!; + } + + struct SomeCompositeStruct + { + public int X { get; set; } + public string SomeText { get; set; } + } + + class SomeCompositeWithArray + { + public int[]? Ints { get; set; } + } + + class SomeCompositeWithConverterResolverType + { + public DateTime[]? DateTimes { get; set; } + } + + record NameTranslationComposite + { + public int Simple { get; set; } + public int TwoWords { get; set; } + [PgName("some_database_name")] + public int SomeClrName { get; set; } + } + + record Address + { + public string Street { get; set; } = default!; + public string PostalCode { get; set; } = default!; + } + + record ClassWithNullableProperty + { + public int? Foo { get; set; } + } + + struct StructWithNullableProperty + { + public int? Foo { get; set; } + } + + public CompositeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + + #endregion } diff --git a/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs new file mode 100644 index 0000000000..8508979b31 --- /dev/null +++ b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs @@ -0,0 +1,129 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Util.Statics; + +namespace Npgsql.Tests.Types; + +[TestFixture(true)] +#if DEBUG +[TestFixture(false)] +[NonParallelizable] +#endif +public sealed class DateTimeInfinityTests : TestBase, IDisposable +{ + static readonly TestCaseData[] TimestampDateTimeValues = + { + new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01 00:00:00", "0002-01-01 00:00:00") + .SetName("MinValue_AddYear"), + new TestCaseData(DateTime.MinValue, "0001-01-01 00:00:00", "-infinity") + .SetName("MinValue"), + new TestCaseData(DateTime.MaxValue, "9999-12-31 23:59:59.999999", "infinity") + .SetName("MaxValue"), + }; + + static readonly TestCaseData[] TimestampTzDateTimeValues = + { + new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01 00:00:00+00", "0002-01-01 00:00:00+00") + .SetName("MinValue_AddYear"), + new TestCaseData(DateTime.MinValue, "0001-01-01 00:00:00+00", "-infinity") + .SetName("MinValue"), + new TestCaseData(DateTime.MaxValue, "9999-12-31 23:59:59.999999+00", "infinity") + .SetName("MaxValue"), + }; + + static readonly TestCaseData[] TimestampTzDateTimeOffsetValues = + { + new TestCaseData(DateTimeOffset.MinValue.ToUniversalTime().AddYears(1), "0002-01-01 00:00:00+00", "0002-01-01 00:00:00+00") + .SetName("MinValue_AddYear"), + new TestCaseData(DateTimeOffset.MinValue, "0001-01-01 00:00:00+00", "-infinity") + .SetName("MinValue"), + new TestCaseData(DateTimeOffset.MaxValue, "9999-12-31 23:59:59.999999+00", "infinity") + .SetName("MaxValue"), + }; + + static readonly TestCaseData[] DateDateTimeValues = + { + new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01", "0002-01-01") + .SetName("MinValue_AddYear"), + new TestCaseData(DateTime.MinValue, "0001-01-01", "-infinity") + .SetName("MinValue"), + new TestCaseData(DateTime.MaxValue, "9999-12-31", "infinity") + .SetName("MaxValue"), + }; + + // As we can't roundtrip DateTime.MaxValue due to precision differences with postgres we are lenient with equality for this particular value. + static readonly Func MaxValuePrecisionLenientComparer = + (expected, actual) => expected == DateTime.MaxValue && actual == new DateTime(expected.Ticks - 9) || actual == expected; + + [Test, TestCaseSource(nameof(TimestampDateTimeValues))] + public Task Timestamp_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) + => AssertType(dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + comparer: MaxValuePrecisionLenientComparer, + isDefault: true); + + [Test, TestCaseSource(nameof(TimestampTzDateTimeValues))] + public Task TimestampTz_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) + => AssertType(new(dateTime.Ticks, DateTimeKind.Utc), DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, DbType.DateTime, + comparer: MaxValuePrecisionLenientComparer, + isDefault: true, isNpgsqlDbTypeInferredFromClrType: false); + + [Test, TestCaseSource(nameof(TimestampTzDateTimeOffsetValues))] + public Task TimestampTz_DateTimeOffset(DateTimeOffset dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) + => AssertType(dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, DbType.DateTime, + comparer: (expected, actual) => MaxValuePrecisionLenientComparer(expected.DateTime, actual.DateTime), + isDefault: false); + + [Test, TestCaseSource(nameof(DateDateTimeValues))] + public Task Date_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) + => AssertType(DisableDateTimeInfinityConversions ? dateTime.Date : dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "date", NpgsqlDbType.Date, DbType.Date, + isDefault: false); + +#if NET6_0_OR_GREATER + static readonly TestCaseData[] DateOnlyDateTimeValues = + { + new TestCaseData(DateOnly.MinValue.AddYears(1), "0002-01-01", "0002-01-01") + .SetName("MinValue_AddYear"), + new TestCaseData(DateOnly.MinValue, "0001-01-01", "-infinity") + .SetName("MinValue"), + new TestCaseData(DateOnly.MaxValue, "9999-12-31", "infinity") + .SetName("MaxValue"), + }; + + [Test, TestCaseSource(nameof(DateOnlyDateTimeValues))] + public Task Date_DateOnly(DateOnly dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) + => AssertType(dateTime, + DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, "date", NpgsqlDbType.Date, DbType.Date, + isDefault: false); +#endif + + NpgsqlDataSource? _dataSource; + protected override NpgsqlDataSource DataSource => _dataSource ??= CreateDataSource(csb => csb.Timezone = "UTC"); + + public DateTimeInfinityTests(bool disableDateTimeInfinityConversions) + { +#if DEBUG + DisableDateTimeInfinityConversions = disableDateTimeInfinityConversions; +#else + if (disableDateTimeInfinityConversions) + { + Assert.Ignore( + "DateTimeInfinityTests rely on the Npgsql.DisableDateTimeInfinityConversions AppContext switch and can only be run in DEBUG builds"); + } +#endif + } + + public void Dispose() + { +#if DEBUG + DisableDateTimeInfinityConversions = false; +#endif + DataSource.Dispose(); + } +} diff --git a/test/Npgsql.Tests/Types/DateTimeTests.cs b/test/Npgsql.Tests/Types/DateTimeTests.cs index 7d69433ec0..434b87705f 100644 --- a/test/Npgsql.Tests/Types/DateTimeTests.cs +++ b/test/Npgsql.Tests/Types/DateTimeTests.cs @@ -1,421 +1,546 @@ using System; +using System.Collections.Generic; using System.Data; -using System.Linq; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +// Since this test suite manipulates TimeZone, it is incompatible with multiplexing +public class DateTimeTests : TestBase { - /// - /// Tests on PostgreSQL date/time types - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-datetime.html - /// - public class DateTimeTests : MultiplexingTestBase + #region Date + + [Test] + public Task Date_as_DateTime() + => AssertType(new DateTime(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForWriting: false); + + [Test] + public Task Date_as_DateTime_with_date_and_time_before_2000() + => AssertTypeWrite(new DateTime(1980, 10, 1, 11, 0, 0), "1980-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + + // Internal PostgreSQL representation (days since 2020-01-01), for out-of-range values. + [Test] + public Task Date_as_int() + => AssertType(7579, "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + + [Test] + public Task Daterange_as_NpgsqlRange_of_DateTime() + => AssertType( + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + "[2002-03-04,2002-03-06)", + "daterange", + NpgsqlDbType.DateRange, + isDefaultForWriting: false); + + [Test] + public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateTime() { - #region Date + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - [Test] - public async Task Date() - { - using (var conn = await OpenConnectionAsync()) + await AssertType( + new[] { - var dateTime = new DateTime(2002, 3, 4, 0, 0, 0, 0, DateTimeKind.Unspecified); - var npgsqlDate = new NpgsqlDate(dateTime); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Date) {Value = npgsqlDate}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = npgsqlDate}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Date)); - Assert.That(p2.DbType, Is.EqualTo(DbType.Date)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - // Regular type (DateTime) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(DateTime))); - Assert.That(reader.GetDateTime(i), Is.EqualTo(dateTime)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTime)); - Assert.That(reader[i], Is.EqualTo(dateTime)); - Assert.That(reader.GetValue(i), Is.EqualTo(dateTime)); - - // Provider-specific type (NpgsqlDate) - Assert.That(reader.GetDate(i), Is.EqualTo(npgsqlDate)); - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof(NpgsqlDate))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(npgsqlDate)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(npgsqlDate)); - } - } - } - } - } - - static readonly TestCaseData[] DateSpecialCases = { - new TestCaseData(NpgsqlDate.Infinity).SetName(nameof(DateSpecial) + "Infinity"), - new TestCaseData(NpgsqlDate.NegativeInfinity).SetName(nameof(DateSpecial) + "NegativeInfinity"), - new TestCaseData(new NpgsqlDate(-5, 3, 3)).SetName(nameof(DateSpecial) +"BC"), - }; - - [Test, TestCaseSource(nameof(DateSpecialCases))] - public async Task DateSpecial(NpgsqlDate value) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = value }); - using (var reader = await cmd.ExecuteReaderAsync()) { - reader.Read(); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(value)); - Assert.That(() => reader.GetDateTime(0), Throws.Exception.TypeOf()); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } - - [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] - public async Task DateConvertInfinity() - { - using (var conn = new NpgsqlConnection(ConnectionString + ";ConvertInfinityDateTime=true")) + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) + }, + "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", + "datemultirange", + NpgsqlDbType.DateMultirange, + isDefaultForWriting: false); + } + +#if NET6_0_OR_GREATER + [Test] + public Task Date_as_DateOnly() + => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForReading: false); + + [Test] + public Task Daterange_as_NpgsqlRange_of_DateOnly() + => AssertType( + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + "[2002-03-04,2002-03-06)", + "daterange", + NpgsqlDbType.DateRange, + isDefaultForReading: false, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Daterange_array_as_NpgsqlRange_of_DateOnly_array() + => AssertType( + new[] { - conn.Open(); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Date, DateTime.MaxValue); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Date, DateTime.MinValue); - using (var reader = await cmd.ExecuteReaderAsync()) { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(NpgsqlDate.Infinity)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(NpgsqlDate.NegativeInfinity)); - Assert.That(reader.GetDateTime(0), Is.EqualTo(DateTime.MaxValue)); - Assert.That(reader.GetDateTime(1), Is.EqualTo(DateTime.MinValue)); - } - } - } - } - - #endregion - - #region Time - - [Test] - public async Task Time() - { - using (var conn = await OpenConnectionAsync()) + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 9), false) + }, + """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-09)"}""", + "daterange[]", + NpgsqlDbType.DateRange | NpgsqlDbType.Array, + isDefault: false); + + [Test] + public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateOnly() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - var expected = new TimeSpan(0, 10, 45, 34, 500); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Time) {Value = expected}); - cmd.Parameters.Add(new NpgsqlParameter("p2", DbType.Time) {Value = expected}); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(TimeSpan))); - Assert.That(reader.GetTimeSpan(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader[i], Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - } - } - } - } - - #endregion - - #region Time with timezone - - [Test] - [MonoIgnore] - public async Task TimeTz() - { - using (var conn = await OpenConnectionAsync()) + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) + }, + "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", + "datemultirange", + NpgsqlDbType.DateMultirange, + isDefaultForReading: false); + } +#endif + + #endregion + + #region Time + + [Test] + public Task Time_as_TimeSpan() + => AssertType( + new TimeSpan(0, 10, 45, 34, 500), + "10:45:34.5", + "time without time zone", + NpgsqlDbType.Time, + DbType.Time, + isDefaultForWriting: false); + +#if NET6_0_OR_GREATER + [Test] + public Task Time_as_TimeOnly() + => AssertType( + new TimeOnly(10, 45, 34, 500), + "10:45:34.5", + "time without time zone", + NpgsqlDbType.Time, + DbType.Time, + isDefaultForReading: false); +#endif + + #endregion + + #region Time with timezone + + static readonly TestCaseData[] TimeTzValues = + { + new TestCaseData(new DateTimeOffset(1, 1, 2, 13, 3, 45, 510, TimeSpan.FromHours(2)), "13:03:45.51+02") + .SetName("Timezone"), + new TestCaseData(new DateTimeOffset(1, 1, 2, 1, 0, 45, 510, TimeSpan.FromHours(-3)), "01:00:45.51-03") + .SetName("Negative_timezone"), + new TestCaseData(new DateTimeOffset(1212720130000, TimeSpan.Zero), "09:41:12.013+00") + .SetName("Utc"), + new TestCaseData(new DateTimeOffset(1, 1, 2, 1, 0, 0, new TimeSpan(0, 2, 0, 0)), "01:00:00+02") + .SetName("Before_utc_zero"), + }; + + [Test, TestCaseSource(nameof(TimeTzValues))] + public Task TimeTz_as_DateTimeOffset(DateTimeOffset time, string sqlLiteral) + => AssertType(time, sqlLiteral, "time with time zone", NpgsqlDbType.TimeTz, isDefault: false); + + #endregion + + #region Timestamp + + static readonly TestCaseData[] TimestampValues = + { + new TestCaseData(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "1998-04-12 13:26:38") + .SetName("Timestamp_pre2000"), + new TestCaseData(new DateTime(2015, 1, 27, 8, 45, 12, 345, DateTimeKind.Unspecified), "2015-01-27 08:45:12.345") + .SetName("Timestamp_post2000"), + new TestCaseData(new DateTime(2013, 7, 25, 0, 0, 0, DateTimeKind.Unspecified), "2013-07-25 00:00:00") + .SetName("Timestamp_date_only") + }; + + [Test, TestCaseSource(nameof(TimestampValues))] + public async Task Timestamp_as_DateTime(DateTime dateTime, string sqlLiteral) + { + await AssertType(dateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + // Explicitly check kind as well. + comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); + + await AssertType( + new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp without time zone[]", NpgsqlDbType.Timestamp | NpgsqlDbType.Array, + isDefaultForReading: false); + } + + [Test] + public Task Timestamp_cannot_write_utc_DateTime() + => AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "timestamp without time zone"); + + [Test] + public Task Timestamp_as_long() + => AssertType( + -54297202000000, + "1998-04-12 13:26:38", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime2, + isDefault: false); + + [Test] + public Task Timestamp_cannot_use_as_DateTimeOffset() + => AssertTypeUnsupported( + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + "1998-04-12 13:26:38", + "timestamp without time zone"); + + [Test] + public Task Tsrange_as_NpgsqlRange_of_DateTime() + => AssertType( + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), + @"[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""]", + "tsrange", + NpgsqlDbType.TimestampRange, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Tsrange_array_as_NpgsqlRange_of_DateTime_array() + => AssertType( + new[] { - var tzOffset = TimeZoneInfo.Local.BaseUtcOffset; - if (tzOffset == TimeSpan.Zero) - Assert.Ignore("Test cannot run when machine timezone is UTC"); - - // Note that the date component of the below is ignored - var dto = new DateTimeOffset(5, 5, 5, 13, 3, 45, 510, tzOffset); - var dtUtc = new DateTime(dto.Year, dto.Month, dto.Day, dto.Hour, dto.Minute, dto.Second, dto.Millisecond, DateTimeKind.Utc) - tzOffset; - var dtLocal = new DateTime(dto.Year, dto.Month, dto.Day, dto.Hour, dto.Minute, dto.Second, dto.Millisecond, DateTimeKind.Local); - var dtUnspecified = new DateTime(dto.Year, dto.Month, dto.Day, dto.Hour, dto.Minute, dto.Second, dto.Millisecond, DateTimeKind.Unspecified); - var ts = dto.TimeOfDay; - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.TimeTz, dto); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.TimeTz, dtUtc); - cmd.Parameters.AddWithValue("p3", NpgsqlDbType.TimeTz, dtLocal); - cmd.Parameters.AddWithValue("p4", NpgsqlDbType.TimeTz, dtUnspecified); - cmd.Parameters.AddWithValue("p5", NpgsqlDbType.TimeTz, ts); - Assert.That(cmd.Parameters.All(p => p.DbType == DbType.Object)); - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(DateTimeOffset))); - - Assert.That(reader.GetFieldValue(i), Is.EqualTo(new DateTimeOffset(1, 1, 2, dto.Hour, dto.Minute, dto.Second, dto.Millisecond, dto.Offset))); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(DateTimeOffset))); - Assert.That(reader.GetFieldValue(i).Kind, Is.EqualTo(DateTimeKind.Local)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(reader.GetFieldValue(i).LocalDateTime)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(reader.GetFieldValue(i).LocalDateTime.TimeOfDay)); - } - } - } - } - } - - [Test] - public async Task TimeWithTimeZoneBeforeUtcZero() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT TIME WITH TIME ZONE '01:00:00+02'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), + }, + """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]","[\"1998-04-13 13:26:38\",\"1998-04-13 15:26:38\"]"}""", + "tsrange[]", + NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, + isDefault: false); + + [Test] + public async Task Tsmultirange_as_array_of_NpgsqlRange_of_DateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new DateTimeOffset(1, 1, 2, 1, 0, 0, new TimeSpan(0, 2, 0, 0)))); - } - } + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38, DateTimeKind.Local), + new(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), + }, + @"{[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""],[""1998-04-13 13:26:38"",""1998-04-13 15:26:38""]}", + "tsmultirange", + NpgsqlDbType.TimestampMultirange); + } - #endregion + #endregion - #region Timestamp + #region Timestamp with timezone - static readonly TestCaseData[] TimeStampCases = { - new TestCaseData(new DateTime(1998, 4, 12, 13, 26, 38)).SetName(nameof(Timestamp) + "Pre2000"), - new TestCaseData(new DateTime(2015, 1, 27, 8, 45, 12, 345)).SetName(nameof(Timestamp) + "Post2000"), - new TestCaseData(new DateTime(2013, 7, 25)).SetName(nameof(Timestamp) + "DateOnly"), - }; + // Note that the below text representations are local (according to TimeZone, which is set to Europe/Berlin in this test class), + // because that's how PG does timestamptz *text* representation. + static readonly TestCaseData[] TimestampTzWriteValues = + { + new TestCaseData(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "1998-04-12 15:26:38+02") + .SetName("Timestamptz_write_pre2000"), + new TestCaseData(new DateTime(2015, 1, 27, 8, 45, 12, 345, DateTimeKind.Utc), "2015-01-27 09:45:12.345+01") + .SetName("Timestamptz_write_post2000"), + new TestCaseData(new DateTime(2013, 7, 25, 0, 0, 0, DateTimeKind.Utc), "2013-07-25 02:00:00+02") + .SetName("Timestamptz_write_date_only") + }; + + [Test, TestCaseSource(nameof(TimestampTzWriteValues))] + public async Task Timestamptz_as_DateTime(DateTime dateTime, string sqlLiteral) + { + await AssertType(dateTime, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + // Explicitly check kind as well. + comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); - [Test, TestCaseSource(nameof(TimeStampCases))] - public async Task Timestamp(DateTime dateTime) - { - using (var conn = await OpenConnectionAsync()) + await AssertType( + new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp with time zone[]", NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, + isDefaultForReading: false); + + } + + [Test] + public async Task Timestamptz_infinity_as_DateTime() + { + await AssertType(DateTime.MinValue, "-infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isDefault: false); + await AssertType(DateTime.MaxValue, "infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + isDefault: false); + } + + [Test] + public async Task Timestamptz_cannot_write_non_utc_DateTime() + { + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "timestamp with time zone"); + await AssertTypeUnsupportedWrite(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), "timestamp with time zone"); + } + + [Test] + public async Task Timestamptz_as_DateTimeOffset_utc() + { + var dateTimeOffset = await AssertType( + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isDefaultForReading: false); + + Assert.That(dateTimeOffset.Offset, Is.EqualTo(TimeSpan.Zero)); + } + + [Test] + public Task Timestamptz_as_DateTimeOffset_utc_with_DbType_DateTimeOffset() + => AssertTypeWrite( + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + inferredDbType: DbType.DateTime, + isDefault: false); + + [Test] + public Task Timestamptz_cannot_write_non_utc_DateTimeOffset() + => AssertTypeUnsupportedWrite(new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.FromHours(2))); + + [Test] + public Task Timestamptz_as_long() + => AssertType( + -54297202000000, + "1998-04-12 15:26:38+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTime, + isDefault: false); + + [Test] + public async Task Timestamptz_array_as_DateTimeOffset_array() + { + var dateTimeOffsets = await AssertType( + new[] { - var npgsqlDateTime = new NpgsqlDateTime(dateTime.Ticks); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Timestamp); - var p2 = new NpgsqlParameter("p2", DbType.DateTime); - var p3 = new NpgsqlParameter("p3", DbType.DateTime2); - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = npgsqlDateTime }; - var p5 = new NpgsqlParameter { ParameterName = "p5", Value = dateTime }; - var p6 = new NpgsqlParameter { ParameterName = "p6", TypedValue = dateTime }; - Assert.That(p4.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); - Assert.That(p4.DbType, Is.EqualTo(DbType.DateTime)); - Assert.That(p5.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); - Assert.That(p5.DbType, Is.EqualTo(DbType.DateTime)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - cmd.Parameters.Add(p5); - cmd.Parameters.Add(p6); - p1.Value = p2.Value = p3.Value = npgsqlDateTime; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - // Regular type (DateTime) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(DateTime))); - Assert.That(reader.GetDateTime(i), Is.EqualTo(dateTime)); - Assert.That(reader.GetDateTime(i).Kind, Is.EqualTo(DateTimeKind.Unspecified)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTime)); - Assert.That(reader[i], Is.EqualTo(dateTime)); - Assert.That(reader.GetValue(i), Is.EqualTo(dateTime)); - - // Provider-specific type (NpgsqlTimeStamp) - Assert.That(reader.GetTimeStamp(i), Is.EqualTo(npgsqlDateTime)); - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof(NpgsqlDateTime))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(npgsqlDateTime)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(npgsqlDateTime)); - - // DateTimeOffset - Assert.That(() => reader.GetFieldValue(i), Throws.Exception.TypeOf()); - } - } - } - } - } - - static readonly TestCaseData[] TimeStampSpecialCases = { - new TestCaseData(NpgsqlDateTime.Infinity).SetName(nameof(TimeStampSpecial) + "Infinity"), - new TestCaseData(NpgsqlDateTime.NegativeInfinity).SetName(nameof(TimeStampSpecial) + "NegativeInfinity"), - new TestCaseData(new NpgsqlDateTime(-5, 3, 3, 1, 0, 0)).SetName(nameof(TimeStampSpecial) + "BC"), - }; - - [Test, TestCaseSource(nameof(TimeStampSpecialCases))] - public async Task TimeStampSpecial(NpgsqlDateTime value) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = value }); - using (var reader = await cmd.ExecuteReaderAsync()) { - reader.Read(); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(value)); - Assert.That(() => reader.GetDateTime(0), Throws.Exception.TypeOf()); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } - } - - [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] - public async Task TimeStampConvertInfinity() - { - using (var conn = new NpgsqlConnection(ConnectionString + ";ConvertInfinityDateTime=true")) + new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), + new DateTimeOffset(1999, 4, 12, 13, 26, 38, TimeSpan.Zero) + }, + """{"1998-04-12 15:26:38+02","1999-04-12 15:26:38+02"}""", + "timestamp with time zone[]", + NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, + isDefaultForReading: false); + + Assert.That(dateTimeOffsets[0].Offset, Is.EqualTo(TimeSpan.Zero)); + Assert.That(dateTimeOffsets[1].Offset, Is.EqualTo(TimeSpan.Zero)); + } + + [Test] + public Task Tstzrange_as_NpgsqlRange_of_DateTime() + => AssertType( + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", + "tstzrange", + NpgsqlDbType.TimestampTzRange, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + + [Test] + public Task Tstzrange_array_as_NpgsqlRange_of_DateTime_array() + => AssertType( + new[] { - conn.Open(); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Timestamp, DateTime.MaxValue); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Timestamp, DateTime.MinValue); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(NpgsqlDateTime.Infinity)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(NpgsqlDateTime.NegativeInfinity)); - Assert.That(reader.GetDateTime(0), Is.EqualTo(DateTime.MaxValue)); - Assert.That(reader.GetDateTime(1), Is.EqualTo(DateTime.MinValue)); - } - } - } - } - - #endregion - - #region Timestamp with timezone - - [Test] - public async Task TimestampTz() - { - using (var conn = await OpenConnectionAsync()) + new NpgsqlRange( + new(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc), + new(1998, 4, 13, 15, 26, 38, DateTimeKind.Utc)), + }, + """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", + "tstzrange[]", + NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, + isDefault: false); + + [Test] + public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_DateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] { - var tzOffset = TimeZoneInfo.Local.BaseUtcOffset; - if (tzOffset == TimeSpan.Zero) - Assert.Ignore("Test cannot run when machine timezone is UTC"); - - var dateTimeUtc = new DateTime(2015, 6, 27, 8, 45, 12, 345, DateTimeKind.Utc); - var dateTimeLocal = dateTimeUtc.ToLocalTime(); - var dateTimeUnspecified = new DateTime(dateTimeUtc.Ticks, DateTimeKind.Unspecified); - - var nDateTimeUtc = new NpgsqlDateTime(dateTimeUtc); - var nDateTimeLocal = nDateTimeUtc.ToLocalTime(); - var nDateTimeUnspecified = new NpgsqlDateTime(nDateTimeUtc.Ticks, DateTimeKind.Unspecified); - - //var dateTimeOffset = new DateTimeOffset(dateTimeLocal, dateTimeLocal - dateTimeUtc); - var dateTimeOffset = new DateTimeOffset(dateTimeLocal); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6, @p7", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.TimestampTz, dateTimeUtc); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.TimestampTz, dateTimeLocal); - cmd.Parameters.AddWithValue("p3", NpgsqlDbType.TimestampTz, dateTimeUnspecified); - cmd.Parameters.AddWithValue("p4", NpgsqlDbType.TimestampTz, nDateTimeUtc); - cmd.Parameters.AddWithValue("p5", NpgsqlDbType.TimestampTz, nDateTimeLocal); - cmd.Parameters.AddWithValue("p6", NpgsqlDbType.TimestampTz, nDateTimeUnspecified); - cmd.Parameters.AddWithValue("p7", dateTimeOffset); - Assert.That(cmd.Parameters["p7"].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz)); - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - // Regular type (DateTime) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(DateTime))); - Assert.That(reader.GetDateTime(i), Is.EqualTo(dateTimeLocal)); - Assert.That(reader.GetFieldValue(i).Kind, Is.EqualTo(DateTimeKind.Local)); - Assert.That(reader[i], Is.EqualTo(dateTimeLocal)); - Assert.That(reader.GetValue(i), Is.EqualTo(dateTimeLocal)); - - // Provider-specific type (NpgsqlDateTime) - Assert.That(reader.GetTimeStamp(i), Is.EqualTo(nDateTimeLocal)); - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof(NpgsqlDateTime))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(nDateTimeLocal)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(nDateTimeLocal)); - - // DateTimeOffset - Assert.That(reader.GetFieldValue(i), Is.EqualTo(dateTimeOffset)); - var x = reader.GetFieldValue(i); - } - } - } - - Assert.AreEqual(nDateTimeUtc, nDateTimeLocal.ToUniversalTime()); - Assert.AreEqual(nDateTimeUtc, new NpgsqlDateTime(nDateTimeLocal.Ticks, DateTimeKind.Unspecified).ToUniversalTime()); - Assert.AreEqual(nDateTimeLocal, nDateTimeUnspecified.ToLocalTime()); - } - } - - #endregion - - #region Interval - - [Test] - public async Task Interval() + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 13, 15, 26, 38, DateTimeKind.Utc)), + }, + @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""],[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02""]}", + "tstzmultirange", + NpgsqlDbType.TimestampTzMultirange); + } + + [Test] + public Task Cannot_mix_DateTime_Kinds_in_array() + => AssertTypeUnsupportedWrite(new[] { - using (var conn = await OpenConnectionAsync()) + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), + }); + + + [Test] + public Task Cannot_mix_DateTime_Kinds_in_range() + => AssertTypeUnsupportedWrite, ArgumentException>(new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local))); + + [Test] + public async Task Cannot_mix_DateTime_Kinds_in_multirange() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertTypeUnsupportedWrite[], ArgumentException>(new[] + { + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), + new NpgsqlRange( + new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Local), + new DateTime(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), + }); + } + + [Test] + public void NpgsqlParameterDbType_is_value_dependent_datetime_or_datetime2() + { + var localtimestamp = new NpgsqlParameter { Value = DateTime.Now }; + var unspecifiedtimestamp = new NpgsqlParameter { Value = new DateTime() }; + Assert.AreEqual(DbType.DateTime2, localtimestamp.DbType); + Assert.AreEqual(DbType.DateTime2, unspecifiedtimestamp.DbType); + + // We don't support any DateTimeOffset other than offset 0 which maps to timestamptz, + // we might add an exception for offset == DateTimeOffset.Now.Offset (local offset) mapping to timestamp at some point. + // var dtotimestamp = new NpgsqlParameter { Value = DateTimeOffset.Now }; + // Assert.AreEqual(DbType.DateTime2, dtotimestamp.DbType); + + var timestamptz = new NpgsqlParameter { Value = DateTime.UtcNow }; + var dtotimestamptz = new NpgsqlParameter { Value = DateTimeOffset.UtcNow }; + Assert.AreEqual(DbType.DateTime, timestamptz.DbType); + Assert.AreEqual(DbType.DateTime, dtotimestamptz.DbType); + } + + [Test] + public void NpgsqlParameterNpgsqlDbType_is_value_dependent_timestamp_or_timestamptz() + { + var localtimestamp = new NpgsqlParameter { Value = DateTime.Now }; + var unspecifiedtimestamp = new NpgsqlParameter { Value = new DateTime() }; + Assert.AreEqual(NpgsqlDbType.Timestamp, localtimestamp.NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.Timestamp, unspecifiedtimestamp.NpgsqlDbType); + + var timestamptz = new NpgsqlParameter { Value = DateTime.UtcNow }; + var dtotimestamptz = new NpgsqlParameter { Value = DateTimeOffset.UtcNow }; + Assert.AreEqual(NpgsqlDbType.TimestampTz, timestamptz.NpgsqlDbType); + Assert.AreEqual(NpgsqlDbType.TimestampTz, dtotimestamptz.NpgsqlDbType); + } + + [Test] + public async Task Array_of_nullable_timestamptz() + => await AssertType( + new DateTime?[] { - var expectedNpgsqlInterval = new NpgsqlTimeSpan(1, 2, 3, 4, 5); - var expectedTimeSpan = new TimeSpan(1, 2, 3, 4, 5); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Interval); - var p2 = new NpgsqlParameter("p2", expectedTimeSpan); - var p3 = new NpgsqlParameter("p3", expectedNpgsqlInterval); - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Interval)); - Assert.That(p2.DbType, Is.EqualTo(DbType.Object)); - Assert.That(p3.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Interval)); - Assert.That(p3.DbType, Is.EqualTo(DbType.Object)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = expectedNpgsqlInterval; - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - // Regular type (TimeSpan) - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(TimeSpan))); - Assert.That(reader.GetTimeSpan(0), Is.EqualTo(expectedTimeSpan)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedTimeSpan)); - Assert.That(reader[0], Is.EqualTo(expectedTimeSpan)); - Assert.That(reader.GetValue(0), Is.EqualTo(expectedTimeSpan)); - - // Provider-specific type (NpgsqlInterval) - Assert.That(reader.GetInterval(0), Is.EqualTo(expectedNpgsqlInterval)); - Assert.That(reader.GetProviderSpecificFieldType(0), Is.EqualTo(typeof(NpgsqlTimeSpan))); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(expectedNpgsqlInterval)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedNpgsqlInterval)); - } - } - } - } - - #endregion - - public DateTimeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + null + }, + @"{""1998-04-12 15:26:38+02"",NULL}", + "timestamp with time zone[]", + NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, + isDefault: false); + + #endregion + + #region Interval + + static readonly TestCaseData[] IntervalValues = + { + new TestCaseData(new TimeSpan(0, 2, 3, 4, 5), "02:03:04.005") + .SetName("Interval_time_only"), + new TestCaseData(new TimeSpan(1, 2, 3, 4, 5), "1 day 02:03:04.005") + .SetName("Interval_with_day"), + new TestCaseData(new TimeSpan(61, 2, 3, 4, 5), "61 days 02:03:04.005") + .SetName("Interval_with_many_days"), + new TestCaseData(new TimeSpan(new TimeSpan(2, 3, 4).Ticks + 10), "02:03:04.000001") + .SetName("Interval_with_microsecond") + }; + + [Test, TestCaseSource(nameof(IntervalValues))] + public Task Interval_as_TimeSpan(TimeSpan timeSpan, string sqlLiteral) + => AssertType(timeSpan, sqlLiteral, "interval", NpgsqlDbType.Interval); + + [Test] + public Task Interval_write_as_TimeSpan_truncates_ticks() + => AssertTypeWrite( + new TimeSpan(new TimeSpan(2, 3, 4).Ticks + 1), + "02:03:04", + "interval", + NpgsqlDbType.Interval); + + [Test] + public Task Interval_as_NpgsqlInterval() + => AssertType( + new NpgsqlInterval(2, 15, 7384005000), + "2 mons 15 days 02:03:04.005", "interval", + NpgsqlDbType.Interval, + isDefaultForReading: false); + + [Test] + public Task Interval_with_months_cannot_read_as_TimeSpan() + => AssertTypeUnsupportedRead("1 month 2 days", "interval"); + + #endregion + + protected override async ValueTask OpenConnectionAsync() + { + var conn = await base.OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync("SET TimeZone='Europe/Berlin'"); + return conn; } + + protected override NpgsqlConnection OpenConnection() + => throw new NotSupportedException(); } diff --git a/test/Npgsql.Tests/Types/DomainTests.cs b/test/Npgsql.Tests/Types/DomainTests.cs new file mode 100644 index 0000000000..4faaceb212 --- /dev/null +++ b/test/Npgsql.Tests/Types/DomainTests.cs @@ -0,0 +1,79 @@ +using System; +using System.Threading.Tasks; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests.Types; + +public class DomainTests : MultiplexingTestBase +{ + [Test, Description("Resolves a domain type handler via the different pathways")] + public async Task Domain_resolution() + { + if (IsMultiplexing) + Assert.Ignore("Multiplexing, ReloadTypes"); + + await using var dataSource = CreateDataSource(csb => csb.Pooling = false); + await using var conn = await dataSource.OpenConnectionAsync(); + var type = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {type} AS text"); + + // Resolve type by DataTypeName + conn.ReloadTypes(); + using (var cmd = new NpgsqlCommand("SELECT @p", conn)) + { + cmd.Parameters.Add(new NpgsqlParameter { ParameterName="p", DataTypeName = type, Value = DBNull.Value }); + using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo("text")); + } + } + + // When sending back domain types, PG sends back the type OID of their base type. So we never need to resolve domains from + // a type OID. + conn.ReloadTypes(); + using (var cmd = new NpgsqlCommand($"SELECT 'foo'::{type}", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo("text")); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); + } + } + + [Test] + public async Task Domain() + { + using var conn = await OpenConnectionAsync(); + var type = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {type} AS text"); + Assert.That(await conn.ExecuteScalarAsync($"SELECT 'foo'::{type}"), Is.EqualTo("foo")); + } + + [Test] + public async Task Domain_in_composite() + { + await using var adminConnection = await OpenConnectionAsync(); + var domainType = await GetTempTypeName(adminConnection); + var compositeType = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@" +CREATE DOMAIN {domainType} AS text; +CREATE TYPE {compositeType} AS (value {domainType});"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var result = (SomeComposite)(await connection.ExecuteScalarAsync($"SELECT ROW('foo')::{compositeType}"))!; + Assert.That(result.Value, Is.EqualTo("foo")); + } + + class SomeComposite + { + public string? Value { get; set; } + } + + public DomainTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} +} diff --git a/test/Npgsql.Tests/Types/EnumTests.cs b/test/Npgsql.Tests/Types/EnumTests.cs index a164d8a846..c36514d6d3 100644 --- a/test/Npgsql.Tests/Types/EnumTests.cs +++ b/test/Npgsql.Tests/Types/EnumTests.cs @@ -3,622 +3,246 @@ using System.Threading.Tasks; using Npgsql.NameTranslation; using Npgsql.PostgresTypes; +using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; -using static Npgsql.Util.Statics; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class EnumTests : MultiplexingTestBase { - [NonParallelizable] - public class EnumTests : TestBase + enum Mood { Sad, Ok, Happy } + enum AnotherEnum { Value1, Value2 } + + [Test] + public async Task Data_source_mapping() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null); + } + + [Test] + public async Task Data_source_unmap() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + + var isUnmapSuccessful = dataSourceBuilder.UnmapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + + Assert.IsTrue(isUnmapSuccessful); + Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null)); + } + + [Test] + public async Task Data_source_mapping_non_generic() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(typeof(Mood), type); + await using var dataSource = dataSourceBuilder.Build(); + await AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null); + } + + [Test] + public async Task Data_source_unmap_non_generic() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(typeof(Mood), type); + + var isUnmapSuccessful = dataSourceBuilder.UnmapEnum(typeof(Mood), type); + await using var dataSource = dataSourceBuilder.Build(); + + Assert.IsTrue(isUnmapSuccessful); + Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null)); + } + + [Test] + public async Task Dual_enums() { - enum Mood { Sad, Ok, Happy }; - - [Test] - public async Task UnmappedEnum() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(UnmappedEnum), - Pooling = false - }; - using (var conn = await OpenConnectionAsync(csb)) - await using (var _ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - - using (var cmd = new NpgsqlCommand("SELECT @scalar1, @scalar2, @scalar3, @scalar4", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "scalar1", - Value = Mood.Happy, - DataTypeName = type - }); - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "scalar2", - Value = "happy", - DataTypeName = type - }); - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "scalar3", - TypedValue = Mood.Happy, - DataTypeName = type - }); - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "scalar4", - TypedValue = "happy", - DataTypeName = type - }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < 4; i++) - { - Assert.That(reader.GetDataTypeName(i), Is.EqualTo($"public.{type}")); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(Mood.Happy)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo("happy")); - Assert.That(reader.GetValue(i), Is.EqualTo("happy")); - } - } - } - } - } - - [Test, Description("Resolves an enum type handler via the different pathways, with global mapping")] - public async Task EnumTypeResolutionWithGlobalMapping() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(EnumTypeResolutionWithGlobalMapping), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = await OpenConnectionAsync(csb)) - await using (var _ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - NpgsqlConnection.GlobalTypeMapper.MapEnum(type); - try - { - conn.ReloadTypes(); - - // Resolve type by DataTypeName - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p", - DataTypeName = type, - Value = DBNull.Value - }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - Assert.That(reader.IsDBNull(0), Is.True); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = Mood.Ok }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand($"SELECT 'happy'::{type}", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - } - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); - } - } - } - - [Test, Description("Resolves an enum type handler via the different pathways, with late mapping")] - public async Task EnumTypeResolutionWithLateMapping() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(EnumTypeResolutionWithLateMapping), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = await OpenConnectionAsync(csb)) - await using (var _ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - - // Resolve type by NpgsqlDbType - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter - { - ParameterName = "p", - DataTypeName = type, - Value = DBNull.Value - }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - Assert.That(reader.IsDBNull(0), Is.True); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = Mood.Ok }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type); - using (var cmd = new NpgsqlCommand($"SELECT 'happy'::{type}", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - } - } - } - - [Test] - public async Task LateMapping() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type); - const Mood expected = Mood.Ok; - var cmd = new NpgsqlCommand($"SELECT @p1::{type}, @p2::{type}", conn); - var p1 = new NpgsqlParameter - { - ParameterName = "p1", - DataTypeName = type, - Value = expected - }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Mood))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - } - } - - [Test] - public async Task DualEnums() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var type1)) - await using (var __ = await GetTempTypeName(conn, out var type2)) - { - await conn.ExecuteNonQueryAsync($@" + await using var adminConnection = await OpenConnectionAsync(); + var type1 = await GetTempTypeName(adminConnection); + var type2 = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@" CREATE TYPE {type1} AS ENUM ('sad', 'ok', 'happy'); CREATE TYPE {type2} AS ENUM ('label1', 'label2', 'label3')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type1); - conn.TypeMapper.MapEnum(type2); - var cmd = new NpgsqlCommand("SELECT @p1", conn); - var expected = new[] { Mood.Ok, Mood.Sad }; - var p = new NpgsqlParameter - { - ParameterName = "p1", - DataTypeName = $"{type1}[]", - Value = expected - }; - cmd.Parameters.Add(p); - var result = await cmd.ExecuteScalarAsync(); - Assert.AreEqual(expected, result); - } - } - - [Test] - public async Task GlobalMapping() - { - using var adminConn = await OpenConnectionAsync(); - await using var _ = await GetTempTypeName(adminConn, out var type); - - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - NpgsqlConnection.GlobalTypeMapper.MapEnum(type); - conn.ReloadTypes(); - const Mood expected = Mood.Ok; - using (var cmd = new NpgsqlCommand($"SELECT @p::{type}", conn)) - { - var p = new NpgsqlParameter { ParameterName = "p", Value = expected }; - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Mood))); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - } - } - } - - // Unmap - NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); - - using (var conn = await OpenConnectionAsync()) - { - // Enum should have been unmapped and so will return as text - Assert.That(await conn.ExecuteScalarAsync($"SELECT 'ok'::{type}"), Is.EqualTo("ok")); - } - } - - [Test] - public async Task GlobalMappingWhenTypeNotFound() - { - using (var conn = await OpenConnectionAsync()) - { - NpgsqlConnection.GlobalTypeMapper.MapEnum("unknown_enum"); - try - { - Assert.That(conn.ReloadTypes, Throws.Nothing); - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum("unknown_enum"); - } - } - } - - [Test] - public async Task Array() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(type); - var expected = new[] {Mood.Ok, Mood.Happy}; - using (var cmd = new NpgsqlCommand($"SELECT @p1::{type}[], @p2::{type}[]", conn)) - { - var p1 = new NpgsqlParameter - { - ParameterName = "p1", - DataTypeName = $"{type}[]", - Value = expected - }; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Array))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - } - } - } - } - - [Test] - public async Task ReadUnmappedEnumsAsString() - { - using (var conn = new NpgsqlConnection(ConnectionString)) - { - conn.Open(); - await using var _ = await GetTempTypeName(conn, out var type); - - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('Sad', 'Ok', 'Happy')"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand($"SELECT 'Sad'::{type}, ARRAY['Ok', 'Happy']::{type}[]", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo("Sad")); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}")); - Assert.That(reader[1], Is.EqualTo(new[] { "Ok", "Happy" })); - } - } - } - - [Test, Description("Test that a c# string can be written to a backend enum when DbType is unknown")] - public async Task WriteStringToBackendEnum() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var type)) - await using (var __ = await GetTempTableName(conn, out var table)) - { - await conn.ExecuteNonQueryAsync($@" -CREATE TYPE {type} AS ENUM ('Banana', 'Apple', 'Orange'); -CREATE TABLE {table} (id SERIAL, value1 {type}, value2 {type});"); - conn.ReloadTypes(); - const string expected = "Banana"; - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (id, value1, value2) VALUES (default, @p1, @p2);", conn)) - { - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Unknown, expected); - var p2 = new NpgsqlParameter("p1", NpgsqlDbType.Unknown) {Value = expected}; - cmd.Parameters.Add(p2); - cmd.ExecuteNonQuery(); - } - } - } - - [Test, Description("Tests that a a C# enum an be written to an enum backend when passed as dbUnknown")] - public async Task WriteEnumAsDbUnknwown() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var type)) - await using (var __ = await GetTempTableName(conn, out var table)) - { - await conn.ExecuteNonQueryAsync($@" -CREATE TYPE {type} AS ENUM ('Sad', 'Ok', 'Happy'); -CREATE TABLE {table} (value1 {type})"); - conn.ReloadTypes(); - var expected = Mood.Happy; - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (value1) VALUES (@p1);", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Unknown, expected); - cmd.ExecuteNonQuery(); - } - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] - public async Task NameTranslationDefaultSnakeCase() - { - // Per-connection mapping - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTypeName(conn, out var enumName1)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {enumName1} AS ENUM ('simple', 'two_words', 'some_database_name')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(enumName1); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - cmd.Parameters.AddWithValue("p1", NameTranslationEnum.Simple); - cmd.Parameters.AddWithValue("p2", NameTranslationEnum.TwoWords); - cmd.Parameters.AddWithValue("p3", NameTranslationEnum.SomeClrName); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(actual: reader.GetFieldValue(0), Is.EqualTo(NameTranslationEnum.Simple)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(NameTranslationEnum.TwoWords)); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(NameTranslationEnum.SomeClrName)); - } - } - } - - // Global mapping - using var dropConn = await OpenConnectionAsync(); - await using var __ = await GetTempTypeName(dropConn, out var enumName2); - NpgsqlConnection.GlobalTypeMapper.MapEnum(enumName2); - try - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {enumName2} AS ENUM ('simple', 'two_words', 'some_database_name')"); - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - cmd.Parameters.AddWithValue("p1", NameTranslationEnum.Simple); - cmd.Parameters.AddWithValue("p2", NameTranslationEnum.TwoWords); - cmd.Parameters.AddWithValue("p3", NameTranslationEnum.SomeClrName); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(NameTranslationEnum.Simple)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(NameTranslationEnum.TwoWords)); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(NameTranslationEnum.SomeClrName)); - } - } - } - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum(); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] - public async Task NameTranslationNull() - { - // Per-connection mapping - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync(@"CREATE TYPE pg_temp.""NameTranslationEnum"" AS ENUM ('Simple', 'TwoWords', 'some_database_name')"); - conn.ReloadTypes(); - conn.TypeMapper.MapEnum(nameTranslator: new NpgsqlNullNameTranslator()); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - cmd.Parameters.AddWithValue("p1", NameTranslationEnum.Simple); - cmd.Parameters.AddWithValue("p2", NameTranslationEnum.TwoWords); - cmd.Parameters.AddWithValue("p3", NameTranslationEnum.SomeClrName); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(NameTranslationEnum.Simple)); - Assert.That(reader.GetFieldValue(1), - Is.EqualTo(NameTranslationEnum.TwoWords)); - Assert.That(reader.GetFieldValue(2), - Is.EqualTo(NameTranslationEnum.SomeClrName)); - } - } - } - } - - enum NameTranslationEnum - { - Simple, - TwoWords, - [PgName("some_database_name")] - SomeClrName - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/632")] - public async Task Schemas() - { - using var adminConn = await OpenConnectionAsync(); - await using var _ = await CreateTempSchema(adminConn, out var schema1); - await using var __ = await CreateTempSchema(adminConn, out var schema2); - - try - { - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync($@" + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type1); + dataSourceBuilder.MapEnum(type2); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, new[] { Mood.Ok, Mood.Sad }, "{ok,sad}", type1 + "[]", npgsqlDbType: null); + } + + [Test] + public async Task Array() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, new[] { Mood.Ok, Mood.Happy }, "{ok,happy}", type + "[]", npgsqlDbType: null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] + public async Task Name_translation_default_snake_case() + { + await using var adminConnection = await OpenConnectionAsync(); + var enumName1 = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {enumName1} AS ENUM ('simple', 'two_words', 'some_database_name')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(enumName1); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, NameTranslationEnum.Simple, "simple", enumName1, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.TwoWords, "two_words", enumName1, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", enumName1, npgsqlDbType: null); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] + public async Task Name_translation_null() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('Simple', 'TwoWords', 'some_database_name')"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type, nameTranslator: new NpgsqlNullNameTranslator()); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, NameTranslationEnum.Simple, "Simple", type, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.TwoWords, "TwoWords", type, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", type, npgsqlDbType: null); + } + + [Test] + public async Task Unmapped_enum_as_clr_enum() + { + await using var dataSource = CreateDataSource(b => b.EnableUnmappedTypes()); + await using var connection = await dataSource.OpenConnectionAsync(); + var type1 = await GetTempTypeName(connection); + var type2 = await GetTempTypeName(connection); + await connection.ExecuteNonQueryAsync(@$" +CREATE TYPE {type1} AS ENUM ('sad', 'ok', 'happy'); +CREATE TYPE {type2} AS ENUM ('value1', 'value2');"); + await connection.ReloadTypesAsync(); + + await AssertType(connection, Mood.Happy, "happy", type1, npgsqlDbType: null, isDefault: false); + await AssertType(connection, AnotherEnum.Value2, "value2", type2, npgsqlDbType: null, isDefault: false); + } + + [Test] + public async Task Unmapped_enum_as_clr_enum_supported_only_with_EnableUnmappedTypes() + { + await using var connection = await DataSource.OpenConnectionAsync(); + var enumType = await GetTempTypeName(connection); + await connection.ExecuteNonQueryAsync($"CREATE TYPE {enumType} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + + var errorMessage = string.Format( + NpgsqlStrings.UnmappedEnumsNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + nameof(NpgsqlDataSourceBuilder)); + + var exception = await AssertTypeUnsupportedWrite(Mood.Happy, enumType); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead("happy", enumType); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task Unmapped_enum_as_string() + { + await using var connection = await OpenConnectionAsync(); + var type = await GetTempTypeName(connection); + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + await connection.ReloadTypesAsync(); + + await AssertType(connection, "happy", "happy", type, npgsqlDbType: null, isDefaultForWriting: false); + } + + enum NameTranslationEnum + { + Simple, + TwoWords, + [PgName("some_database_name")] + SomeClrName + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/632")] + public async Task Same_name_in_different_schemas() + { + await using var adminConnection = await OpenConnectionAsync(); + var schema1 = await CreateTempSchema(adminConnection); + var schema2 = await CreateTempSchema(adminConnection); + await adminConnection.ExecuteNonQueryAsync($@" CREATE TYPE {schema1}.my_enum AS ENUM ('one'); CREATE TYPE {schema2}.my_enum AS ENUM ('alpha');"); - conn.ReloadTypes(); - conn.TypeMapper - .MapEnum($"{schema1}.my_enum") - .MapEnum($"{schema2}.my_enum"); - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", Enum1.One); - cmd.Parameters.AddWithValue("p2", Enum2.Alpha); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(Enum1.One)); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"{schema1}.my_enum")); - Assert.That(reader[1], Is.EqualTo(Enum2.Alpha)); - Assert.That(reader.GetDataTypeName(1), Is.EqualTo($"{schema2}.my_enum")); - } - } - } - - // Global mapping - NpgsqlConnection.GlobalTypeMapper.MapEnum($"{schema1}.my_enum"); - NpgsqlConnection.GlobalTypeMapper.MapEnum($"{schema2}.my_enum"); - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.AddWithValue("p1", Enum1.One); - cmd.Parameters.AddWithValue("p2", Enum2.Alpha); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(Enum1.One)); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"{schema1}.my_enum")); - Assert.That(reader[1], Is.EqualTo(Enum2.Alpha)); - Assert.That(reader.GetDataTypeName(1), Is.EqualTo($"{schema2}.my_enum")); - } - } - } - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum($"{schema1}.my_enum"); - NpgsqlConnection.GlobalTypeMapper.UnmapEnum($"{schema2}.my_enum"); - } - } - - enum Enum1 { One } - enum Enum2 { Alpha } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1017")] - public async Task GlobalMappingsAndPooling() - { - using var adminConn = await OpenConnectionAsync(); - using var _ = CreateTempPool(ConnectionString, out var connectionString); - await using var __ = await GetTempTypeName(adminConn, out var type); - - int serverId; - using (var conn = await OpenConnectionAsync(connectionString)) - { - serverId = conn.ProcessID; - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - } - // At this point the backend type for the enum is loaded, but no global mapping - // has been made. Reopening the same pooled connector should learn about the new - // global mapping - NpgsqlConnection.GlobalTypeMapper.MapEnum(type); - try - { - using (var conn = await OpenConnectionAsync(connectionString)) - { - Assert.That(conn.ProcessID, Is.EqualTo(serverId)); - Assert.That(await conn.ExecuteScalarAsync($"SELECT 'sad'::{type}"), Is.EqualTo(Mood.Sad)); - } - } - finally - { - NpgsqlConnection.GlobalTypeMapper.UnmapEnum("mood1"); - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1779")] - public async Task EnumPostgresType() - { - using var _ = CreateTempPool(ConnectionString, out var connectionString); - using (var conn = await OpenConnectionAsync(connectionString)) - await using (var __ = await GetTempTypeName(conn, out var type)) - { - await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); - conn.ReloadTypes(); - - using (var cmd = new NpgsqlCommand($"SELECT 'ok'::{type}", conn)) - { - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - var enumType = (PostgresEnumType)reader.GetPostgresType(0); - Assert.That(enumType.Name, Is.EqualTo(type)); - Assert.That(enumType.Labels, Is.EqualTo(new List { "sad", "ok", "happy" })); - } - } - } - } - - enum TestEnum - { - label1, - label2, - [PgName("label3")] - Label3 - } + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum($"{schema1}.my_enum"); + dataSourceBuilder.MapEnum($"{schema2}.my_enum"); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, Enum1.One, "one", $"{schema1}.my_enum", npgsqlDbType: null); + await AssertType(dataSource, Enum2.Alpha, "alpha", $"{schema2}.my_enum", npgsqlDbType: null); } + + enum Enum1 { One } + enum Enum2 { Alpha } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1779")] + public async Task GetPostgresType() + { + await using var dataSource = CreateDataSource(); + using var conn = await dataSource.OpenConnectionAsync(); + var type = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + conn.ReloadTypes(); + + using var cmd = new NpgsqlCommand($"SELECT 'ok'::{type}", conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var enumType = (PostgresEnumType)reader.GetPostgresType(0); + Assert.That(enumType.Name, Is.EqualTo(type)); + Assert.That(enumType.Labels, Is.EqualTo(new List { "sad", "ok", "happy" })); + } + + enum TestEnum + { + label1, + label2, + [PgName("label3")] + Label3 + } + + public EnumTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/FullTextSearchTests.cs b/test/Npgsql.Tests/Types/FullTextSearchTests.cs index e04c699646..eda874b12a 100644 --- a/test/Npgsql.Tests/Types/FullTextSearchTests.cs +++ b/test/Npgsql.Tests/Types/FullTextSearchTests.cs @@ -1,44 +1,103 @@ using System; +using System.Collections; using System.Threading.Tasks; +using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +#pragma warning disable CS0618 // NpgsqlTsVector.Parse is obsolete + +namespace Npgsql.Tests.Types; + +public class FullTextSearchTests : MultiplexingTestBase { - public class FullTextSearchTests : MultiplexingTestBase + public FullTextSearchTests(MultiplexingMode multiplexingMode) + : base(multiplexingMode) { } + + [Test] + public Task TsVector() + => AssertType( + NpgsqlTsVector.Parse("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'"), + "'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'", + "tsvector", + NpgsqlDbType.TsVector); + + public static IEnumerable TsQueryTestCases() => new[] { - [Test] - public async Task TsVector() + new object[] { - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - var inputVec = NpgsqlTsVector.Parse(" a:12345C a:24D a:25B b c d 1 2 a:25A,26B,27,28"); - - cmd.CommandText = "Select :p"; - cmd.Parameters.AddWithValue("p", inputVec); - var outputVec = await cmd.ExecuteScalarAsync(); - Assert.AreEqual(inputVec.ToString(), outputVec!.ToString()); - } - } - - [Test] - public async Task TsQuery() + "'a'", + new NpgsqlTsQueryLexeme("a") + }, + new object[] + { + "!'a'", + new NpgsqlTsQueryNot( + new NpgsqlTsQueryLexeme("a")) + }, + new object[] + { + "'a' | 'b'", + new NpgsqlTsQueryOr( + new NpgsqlTsQueryLexeme("a"), + new NpgsqlTsQueryLexeme("b")) + }, + new object[] { - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - var query = conn.PostgreSqlVersion < new Version(9, 6) - ? NpgsqlTsQuery.Parse("(a & !(c | d)) & (!!a&b) | ä | f") - : NpgsqlTsQuery.Parse("(a & !(c | d)) & (!!a&b) | ä | x <-> y | x <10> y | d <0> e | f"); - - cmd.CommandText = "Select :p"; - cmd.Parameters.AddWithValue("p", query); - var output = await cmd.ExecuteScalarAsync(); - Assert.AreEqual(query.ToString(), output!.ToString()); - } + "'a' & 'b'", + new NpgsqlTsQueryAnd( + new NpgsqlTsQueryLexeme("a"), + new NpgsqlTsQueryLexeme("b")) + }, + new object[] + { + "'a' <-> 'b'", + new NpgsqlTsQueryFollowedBy( + new NpgsqlTsQueryLexeme("a"), 1, new NpgsqlTsQueryLexeme("b")) } + }; + + [Test] + [TestCaseSource(nameof(TsQueryTestCases))] + public Task TsQuery(string sqlLiteral, NpgsqlTsQuery query) + => AssertType(query, sqlLiteral, "tsquery", NpgsqlDbType.TsQuery); + + [Test] + public async Task Full_text_search_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.FullTextSearchNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableFullTextSearch), + nameof(NpgsqlSlimDataSourceBuilder)); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + + var exception = await AssertTypeUnsupportedRead("a", "tsquery", dataSource); + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + + exception = await AssertTypeUnsupportedWrite(new NpgsqlTsQueryLexeme("a"), pgTypeName: null, dataSource); + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + + exception = await AssertTypeUnsupportedRead("1", "tsvector", dataSource); + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + + exception = await AssertTypeUnsupportedWrite(NpgsqlTsVector.Parse("'1'"), pgTypeName: null, dataSource); + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableFullTextSearch() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableFullTextSearch(); + await using var dataSource = dataSourceBuilder.Build(); - public FullTextSearchTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + await AssertType(new NpgsqlTsQueryLexeme("a"), "'a'", "tsquery", NpgsqlDbType.TsQuery); + await AssertType(NpgsqlTsVector.Parse("'1'"), "'1'", "tsvector", NpgsqlDbType.TsVector); } } diff --git a/test/Npgsql.Tests/Types/GeometricTypeTests.cs b/test/Npgsql.Tests/Types/GeometricTypeTests.cs index 45a42bef3e..c4d8d53b0e 100644 --- a/test/Npgsql.Tests/Types/GeometricTypeTests.cs +++ b/test/Npgsql.Tests/Types/GeometricTypeTests.cs @@ -2,189 +2,139 @@ using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on PostgreSQL geometric types +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-geometric.html +/// +class GeometricTypeTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL geometric types - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-geometric.html - /// - class GeometricTypeTests : MultiplexingTestBase + [Test] + public Task Point() + => AssertType(new NpgsqlPoint(1.2, 3.4), "(1.2,3.4)", "point", NpgsqlDbType.Point); + + [Test] + public Task Line() + => AssertType(new NpgsqlLine(1, 2, 3), "{1,2,3}", "line", NpgsqlDbType.Line); + + [Test] + public Task LineSegment() + => AssertType(new NpgsqlLSeg(1, 2, 3, 4), "[(1,2),(3,4)]", "lseg", NpgsqlDbType.LSeg); + + [Test] + public async Task Box() { - [Test] - public async Task Point() - { - using (var conn = await OpenConnectionAsync()) - { - var expected = new NpgsqlPoint(1.2, 3.4); - var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Point) {Value = expected}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Point)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlPoint))); - var actual = reader.GetFieldValue(i); - AssertPointsEqual(actual, expected); - } - } - } - } - - [Test] - public async Task LineSegment() - { - using (var conn = await OpenConnectionAsync()) - { - var expected = new NpgsqlLSeg(1, 2, 3, 4); - var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.LSeg) {Value = expected}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.LSeg)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlLSeg))); - var actual = reader.GetFieldValue(i); - AssertPointsEqual(actual.Start, expected.Start); - AssertPointsEqual(actual.End, expected.End); - } - } - } - } - - [Test] - public async Task Box() - { - using (var conn = await OpenConnectionAsync()) - { - var expected = new NpgsqlBox(2, 4, 1, 3); - var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Box) {Value = expected}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Box)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlBox))); - var actual = reader.GetFieldValue(i); - AssertPointsEqual(actual.UpperRight, expected.UpperRight); - } - } - } - } - - [Test] - public async Task Path() - { - using (var conn = await OpenConnectionAsync()) - { - var expectedOpen = new NpgsqlPath(new[] {new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)}, true); - var expectedClosed = new NpgsqlPath(new[] {new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)}, false); - var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Path) {Value = expectedOpen}; - var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Path) {Value = expectedClosed}; - var p3 = new NpgsqlParameter {ParameterName = "p3", Value = expectedClosed}; - Assert.That(p3.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Path)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - var expected = i == 0 ? expectedOpen : expectedClosed; - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlPath))); - var actual = reader.GetFieldValue(i); - Assert.That(actual.Open, Is.EqualTo(expected.Open)); - Assert.That(actual, Has.Count.EqualTo(expected.Count)); - for (var j = 0; j < actual.Count; j++) - AssertPointsEqual(actual[j], expected[j]); - } - } - } - } - - [Test] - public async Task Polygon() - { - using (var conn = await OpenConnectionAsync()) - { - var expected = new NpgsqlPolygon(new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)); - var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Polygon) {Value = expected}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Polygon)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlPolygon))); - var actual = reader.GetFieldValue(i); - Assert.That(actual, Has.Count.EqualTo(expected.Count)); - for (var j = 0; j < actual.Count; j++) - AssertPointsEqual(actual[j], expected[j]); - } - } - } - } - - [Test] - public async Task Circle() + await AssertType( + new NpgsqlBox(top: 3, right: 4, bottom: 1, left: 2), + "(4,3),(2,1)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + await AssertType( + new NpgsqlBox(top: -10, right: 0, bottom: -20, left: -10), + "(0,-10),(-10,-20)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + await AssertType( + new NpgsqlBox(top: 1, right: 2, bottom: 3, left: 4), + "(4,3),(2,1)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + var swapped = new NpgsqlBox(top: -20, right: -10, bottom: -10, left: 0); + + await AssertType( + swapped, + "(0,-10),(-10,-20)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + await AssertType( + swapped with { UpperRight = new NpgsqlPoint(-20,-10) }, + "(-10,-10),(-20,-20)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + + await AssertType( + swapped with { LowerLeft = new NpgsqlPoint(10, 10) }, + "(10,10),(0,-10)", + "box", + NpgsqlDbType.Box, + skipArrayCheck: true); // Uses semicolon instead of comma as separator + } + + [Test] + public async Task Box_array() + { + var data = new[] { - using (var conn = await OpenConnectionAsync()) - { - var expected = new NpgsqlCircle(1, 2, 0.5); - var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Circle) {Value = expected}; - var p2 = new NpgsqlParameter {ParameterName = "p2", Value = expected}; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Circle)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(NpgsqlCircle))); - var actual = reader.GetFieldValue(i); - Assert.That(actual.X, Is.EqualTo(expected.X).Within(1).Ulps); - Assert.That(actual.Y, Is.EqualTo(expected.Y).Within(1).Ulps); - Assert.That(actual.Radius, Is.EqualTo(expected.Radius).Within(1).Ulps); - } - } - } - } - - void AssertPointsEqual(NpgsqlPoint actual, NpgsqlPoint expected) + new NpgsqlBox(top: 3, right: 4, bottom: 1, left: 2), + new NpgsqlBox(top: 5, right: 6, bottom: 3, left: 4), + new NpgsqlBox(top: -10, right: 0, bottom: -20, left: -10) + }; + + await AssertType( + data, + "{(4,3),(2,1);(6,5),(4,3);(0,-10),(-10,-20)}", + "box[]", + NpgsqlDbType.Box | NpgsqlDbType.Array + ); + + var swappedData = new[] { - Assert.That(actual.X, Is.EqualTo(expected.X).Within(1).Ulps); - Assert.That(actual.Y, Is.EqualTo(expected.Y).Within(1).Ulps); - } + new NpgsqlBox(top: 1, right: 2, bottom: 3, left: 4), + new NpgsqlBox(top: 3, right: 4, bottom: 5, left: 6), + new NpgsqlBox(top: -20, right: -10, bottom: -10, left: 0) + }; - public GeometricTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + await AssertType( + swappedData, + "{(4,3),(2,1);(6,5),(4,3);(0,-10),(-10,-20)}", + "box[]", + NpgsqlDbType.Box | NpgsqlDbType.Array + ); } + + [Test] + public Task Path_closed() + => AssertType( + new NpgsqlPath(new[] { new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4) }, false), + "((1,2),(3,4))", + "path", + NpgsqlDbType.Path); + + [Test] + public Task Path_open() + => AssertType( + new NpgsqlPath(new[] { new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4) }, true), + "[(1,2),(3,4)]", + "path", + NpgsqlDbType.Path); + + [Test] + public Task Polygon() + => AssertType( + new NpgsqlPolygon(new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)), + "((1,2),(3,4))", + "polygon", + NpgsqlDbType.Polygon); + + [Test] + public Task Circle() + => AssertType( + new NpgsqlCircle(1, 2, 0.5), + "<(1,2),0.5>", + "circle", + NpgsqlDbType.Circle); + + public GeometricTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/HstoreTests.cs b/test/Npgsql.Tests/Types/HstoreTests.cs index 8d181b619d..5696cad98b 100644 --- a/test/Npgsql.Tests/Types/HstoreTests.cs +++ b/test/Npgsql.Tests/Types/HstoreTests.cs @@ -4,89 +4,65 @@ using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types -{ - public class HstoreTests : MultiplexingTestBase - { - [Test] - public async Task Basic() - { - using var conn = await OpenConnectionAsync(); +namespace Npgsql.Tests.Types; - var expected = new Dictionary { +public class HstoreTests : MultiplexingTestBase +{ + [Test] + public Task Hstore() + => AssertType( + new Dictionary + { {"a", "3"}, {"b", null}, {"cd", "hello"} - }; - - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Hstore, expected); - cmd.Parameters.AddWithValue("p2", expected); - - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Dictionary))); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - } - } - - [Test] - public async Task Empty() - { - using var conn = await OpenConnectionAsync(); - - var expected = new Dictionary(); + }, + @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", + "hstore", + NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Hstore, expected); - cmd.Parameters.AddWithValue("p2", expected); + [Test] + public Task Hstore_empty() + => AssertType(new Dictionary(), @"", "hstore", NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Dictionary))); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - } - } - - [Test] - public async Task ImmutableDictionary() - { - using var conn = await OpenConnectionAsync(); - - var builder = ImmutableDictionary.Empty; - builder.Add("a", "3"); - builder.Add("b", null); - builder.Add("cd", "hello"); - var expected = builder.ToImmutableDictionary(); + [Test] + public Task Hstore_as_ImmutableDictionary() + { + var builder = ImmutableDictionary.Empty.ToBuilder(); + builder.Add("a", "3"); + builder.Add("b", null); + builder.Add("cd", "hello"); + var immutableDictionary = builder.ToImmutableDictionary(); - using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Hstore, expected); - cmd.Parameters.AddWithValue("p2", expected); + return AssertType( + immutableDictionary, + @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", + "hstore", + NpgsqlDbType.Hstore, + isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); + } - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) + [Test] + public Task Hstore_as_IDictionary() + => AssertType>( + new Dictionary { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Dictionary))); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue>(i), Is.EqualTo(expected)); - } - } - - [OneTimeSetUp] - public async Task SetUp() - { - using var conn = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(conn, "9.1", "Hstore introduced in PostgreSQL 9.1"); - await TestUtil.EnsureExtensionAsync(conn, "hstore", "9.1"); - } + { "a", "3" }, + { "b", null }, + { "cd", "hello" } + }, + @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", + "hstore", + NpgsqlDbType.Hstore, + isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); - public HstoreTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [OneTimeSetUp] + public async Task SetUp() + { + using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "9.1", "Hstore introduced in PostgreSQL 9.1"); + await TestUtil.EnsureExtensionAsync(conn, "hstore", "9.1"); } + + public HstoreTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/InternalTypeTests.cs b/test/Npgsql.Tests/Types/InternalTypeTests.cs index 451ef58ca2..a5d69664a4 100644 --- a/test/Npgsql.Tests/Types/InternalTypeTests.cs +++ b/test/Npgsql.Tests/Types/InternalTypeTests.cs @@ -2,106 +2,99 @@ using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class InternalTypeTests : MultiplexingTestBase { - public class InternalTypeTests : MultiplexingTestBase + [Test] + public async Task Read_internal_char() { - [Test] - public async Task ReadInternalChar() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT typdelim FROM pg_type WHERE typname='int4'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetChar(0), Is.EqualTo(',')); - Assert.That(reader.GetValue(0), Is.EqualTo(',')); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(',')); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(char))); - } - } + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT typdelim FROM pg_type WHERE typname='int4'", conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetChar(0), Is.EqualTo(',')); + Assert.That(reader.GetValue(0), Is.EqualTo(',')); + Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(',')); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(char))); + } - [Test] - [TestCase(NpgsqlDbType.Oid)] - [TestCase(NpgsqlDbType.Regtype)] - [TestCase(NpgsqlDbType.Regconfig)] - public async Task InternalUintTypes(NpgsqlDbType npgsqlDbType) - { - var postgresType = npgsqlDbType.ToString().ToLowerInvariant(); - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand($"SELECT @max, 4294967295::{postgresType}, @eight, 8::{postgresType}", conn); - cmd.Parameters.AddWithValue("max", npgsqlDbType, uint.MaxValue); - cmd.Parameters.AddWithValue("eight", npgsqlDbType, 8u); - using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); + [Test] + [TestCase(NpgsqlDbType.Oid)] + [TestCase(NpgsqlDbType.Regtype)] + [TestCase(NpgsqlDbType.Regconfig)] + public async Task Internal_uint_types(NpgsqlDbType npgsqlDbType) + { + var postgresType = npgsqlDbType.ToString().ToLowerInvariant(); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT @max, 4294967295::{postgresType}, @eight, 8::{postgresType}", conn); + cmd.Parameters.AddWithValue("max", npgsqlDbType, uint.MaxValue); + cmd.Parameters.AddWithValue("eight", npgsqlDbType, 8u); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); - for (var i = 0; i < reader.FieldCount; i++) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(uint))); + for (var i = 0; i < reader.FieldCount; i++) + Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(uint))); - Assert.That(reader.GetValue(0), Is.EqualTo(uint.MaxValue)); - Assert.That(reader.GetValue(1), Is.EqualTo(uint.MaxValue)); - Assert.That(reader.GetValue(2), Is.EqualTo(8u)); - Assert.That(reader.GetValue(3), Is.EqualTo(8u)); - } + Assert.That(reader.GetValue(0), Is.EqualTo(uint.MaxValue)); + Assert.That(reader.GetValue(1), Is.EqualTo(uint.MaxValue)); + Assert.That(reader.GetValue(2), Is.EqualTo(8u)); + Assert.That(reader.GetValue(3), Is.EqualTo(8u)); + } - [Test] - public async Task Tid() - { - var expected = new NpgsqlTid(3, 5); - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "SELECT '(1234,40000)'::tid, @p::tid"; - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Tid, expected); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.AreEqual(1234, reader.GetFieldValue(0).BlockNumber); - Assert.AreEqual(40000, reader.GetFieldValue(0).OffsetNumber); - Assert.AreEqual(expected.BlockNumber, reader.GetFieldValue(1).BlockNumber); - Assert.AreEqual(expected.OffsetNumber, reader.GetFieldValue(1).OffsetNumber); - } - } - } + [Test] + public async Task Tid() + { + var expected = new NpgsqlTid(3, 5); + using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT '(1234,40000)'::tid, @p::tid"; + cmd.Parameters.AddWithValue("p", NpgsqlDbType.Tid, expected); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.AreEqual(1234, reader.GetFieldValue(0).BlockNumber); + Assert.AreEqual(40000, reader.GetFieldValue(0).OffsetNumber); + Assert.AreEqual(expected.BlockNumber, reader.GetFieldValue(1).BlockNumber); + Assert.AreEqual(expected.OffsetNumber, reader.GetFieldValue(1).OffsetNumber); + } - #region NpgsqlLogSequenceNumber / PgLsn + #region NpgsqlLogSequenceNumber / PgLsn - static readonly TestCaseData[] EqualsObjectCases = { - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), null).Returns(false), - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new object()).Returns(false), - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), 1ul).Returns(false), // no implicit cast - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), "0/0").Returns(false), // no implicit cast/parsing - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new NpgsqlLogSequenceNumber(1ul)).Returns(true), - }; + static readonly TestCaseData[] EqualsObjectCases = { + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), null).Returns(false), + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new object()).Returns(false), + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), 1ul).Returns(false), // no implicit cast + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), "0/0").Returns(false), // no implicit cast/parsing + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new NpgsqlLogSequenceNumber(1ul)).Returns(true), + }; - [Test, TestCaseSource(nameof(EqualsObjectCases))] - public bool NpgsqlLogSequenceNumberEquals(NpgsqlLogSequenceNumber lsn, object? obj) - => lsn.Equals(obj); + [Test, TestCaseSource(nameof(EqualsObjectCases))] + public bool NpgsqlLogSequenceNumber_equals(NpgsqlLogSequenceNumber lsn, object? obj) + => lsn.Equals(obj); - [Test] - public async Task PgLsn() - { - var expected1 = new NpgsqlLogSequenceNumber(42949672971ul); - Assert.AreEqual(expected1, NpgsqlLogSequenceNumber.Parse("A/B")); - await using var conn = await OpenConnectionAsync(); - using var cmd = conn.CreateCommand(); - cmd.CommandText = "SELECT 'A/B'::pg_lsn, @p::pg_lsn"; - cmd.Parameters.AddWithValue("p", NpgsqlDbType.PgLsn, expected1); - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - var result1 = reader.GetFieldValue(0); - var result2 = reader.GetFieldValue(1); - Assert.AreEqual(expected1, result1); - Assert.AreEqual(42949672971ul, (ulong)result1); - Assert.AreEqual("A/B", result1.ToString()); - Assert.AreEqual(expected1, result2); - Assert.AreEqual(42949672971ul, (ulong)result2); - Assert.AreEqual("A/B", result2.ToString()); - } + [Test] + public async Task NpgsqlLogSequenceNumber() + { + var expected1 = new NpgsqlLogSequenceNumber(42949672971ul); + Assert.AreEqual(expected1, NpgsqlTypes.NpgsqlLogSequenceNumber.Parse("A/B")); + await using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 'A/B'::pg_lsn, @p::pg_lsn"; + cmd.Parameters.AddWithValue("p", NpgsqlDbType.PgLsn, expected1); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var result1 = reader.GetFieldValue(0); + var result2 = reader.GetFieldValue(1); + Assert.AreEqual(expected1, result1); + Assert.AreEqual(42949672971ul, (ulong)result1); + Assert.AreEqual("A/B", result1.ToString()); + Assert.AreEqual(expected1, result2); + Assert.AreEqual(42949672971ul, (ulong)result2); + Assert.AreEqual("A/B", result2.ToString()); + } - #endregion NpgsqlLogSequenceNumber / PgLsn + #endregion NpgsqlLogSequenceNumber / PgLsn - public InternalTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} - } -} + public InternalTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} +} \ No newline at end of file diff --git a/test/Npgsql.Tests/Types/JsonDynamicTests.cs b/test/Npgsql.Tests/Types/JsonDynamicTests.cs new file mode 100644 index 0000000000..73b0965d12 --- /dev/null +++ b/test/Npgsql.Tests/Types/JsonDynamicTests.cs @@ -0,0 +1,382 @@ +using System; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Npgsql.Properties; +using NpgsqlTypes; +using NUnit.Framework; + +namespace Npgsql.Tests.Types; + +[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Json)] +[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Jsonb)] +[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Json)] +[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Jsonb)] +public class JsonDynamicTests : MultiplexingTestBase +{ +#if NET6_0_OR_GREATER + [Test] + public Task Roundtrip_JsonObject() + => AssertType( + new JsonObject { ["Bar"] = 8 }, + IsJsonb ? """{"Bar": 8}""" : """{"Bar":8}""", + PostgresType, + NpgsqlDbType, + // By default we map JsonObject to jsonb + isDefaultForWriting: IsJsonb, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false, + comparer: (x, y) => x.ToString() == y.ToString()); + + [Test] + public Task Roundtrip_JsonArray() + => AssertType( + new JsonArray { 1, 2, 3 }, + IsJsonb ? "[1, 2, 3]" : "[1,2,3]", + PostgresType, + NpgsqlDbType, + // By default we map JsonArray to jsonb + isDefaultForWriting: IsJsonb, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false, + comparer: (x, y) => x.ToString() == y.ToString()); + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4537")] + public async Task Write_jsonobject_array_without_npgsqldbtype() + { + // By default we map JsonObject to jsonb + if (!IsJsonb) + return; + + await using var conn = await OpenConnectionAsync(); + var tableName = await TestUtil.CreateTempTable(conn, "key SERIAL PRIMARY KEY, ingredients json[]"); + + await using var cmd = new NpgsqlCommand { Connection = conn }; + + var jsonObject1 = new JsonObject + { + { "name", "value1" }, + { "amount", 1 }, + { "unit", "ml" } + }; + + var jsonObject2 = new JsonObject + { + { "name", "value2" }, + { "amount", 2 }, + { "unit", "g" } + }; + + cmd.CommandText = $"INSERT INTO {tableName} (ingredients) VALUES (@p)"; + cmd.Parameters.Add(new("p", new[] { jsonObject1, jsonObject2 })); + await cmd.ExecuteNonQueryAsync(); + } +#endif + + [Test] + public async Task As_poco() + => await AssertType( + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" + : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefault: false); + + [Test] + public async Task As_poco_long() + { + using var conn = CreateConnection(); + var bigString = new string('x', Math.Max(conn.Settings.ReadBufferSize, conn.Settings.WriteBufferSize)); + + await AssertType( + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = bigString, + TemperatureC = 10 + }, + // Warning: in theory jsonb order and whitespace may change across versions + IsJsonb + ? $$"""{"Date": "2019-09-01T00:00:00", "Summary": "{{bigString}}", "TemperatureC": 10}""" + : $$"""{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"{{bigString}}"}""", + PostgresType, + NpgsqlDbType, + isDefault: false); + } + + [Test] + public async Task As_poco_supported_only_with_EnableDynamicJson() + { + // This test uses base.DataSource, which doesn't have EnableDynamicJson() + + var errorMessage = string.Format( + NpgsqlStrings.DynamicJsonNotEnabled, + nameof(WeatherForecast), + nameof(NpgsqlSlimDataSourceBuilder.EnableDynamicJson), + nameof(NpgsqlDataSourceBuilder)); + + var exception = await AssertTypeUnsupportedWrite( + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + PostgresType, + base.DataSource); + + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead( + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" + : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + base.DataSource); + + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task Poco_does_not_stomp_GetValue_string() + { + var dataSource = CreateDataSourceBuilder() + .EnableDynamicJson(new[] {typeof(WeatherForecast)}, new[] {typeof(WeatherForecast)}) + .Build(); + var sqlLiteral = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" + : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{(IsJsonb ? "jsonb" : "json")}", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader.GetValue(0), Is.TypeOf()); + } + + [Test] + public async Task Custom_JsonSerializerOptions() + { + await using var dataSource = CreateDataSourceBuilder() + .ConfigureJsonOptions(new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }) + .EnableDynamicJson() + .Build(); + + await AssertTypeWrite( + dataSource, + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + IsJsonb + ? """{"date": "2019-09-01T00:00:00", "summary": "Partly cloudy", "temperatureC": 10}""" + : """{"date":"2019-09-01T00:00:00","temperatureC":10,"summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefault: false); + } + + [Test, Ignore("TODO We should not change the default type for json/jsonb, it makes little sense.")] + public async Task Poco_default_mapping() + { + var dataSourceBuilder = CreateDataSourceBuilder(); + if (IsJsonb) + dataSourceBuilder.EnableDynamicJson(jsonbClrTypes: new[] { typeof(WeatherForecast) }); + else + dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" + : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Poco_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Poco_polymorphic_mapping_read_parents() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + var value = new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; + + var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertTypeWrite( + dataSource, + value, + sql, + PostgresType, + NpgsqlDbType, + isNpgsqlDbTypeInferredFromClrType: false); + + // GetFieldValue + await AssertTypeRead(dataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), + isDefault: false); + + await AssertTypeRead(dataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + + await AssertTypeRead(dataSource, sql, PostgresType, value, isDefault: false); + } + + + [Test] + public async Task Poco_exact_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + if (IsJsonb) + return; + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(ExtendedDerivedWeatherForecast) }); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType( + dataSource, + new ExtendedDerivedWeatherForecast() + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + """{"TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", + PostgresType, + NpgsqlDbType, + isDefaultForReading: false, + isNpgsqlDbTypeInferredFromClrType: false); + } + + [Test] + public async Task Poco_unspecified_polymorphic_mapping() + { + // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. + // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. + // In this case we don't have any statically mapped base type to check its PolymorphicOptions on. + // Detecting whether the type could be polymorphic would require us to duplicate STJ's nearest polymorphic ancestor search. + if (IsJsonb) + return; + + var value = new ExtendedDerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; + + var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertType( + value, + sql, + PostgresType, + NpgsqlDbType, + isDefault: false); + + await AssertTypeRead(DataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + + await AssertTypeRead(DataSource, sql, PostgresType, value, + comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), isDefault: false); + } + + [JsonDerivedType(typeof(ExtendedDerivedWeatherForecast), typeDiscriminator: "extended")] + record WeatherForecast + { + public DateTime Date { get; set; } + public int TemperatureC { get; set; } + public string Summary { get; set; } = ""; + } + + record DerivedWeatherForecast : WeatherForecast + { + } + + record ExtendedDerivedWeatherForecast : DerivedWeatherForecast + { + public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); + } + + public JsonDynamicTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) + : base(multiplexingMode) + { + DataSource = CreateDataSource(b => b.EnableDynamicJson()); + + if (npgsqlDbType == NpgsqlDbType.Jsonb) + using (var conn = OpenConnection()) + TestUtil.MinimumPgVersion(conn, "9.4.0", "JSONB data type not yet introduced"); + + NpgsqlDbType = npgsqlDbType; + } + + protected override NpgsqlDataSource DataSource { get; } + + bool IsJsonb => NpgsqlDbType == NpgsqlDbType.Jsonb; + string PostgresType => IsJsonb ? "jsonb" : "json"; + readonly NpgsqlDbType NpgsqlDbType; +} diff --git a/test/Npgsql.Tests/Types/JsonPathTests.cs b/test/Npgsql.Tests/Types/JsonPathTests.cs index 8ded8f3796..de49a631e0 100644 --- a/test/Npgsql.Tests/Types/JsonPathTests.cs +++ b/test/Npgsql.Tests/Types/JsonPathTests.cs @@ -1,47 +1,59 @@ -using System.Threading.Tasks; +using System.Data; +using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class JsonPathTests : MultiplexingTestBase { - public class JsonPathTests : MultiplexingTestBase + public JsonPathTests(MultiplexingMode multiplexingMode) + : base(multiplexingMode) { } + + static readonly object[] ReadWriteCases = new[] { - public JsonPathTests(MultiplexingMode multiplexingMode) - : base(multiplexingMode) { } - - static readonly object[] ReadWriteCases = new[] - { - new object[] { "'$'", "$" }, - new object[] { "'$\"varname\"'", "$\"varname\"" }, - }; - - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Read(string query, string expected) - { - using var conn = await OpenConnectionAsync(); - MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); - - using var cmd = new NpgsqlCommand($"SELECT {query}::jsonpath", conn); - using var rdr = await cmd.ExecuteReaderAsync(); - - rdr.Read(); - Assert.That(rdr.GetFieldValue(0), Is.EqualTo(expected)); - Assert.That(rdr.GetTextReader(0).ReadToEnd(), Is.EqualTo(expected)); - } - - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Write(string query, string expected) - { - using var conn = await OpenConnectionAsync(); - MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); - - using var cmd = new NpgsqlCommand($"SELECT 'Passed' WHERE @p::text = {query}::text", conn) { Parameters = { new NpgsqlParameter("p", NpgsqlDbType.JsonPath) { Value = expected } } }; - using var rdr = await cmd.ExecuteReaderAsync(); - - Assert.True(rdr.Read()); - } + new object[] { "'$'", "$" }, + new object[] { "'$\"varname\"'", "$\"varname\"" }, + }; + + [Test] + [TestCase("$")] + [TestCase("$\"varname\"")] + public async Task JsonPath(string jsonPath) + { + using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); + await AssertType( + jsonPath, jsonPath, "jsonpath", NpgsqlDbType.JsonPath, isDefaultForWriting: false, isNpgsqlDbTypeInferredFromClrType: false, + inferredDbType: DbType.Object); + } + + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Read(string query, string expected) + { + using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); + + using var cmd = new NpgsqlCommand($"SELECT {query}::jsonpath", conn); + using var rdr = await cmd.ExecuteReaderAsync(); + + rdr.Read(); + Assert.That(rdr.GetFieldValue(0), Is.EqualTo(expected)); + Assert.That(rdr.GetTextReader(0).ReadToEnd(), Is.EqualTo(expected)); + } + + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Write(string query, string expected) + { + using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); + + using var cmd = new NpgsqlCommand($"SELECT 'Passed' WHERE @p::text = {query}::text", conn) { Parameters = { new NpgsqlParameter("p", NpgsqlDbType.JsonPath) { Value = expected } } }; + using var rdr = await cmd.ExecuteReaderAsync(); + + Assert.True(rdr.Read()); } } diff --git a/test/Npgsql.Tests/Types/JsonTests.cs b/test/Npgsql.Tests/Types/JsonTests.cs index 9b69a0cd63..e7a9b4576e 100644 --- a/test/Npgsql.Tests/Types/JsonTests.cs +++ b/test/Npgsql.Tests/Types/JsonTests.cs @@ -1,251 +1,184 @@ using System; +using System.Data; +using System.IO; using System.Text; using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types -{ - [TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Json)] - [TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Jsonb)] - [TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Json)] - [TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Jsonb)] - public class JsonTests : MultiplexingTestBase - { - [Test] - public async Task RoundtripString() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - const string value = @"{""Key"": ""Value""}"; - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType) { Value = value }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType) { TypedValue = value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetString(i), Is.EqualTo(value)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); - - using (var textReader = reader.GetTextReader(i)) - Assert.That(textReader.ReadToEnd(), Is.EqualTo(value)); - } - } - } - } - - [Test] - public async Task RoundtripLongString() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var sb = new StringBuilder(); - sb.Append(@"{""Key"": """); - sb.Append('x', conn.Settings.WriteBufferSize); - sb.Append(@"""}"); - var value = sb.ToString(); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType) { Value = value }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType) { TypedValue = value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < 2; i++) - { - Assert.That(reader.GetString(i), Is.EqualTo(value)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); - - using (var textReader = reader.GetTextReader(i)) - Assert.That(textReader.ReadToEnd(), Is.EqualTo(value)); - } - } - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] - public async Task RoundtripStringTypes() - { - var expected = "{\"p\":1}"; - // If we serialize to JSONB, Postgres will not store the Json.NET formatting, and will add a space after ':' - var expectedString = NpgsqlDbType.Equals(NpgsqlDbType.Jsonb) ? "{\"p\": 1}" - : "{\"p\":1}"; - - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand(@"SELECT @p1, @p2, @p3", conn); +namespace Npgsql.Tests.Types; - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType) { Value = expected }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType) { Value = expected.ToCharArray() }); - cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType) { Value = Encoding.ASCII.GetBytes(expected) }); - - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expectedString)); - Assert.That(reader.GetFieldValue(1), Is.EqualTo(expectedString.ToCharArray())); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(Encoding.ASCII.GetBytes(expectedString))); - } - - [Test, Ignore("INpgsqlTypeHandler>.Read currently not yet implemented in TextHandler")] - public async Task RoundtripArraySegment() - { - var expected = "{\"p\":1}"; - // If we serialize to JSONB, Postgres will not store the Json.NET formatting, and will add a space after ':' - var expectedString = NpgsqlDbType.Equals(NpgsqlDbType.Jsonb) ? "{\"p\": 1}" - : "{\"p\":1}"; +[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Json)] +[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Jsonb)] +[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Json)] +[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Jsonb)] +public class JsonTests : MultiplexingTestBase +{ + [Test] + public async Task As_string() + => await AssertType("""{"K": "V"}""", """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefaultForWriting: false); - using var conn = OpenConnection(); - using var cmd = new NpgsqlCommand(@"SELECT @p1", conn); + [Test] + public async Task As_string_long() + { + await using var conn = CreateConnection(); - cmd.Parameters.Add(new NpgsqlParameter>("p1", NpgsqlDbType) { Value = new ArraySegment(expected.ToCharArray()) }); + var value = new StringBuilder() + .Append(@"{""K"": """) + .Append('x', conn.Settings.WriteBufferSize) + .Append(@"""}") + .ToString(); - await using var reader = await cmd.ExecuteReaderAsync(); - reader.Read(); - Assert.That(reader.GetFieldValue>(0), Is.EqualTo(expectedString)); - } + await AssertType(value, value, PostgresType, NpgsqlDbType, isDefaultForWriting: false); + } + [Test] + public async Task As_string_with_GetTextReader() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($$"""SELECT '{"K": "V"}'::{{PostgresType}}""", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + using var textReader = await reader.GetTextReaderAsync(0); + Assert.That(await textReader.ReadToEndAsync(), Is.EqualTo(@"{""K"": ""V""}")); + } - [Test] - public async Task ReadJsonDocument() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var value = @"{""Date"":""2019-09-01T00:00:00"",""TemperatureC"":10,""Summary"":""Partly cloudy""}"; - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType) { Value = value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo(PostgresType)); - var root = reader.GetFieldValue(0).RootElement; - Assert.That(root.GetProperty("Date").GetDateTime(), Is.EqualTo(new DateTime(2019, 9, 1))); - Assert.That(root.GetProperty("Summary").GetString(), Is.EqualTo("Partly cloudy")); - Assert.That(root.GetProperty("TemperatureC").GetInt32(), Is.EqualTo(10)); - } - } - } + [Test] + public async Task As_char_array() + => await AssertType("""{"K": "V"}""".ToCharArray(), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + + [Test] + public async Task As_bytes() + => await AssertType("""{"K": "V"}"""u8.ToArray(), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + + [Test] + public async Task Write_as_ReadOnlyMemory_of_byte() + => await AssertTypeWrite(new ReadOnlyMemory("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, + isDefault: false); + + [Test] + public async Task Write_as_ArraySegment_of_char() + => await AssertTypeWrite(new ArraySegment("""{"K": "V"}""".ToCharArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, + isDefault: false); + + [Test] + public Task As_MemoryStream() + => AssertTypeWrite(() => new MemoryStream("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + + [Test] + public async Task As_JsonDocument() + => await AssertType( + JsonDocument.Parse("""{"K": "V"}"""), + IsJsonb ? """{"K": "V"}""" : """{"K":"V"}""", + PostgresType, + NpgsqlDbType, + isDefault: false, + comparer: (x, y) => x.RootElement.GetProperty("K").GetString() == y.RootElement.GetProperty("K").GetString()); + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5540")] + public async Task As_JsonDocument_with_null_root() + => await AssertType( + JsonDocument.Parse("null"), + "null", + PostgresType, + NpgsqlDbType, + isDefault: false, + comparer: (x, y) => x.RootElement.ValueKind == y.RootElement.ValueKind, + skipArrayCheck: true); + + [Test] + public async Task As_JsonElement_with_null_root() + => await AssertType( + JsonDocument.Parse("null").RootElement, + "null", + PostgresType, + NpgsqlDbType, + isDefault: false, + comparer: (x, y) => x.ValueKind == y.ValueKind, + skipArrayCheck: true); + + [Test] + public async Task As_JsonDocument_supported_only_with_SystemTextJson() + { + await using var slimDataSource = new NpgsqlSlimDataSourceBuilder(ConnectionString).Build(); - [Test] - public async Task WriteJsonDocument() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var value = JsonDocument.Parse(@"{""Date"": ""2019-09-01T00:00:00"", ""Summary"": ""Partly cloudy"", ""TemperatureC"": 10}"); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType) { Value = value }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType) { TypedValue = value }); - if (IsJsonb) - { - cmd.CommandText += ", @p3"; - cmd.Parameters.AddWithValue("p3", value); - } - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < reader.FieldCount; i++) - { - // Warning: in theory jsonb order and whitespace may change across versions - Assert.That(reader.GetString(0), Is.EqualTo(IsJsonb - ? @"{""Date"": ""2019-09-01T00:00:00"", ""Summary"": ""Partly cloudy"", ""TemperatureC"": 10}" - : @"{""Date"":""2019-09-01T00:00:00"",""Summary"":""Partly cloudy"",""TemperatureC"":10}")); - } - } - } - } + await AssertTypeUnsupported( + JsonDocument.Parse("""{"K": "V"}"""), + """{"K": "V"}""", + PostgresType, + slimDataSource); + } - [Test] - public async Task WriteObject() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var value = new WeatherForecast - { - Date = new DateTime(2019, 9, 1), - Summary = "Partly cloudy", - TemperatureC = 10 - }; - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType) { Value = value }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType) { TypedValue = value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < 2; i++) - { - // Warning: in theory jsonb order and whitespace may change across versions - Assert.That(reader.GetString(0), Is.EqualTo(IsJsonb - ? @"{""Date"": ""2019-09-01T00:00:00"", ""Summary"": ""Partly cloudy"", ""TemperatureC"": 10}" - : @"{""Date"":""2019-09-01T00:00:00"",""TemperatureC"":10,""Summary"":""Partly cloudy""}")); - } - } - } - } + [Test] + public Task Roundtrip_string() + => AssertType( + @"{""p"": 1}", + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_char_array() + => AssertType( + @"{""p"": 1}".ToCharArray(), + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + public Task Roundtrip_byte_array() + => AssertType( + Encoding.ASCII.GetBytes(@"{""p"": 1}"), + @"{""p"": 1}", + PostgresType, + NpgsqlDbType, + isDefault: false, + isNpgsqlDbTypeInferredFromClrType: false); + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/2811")] + [IssueLink("https://github.com/npgsql/efcore.pg/issues/1177")] + [IssueLink("https://github.com/npgsql/efcore.pg/issues/1082")] + public async Task Can_read_two_json_documents() + { + await using var conn = await OpenConnectionAsync(); - [Test] - public async Task ReadObject() + JsonDocument car; + await using (var cmd = new NpgsqlCommand("""SELECT '{"key" : "foo"}'::jsonb""", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - var value = @"{""Date"":""2019-09-01T00:00:00"",""TemperatureC"":10,""Summary"":""Partly cloudy""}"; - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType) { Value = value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo(PostgresType)); - var actual = reader.GetFieldValue(0); - Assert.That(actual.Date, Is.EqualTo(new DateTime(2019, 9, 1))); - Assert.That(actual.Summary, Is.EqualTo("Partly cloudy")); - Assert.That(actual.TemperatureC, Is.EqualTo(10)); - } - } + reader.Read(); + car = reader.GetFieldValue(0); } - class WeatherForecast + await using (var cmd = new NpgsqlCommand("""SELECT '{"key" : "bar"}'::jsonb""", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) { - public DateTime Date { get; set; } - public int TemperatureC { get; set; } - public string Summary { get; set; } = ""; + reader.Read(); + reader.GetFieldValue(0); } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/2811")] - [IssueLink("https://github.com/npgsql/efcore.pg/issues/1177")] - [IssueLink("https://github.com/npgsql/efcore.pg/issues/1082")] - public async Task CanReadTwoJsonDocuments() - { - using var conn = await OpenConnectionAsync(); - - JsonDocument car; - using (var cmd = new NpgsqlCommand(@"SELECT '{""key"" : ""foo""}'::jsonb", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - car = reader.GetFieldValue(0); - } - - using (var cmd = new NpgsqlCommand(@"SELECT '{""key"" : ""bar""}'::jsonb", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - reader.GetFieldValue(0); - } - - Assert.That(car.RootElement.GetProperty("key").GetString(), Is.EqualTo("foo")); - } + Assert.That(car.RootElement.GetProperty("key").GetString(), Is.EqualTo("foo")); + } - public JsonTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) - : base(multiplexingMode) - { + public JsonTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) + : base(multiplexingMode) + { + if (npgsqlDbType == NpgsqlDbType.Jsonb) using (var conn = OpenConnection()) TestUtil.MinimumPgVersion(conn, "9.4.0", "JSONB data type not yet introduced"); - NpgsqlDbType = npgsqlDbType; - } - bool IsJsonb => NpgsqlDbType == NpgsqlDbType.Jsonb; - string PostgresType => IsJsonb ? "jsonb" : "json"; - readonly NpgsqlDbType NpgsqlDbType; + NpgsqlDbType = npgsqlDbType; } + + bool IsJsonb => NpgsqlDbType == NpgsqlDbType.Jsonb; + string PostgresType => IsJsonb ? "jsonb" : "json"; + readonly NpgsqlDbType NpgsqlDbType; } diff --git a/test/Npgsql.Tests/Types/LQueryTests.cs b/test/Npgsql.Tests/Types/LQueryTests.cs deleted file mode 100644 index d84ff93558..0000000000 --- a/test/Npgsql.Tests/Types/LQueryTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.Collections; -using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; - -namespace Npgsql.Tests.Types -{ - [TestFixture(MultiplexingMode.NonMultiplexing, false)] - [TestFixture(MultiplexingMode.NonMultiplexing, true)] - [TestFixture(MultiplexingMode.Multiplexing, false)] - [TestFixture(MultiplexingMode.Multiplexing, true)] - public class LQueryTests : TypeHandlerTestBase - { - public LQueryTests(MultiplexingMode multiplexingMode, bool useTypeName) : base( - multiplexingMode, - useTypeName ? null : NpgsqlDbType.LQuery, - useTypeName ? "lquery" : null, - minVersion: "13.0") - { } - - public static IEnumerable TestCases() => new[] - { - new object[] { "'Top.Science.*'::lquery", "Top.Science.*" } - }; - - [OneTimeSetUp] - public async Task SetUp() - { - using var conn = await OpenConnectionAsync(); - await TestUtil.EnsureExtensionAsync(conn, "ltree"); - } - } -} diff --git a/test/Npgsql.Tests/Types/LTreeTests.cs b/test/Npgsql.Tests/Types/LTreeTests.cs index a5f66a0412..f836b49ca0 100644 --- a/test/Npgsql.Tests/Types/LTreeTests.cs +++ b/test/Npgsql.Tests/Types/LTreeTests.cs @@ -1,33 +1,68 @@ -using System.Collections; -using System.Threading.Tasks; +using System.Threading.Tasks; +using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class LTreeTests : MultiplexingTestBase { - [TestFixture(MultiplexingMode.NonMultiplexing, false)] - [TestFixture(MultiplexingMode.NonMultiplexing, true)] - [TestFixture(MultiplexingMode.Multiplexing, false)] - [TestFixture(MultiplexingMode.Multiplexing, true)] - public class LTreeTests : TypeHandlerTestBase + [Test] + public Task LQuery() + => AssertType("Top.Science.*", "Top.Science.*", "lquery", NpgsqlDbType.LQuery, isDefaultForWriting: false); + + [Test] + public Task LTree() + => AssertType("Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false); + + [Test] + public Task LTxtQuery() + => AssertType("Science & Astronomy", "Science & Astronomy", "ltxtquery", NpgsqlDbType.LTxtQuery, isDefaultForWriting: false); + + [Test] + public async Task LTree_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.LTreeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableLTree), nameof(NpgsqlSlimDataSourceBuilder)); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + + var exception = + await AssertTypeUnsupportedRead>("Top.Science.Astronomy", "ltree", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + exception = await AssertTypeUnsupportedWrite("Top.Science.Astronomy", "ltree", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableLTree() { - public LTreeTests(MultiplexingMode multiplexingMode, bool useTypeName) : base( - multiplexingMode, - useTypeName ? null : NpgsqlDbType.LTree, - useTypeName ? "ltree" : null, - minVersion: "13.0") - { } - - public static IEnumerable TestCases() => new[] - { - new object[] { "'Top.Science.Astronomy'::ltree", "Top.Science.Astronomy" } - }; - - [OneTimeSetUp] - public async Task SetUp() - { - using var conn = await OpenConnectionAsync(); - await TestUtil.EnsureExtensionAsync(conn, "ltree"); - } + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableLTree(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false, skipArrayCheck: true); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableArrays() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableLTree(); + dataSourceBuilder.EnableArrays(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false); } + + [OneTimeSetUp] + public async Task SetUp() + { + await using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "13.0"); + await TestUtil.EnsureExtensionAsync(conn, "ltree"); + } + + public LTreeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/LTxtQueryTests.cs b/test/Npgsql.Tests/Types/LTxtQueryTests.cs deleted file mode 100644 index eb17efedbc..0000000000 --- a/test/Npgsql.Tests/Types/LTxtQueryTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.Collections; -using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; - -namespace Npgsql.Tests.Types -{ - [TestFixture(MultiplexingMode.NonMultiplexing, false)] - [TestFixture(MultiplexingMode.NonMultiplexing, true)] - [TestFixture(MultiplexingMode.Multiplexing, false)] - [TestFixture(MultiplexingMode.Multiplexing, true)] - public class LTxtQueryTests : TypeHandlerTestBase - { - public LTxtQueryTests(MultiplexingMode multiplexingMode, bool useTypeName) : base( - multiplexingMode, - useTypeName ? null : NpgsqlDbType.LTxtQuery, - useTypeName ? "ltxtquery" : null, - minVersion: "13.0") - { } - - public static IEnumerable TestCases() => new[] - { - new object[] { "'Science & Astronomy'::ltxtquery", "Science & Astronomy" } - }; - - [OneTimeSetUp] - public async Task SetUp() - { - using var conn = await OpenConnectionAsync(); - await TestUtil.EnsureExtensionAsync(conn, "ltree"); - } - } -} diff --git a/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs new file mode 100644 index 0000000000..c500324986 --- /dev/null +++ b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs @@ -0,0 +1,107 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using Npgsql.Internal.ResolverFactories; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Util.Statics; + +namespace Npgsql.Tests.Types; + +// Since this test suite manipulates TimeZone, it is incompatible with multiplexing +[NonParallelizable] +public class LegacyDateTimeTests : TestBase +{ + [Test] + public Task Timestamp_with_all_DateTime_kinds([Values] DateTimeKind kind) + => AssertType( + new DateTime(1998, 4, 12, 13, 26, 38, 789, kind), + "1998-04-12 13:26:38.789", + "timestamp without time zone", + NpgsqlDbType.Timestamp, + DbType.DateTime); + + [Test] + public async Task Timestamp_read_as_Unspecified_DateTime() + { + await using var command = DataSource.CreateCommand("SELECT '2020-03-01T10:30:00'::timestamp"); + var dateTime = (DateTime)(await command.ExecuteScalarAsync())!; + Assert.That(dateTime.Kind, Is.EqualTo(DateTimeKind.Unspecified)); + } + + [Test] + public async Task Timestamptz_negative_infinity() + { + var dto = await AssertType(DateTimeOffset.MinValue, "-infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, isDefaultForReading: false); + Assert.That(dto.Offset, Is.EqualTo(TimeSpan.Zero)); + } + + [Test] + public async Task Timestamptz_infinity() + { + var dto = await AssertType( + DateTimeOffset.MaxValue, "infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTimeOffset, + isDefaultForReading: false); + Assert.That(dto.Offset, Is.EqualTo(TimeSpan.Zero)); + } + + [Test] + [TestCase(DateTimeKind.Utc, TestName = "Timestamptz_write_utc_DateTime_does_not_convert")] + [TestCase(DateTimeKind.Unspecified, TestName = "Timestamptz_write_unspecified_DateTime_does_not_convert")] + public Task Timestamptz_write_utc_DateTime_does_not_convert(DateTimeKind kind) + => AssertTypeWrite( + new DateTime(1998, 4, 12, 13, 26, 38, 789, kind), + "1998-04-12 15:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isDefault: false); + + [Test] + public Task Timestamptz_local_DateTime_converts() + { + // In legacy mode, we convert local DateTime to UTC when writing, and convert to local when reading, + // using the machine time zone. + var dateTime = new DateTime(1998, 4, 12, 13, 26, 38, 789, DateTimeKind.Utc).ToLocalTime(); + + return AssertType( + dateTime, + "1998-04-12 15:26:38.789+02", + "timestamp with time zone", + NpgsqlDbType.TimestampTz, + DbType.DateTimeOffset, + isDefaultForWriting: false); + } + + NpgsqlDataSource _dataSource = null!; + protected override NpgsqlDataSource DataSource => _dataSource; + + [OneTimeSetUp] + public void Setup() + { +#if DEBUG + LegacyTimestampBehavior = true; + _dataSource = CreateDataSource(builder => + { + // Can't use the static AdoTypeInfoResolver instance, it already captured the feature flag. + builder.AddTypeInfoResolverFactory(new AdoTypeInfoResolverFactory()); + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); + NpgsqlDataSourceBuilder.ResetGlobalMappings(overwrite: true); +#else + Assert.Ignore( + "Legacy DateTime tests rely on the Npgsql.EnableLegacyTimestampBehavior AppContext switch and can only be run in DEBUG builds"); +#endif + } + +#if DEBUG + [OneTimeTearDown] + public void Teardown() + { + LegacyTimestampBehavior = false; + _dataSource.Dispose(); + NpgsqlDataSourceBuilder.ResetGlobalMappings(overwrite: true); + } +#endif +} diff --git a/test/Npgsql.Tests/Types/MiscTypeTests.cs b/test/Npgsql.Tests/Types/MiscTypeTests.cs index f0706a1153..d689a268ef 100644 --- a/test/Npgsql.Tests/Types/MiscTypeTests.cs +++ b/test/Npgsql.Tests/Types/MiscTypeTests.cs @@ -1,461 +1,197 @@ using System; using System.Data; -using System.Linq; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on PostgreSQL types which don't fit elsewhere +/// +class MiscTypeTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL types which don't fit elsewhere - /// - class MiscTypeTests : MultiplexingTestBase + [Test] + public async Task Boolean() { - [Test, Description("Resolves a base type handler via the different pathways")] - public async Task BaseTypeResolution() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(BaseTypeResolution), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; - - using (var conn = await OpenConnectionAsync(csb)) - { - // Resolve type by NpgsqlDbType - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Integer, DBNull.Value); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); - } - } - - // Resolve type by DbType - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", DbType.Int32) { Value = DBNull.Value }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName="p", Value = 8 }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT 8", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("integer")); - } - } - } - - /// - /// https://www.postgresql.org/docs/current/static/datatype-boolean.html - /// - [Test, Description("Roundtrips a bool")] - public async Task Bool() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Boolean); - var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Boolean); - var p3 = new NpgsqlParameter("p3", DbType.Boolean); - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = true }; - Assert.That(p4.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Boolean)); - Assert.That(p4.DbType, Is.EqualTo(DbType.Boolean)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - p1.Value = false; - p2.Value = p3.Value = true; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - Assert.That(reader.GetBoolean(0), Is.False); - - for (var i = 1; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetBoolean(i), Is.True); - Assert.That(reader.GetValue(i), Is.True); - Assert.That(reader.GetProviderSpecificValue(i), Is.True); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(bool))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("boolean")); - } - } - } - } - - /// - /// https://www.postgresql.org/docs/current/static/datatype-uuid.html - /// - [Test, Description("Roundtrips a UUID")] - public async Task Uuid() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var expected = new Guid("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Uuid); - var p2 = new NpgsqlParameter("p2", DbType.Guid); - var p3 = new NpgsqlParameter {ParameterName = "p3", Value = expected}; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetGuid(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(Guid))); - } - } - } - } - - [Test, Description("Makes sure that the PostgreSQL 'unknown' type (OID 705) is read properly")] - public async Task ReadUnknown() - { - const string expected = "some_text"; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT '{expected}'", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(expected)); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected.ToCharArray())); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); - } - } + await AssertType(true, "true", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); + await AssertType(false, "false", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); - [Test, Description("Roundtrips a null value")] - public async Task Null() - { - using (var conn = await OpenConnectionAsync()) - { - using (var cmd = new NpgsqlCommand("SELECT @p1::TEXT, @p2::TEXT, @p3::TEXT", conn)) - { - cmd.Parameters.AddWithValue("p1", DBNull.Value); - cmd.Parameters.Add(new NpgsqlParameter("p2", null)); - cmd.Parameters.Add(new NpgsqlParameter("p3", DBNull.Value)); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.IsDBNull(i)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); - } - } - } - - // Setting non-generic NpgsqlParameter.Value is not allowed, only DBNull.Value - using (var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn)) - { - cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Text, null!); - Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); - } - } - } + // The literal representations for bools inside array are different ({t,f} instead of true/false, so we check separately. + await AssertType(new[] { true, false }, "{t,f}", "boolean[]", NpgsqlDbType.Boolean | NpgsqlDbType.Array); + } - [Test, Description("PostgreSQL records should be returned as arrays of objects")] - [IssueLink("https://github.com/npgsql/npgsql/issues/724")] - [IssueLink("https://github.com/npgsql/npgsql/issues/1980")] - public async Task Record() - { - var recordLiteral = "(1,'foo'::text)::record"; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - var record = (object[])reader[0]; - Assert.That(record[0], Is.EqualTo(1)); - Assert.That(record[1], Is.EqualTo("foo")); + [Test] + public Task Uuid() + => AssertType( + new Guid("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", + "uuid", NpgsqlDbType.Uuid, DbType.Guid); - var arr = (object[][])reader[1]; - Assert.That(arr.Length, Is.EqualTo(2)); - Assert.That(arr[0][0], Is.EqualTo(1)); - Assert.That(arr[1][0], Is.EqualTo(1)); - } - } + [Test, Description("Makes sure that the PostgreSQL 'unknown' type (OID 705) is read properly")] + public async Task Read_unknown() + { + const string expected = "some_text"; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT '{expected}'", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo(expected)); + Assert.That(reader.GetValue(0), Is.EqualTo(expected)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected.ToCharArray())); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); + } - [Test] - public async Task Domain() + [Test] + public async Task Null() + { + await using var conn = await OpenConnectionAsync(); + await using (var cmd = new NpgsqlCommand("SELECT @p1::TEXT, @p2::TEXT, @p3::TEXT", conn)) { - using var conn = await OpenConnectionAsync(); - await using var _ = await GetTempTypeName(conn, out var type); - await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {type} AS text"); - Assert.That(await conn.ExecuteScalarAsync($"SELECT 'foo'::{type}"), Is.EqualTo("foo")); - } + cmd.Parameters.AddWithValue("p1", DBNull.Value); + cmd.Parameters.Add(new NpgsqlParameter("p2", null)); + cmd.Parameters.Add(new NpgsqlParameter("p3", DBNull.Value)); - [Test, Description("Makes sure that setting DbType.Object makes Npgsql infer the type")] - [IssueLink("https://github.com/npgsql/npgsql/issues/694")] - public async Task DbTypeCausesInference() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + for (var i = 0; i < cmd.Parameters.Count; i++) { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName="p", DbType = DbType.Object, Value = 3 }); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(3)); + Assert.That(reader.IsDBNull(i)); + Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); } } - #region Unrecognized types - - [Test, Description("Attempts to retrieve an unrecognized type without marking it as unknown, triggering an exception")] - public async Task UnrecognizedBinary() + // Setting non-generic NpgsqlParameter.Value to null is not allowed, only DBNull.Value + await using (var cmd = new NpgsqlCommand("SELECT @p4::TEXT", conn)) { - if (IsMultiplexing) - return; - - using (var conn = await OpenConnectionAsync()) - { - conn.TypeMapper.RemoveMapping("boolean"); - using (var cmd = new NpgsqlCommand("SELECT TRUE", conn)) - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); - } - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + cmd.Parameters.AddWithValue("p4", NpgsqlDbType.Text, null!); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); } - [Test, Description("Retrieves a type as an unknown type, i.e. untreated string")] - public async Task AllResultTypesAreUnknown() + // Setting generic NpgsqlParameter.Value to null is not allowed, only DBNull.Value + await using (var cmd = new NpgsqlCommand("SELECT @p4::TEXT", conn)) { - if (IsMultiplexing) - return; - - using (var conn = await OpenConnectionAsync()) - { - conn.TypeMapper.RemoveMapping("bool"); - - using (var cmd = new NpgsqlCommand("SELECT TRUE", conn)) - { - cmd.AllResultTypesAreUnknown = true; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); - Assert.That(reader.GetString(0), Is.EqualTo("t")); - } - } - } - } - - [Test, Description("Mixes and matches an unknown type with a known type")] - public async Task UnknownResultTypeList() - { - if (IsMultiplexing) - return; - - using (var conn = await OpenConnectionAsync()) - { - conn.TypeMapper.RemoveMapping("bool"); - - using (var cmd = new NpgsqlCommand("SELECT TRUE, 8", conn)) - { - cmd.UnknownResultTypeList = new[] { true, false }; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); - Assert.That(reader.GetString(0), Is.EqualTo("t")); - Assert.That(reader.GetInt32(1), Is.EqualTo(8)); - } - } - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/711")] - public async Task KnownTypeAsUnknown() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT 8", conn)) - { - cmd.AllResultTypesAreUnknown = true; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("8")); - } + cmd.Parameters.Add(new NpgsqlParameter("p4", NpgsqlDbType.Text) { Value = null! }); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); } + } - [Test, Description("Sends a null value parameter with no NpgsqlDbType or DbType, but with context for the backend to handle it")] - public async Task UnrecognizedNull() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn)) - { - var p = new NpgsqlParameter("p", DBNull.Value); - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.IsDBNull(0)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); - } - } - } + [Test, Description("Makes sure that setting DbType.Object makes Npgsql infer the type")] + [IssueLink("https://github.com/npgsql/npgsql/issues/694")] + public async Task DbType_causes_inference() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { ParameterName="p", DbType = DbType.Object, Value = 3 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(3)); + } - [Test, Description("Sends a value parameter with an explicit NpgsqlDbType.Unknown, but with context for the backend to handle it")] - public async Task SendUnknown() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::INT4", conn)) - { - var p = new NpgsqlParameter("p", "8"); - cmd.Parameters.Add(p); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(int))); - Assert.That(reader.GetInt32(0), Is.EqualTo(8)); - } - } - } + #region Unrecognized types - #endregion + [Test, Description("Retrieves a type as an unknown type, i.e. untreated string")] + public async Task AllResultTypesAreUnknown() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT TRUE", conn); + cmd.AllResultTypesAreUnknown = true; + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); + Assert.That(reader.GetString(0), Is.EqualTo("t")); + } - [Test] - public async Task Int2Vector() - { - var expected = new short[] { 4, 5, 6 }; - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - TestUtil.MinimumPgVersion(conn, "9.1.0"); - cmd.CommandText = "SELECT @p::int2vector"; - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Int2Vector, expected); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); - } - } - } + [Test, Description("Mixes and matches an unknown type with a known type")] + public async Task UnknownResultTypeList() + { + if (IsMultiplexing) + return; + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT TRUE, 8", conn); + cmd.UnknownResultTypeList = new[] { true, false }; + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); + Assert.That(reader.GetString(0), Is.EqualTo("t")); + Assert.That(reader.GetInt32(1), Is.EqualTo(8)); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1138")] - public async Task Void() - { - using (var conn = await OpenConnectionAsync()) - Assert.That(await conn.ExecuteScalarAsync("SELECT pg_sleep(0)"), Is.SameAs(DBNull.Value)); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/711")] + public async Task Known_type_as_unknown() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 8", conn); + cmd.AllResultTypesAreUnknown = true; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("8")); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1364")] - public async Task UnsupportedDbType() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - Assert.That(() => cmd.Parameters.Add(new NpgsqlParameter("p", DbType.UInt32) { Value = 8u }), - Throws.Exception.TypeOf()); - } - } + [Test, Description("Sends a null value parameter with no NpgsqlDbType or DbType, but with context for the backend to handle it")] + public async Task Unrecognized_null() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn); + var p = new NpgsqlParameter("p", DBNull.Value); + cmd.Parameters.Add(p); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.IsDBNull(0)); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); + } - // Older tests + [Test, Description("Sends a value parameter with an explicit NpgsqlDbType.Unknown, but with context for the backend to handle it")] + public async Task Send_unknown() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p::INT4", conn); + var p = new NpgsqlParameter("p", "8"); + cmd.Parameters.Add(p); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(int))); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + } - [Test] - public async Task Bug1011085() - { - // Money format is not set in accordance with the system locale format - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select :moneyvalue", conn)) - { - var expectedValue = 8.99m; - command.Parameters.Add("moneyvalue", NpgsqlDbType.Money).Value = expectedValue; - var result = (decimal?)await command.ExecuteScalarAsync(); - Assert.AreEqual(expectedValue, result); + #endregion - expectedValue = 100m; - command.Parameters[0].Value = expectedValue; - result = (decimal?)await command.ExecuteScalarAsync(); - Assert.AreEqual(expectedValue, result); - expectedValue = 72.25m; - command.Parameters[0].Value = expectedValue; - result = (decimal?)await command.ExecuteScalarAsync(); - Assert.AreEqual(expectedValue, result); - } - } + [Test] + public async Task ObjectArray() + { + await AssertTypeWrite(new object?[] { (short)4, null, (long)5, 6 }, "{4,NULL,5,6}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefault: false); + await AssertTypeWrite(new object?[] { "text", null, DBNull.Value, "chars".ToCharArray(), 'c' }, "{text,NULL,NULL,chars,c}", "text[]", NpgsqlDbType.Text | NpgsqlDbType.Array, isDefault: false); - [Test] - public async Task TestUUIDDataType() - { - using (var conn = await OpenConnectionAsync()) - await using (var _ = await GetTempTableName(conn, out var table)) + await using var dataSource = CreateDataSource(b => b.ConnectionStringBuilder.Timezone = "Europe/Berlin"); + await AssertTypeWrite(dataSource, new object?[] { DateTime.UnixEpoch, null, DBNull.Value, DateTime.UnixEpoch.AddDays(1) }, "{\"1970-01-01 01:00:00+01\",NULL,NULL,\"1970-01-02 01:00:00+01\"}", "timestamp with time zone[]", NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, isDefault: false); + Assert.ThrowsAsync(() => AssertTypeWrite(dataSource, new object?[] { - var createTable = $@" -CREATE TABLE {table} ( - person_id serial PRIMARY KEY NOT NULL, - person_uuid uuid NOT NULL -) WITH(OIDS=FALSE);"; - var command = new NpgsqlCommand(createTable, conn); - await command.ExecuteNonQueryAsync(); + DateTime.Now, null, DBNull.Value, DateTime.UnixEpoch.AddDays(1) + }, "{\"1970-01-01 01:00:00+01\",NULL,NULL,\"1970-01-02 01:00:00+01\"}", "timestamp with time zone[]", + NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, isDefault: false)); + } - var uuidDbParam = new NpgsqlParameter(":param1", NpgsqlDbType.Uuid); - uuidDbParam.Value = Guid.NewGuid(); + [Test] + public Task Int2Vector() + => AssertType(new short[] { 4, 5, 6 }, "4 5 6", "int2vector", NpgsqlDbType.Int2Vector, isDefault: false); - command = new NpgsqlCommand($"INSERT INTO {table} (person_uuid) VALUES (:param1);", conn); - command.Parameters.Add(uuidDbParam); - await command.ExecuteNonQueryAsync(); + [Test] + public Task Oidvector() + => AssertType(new uint[] { 4, 5, 6 }, "4 5 6", "oidvector", NpgsqlDbType.Oidvector, isDefault: false); - command = new NpgsqlCommand($"SELECT person_uuid::uuid FROM {table} LIMIT 1", conn); - var result = await command.ExecuteScalarAsync(); - Assert.AreEqual(typeof(Guid), result!.GetType()); - } - } - - [Test] - public async Task OidVector() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "Select '1 2 3'::oidvector, :p1"; - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Oidvector, new uint[] { 4, 5, 6 }); - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.AreEqual(typeof(uint[]), rdr.GetValue(0).GetType()); - Assert.AreEqual(typeof(uint[]), rdr.GetValue(1).GetType()); - Assert.IsTrue(rdr.GetFieldValue(0).SequenceEqual(new uint[] { 1, 2, 3 })); - Assert.IsTrue(rdr.GetFieldValue(1).SequenceEqual(new uint[] { 4, 5, 6 })); - } - } - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1138")] + public async Task Void() + { + await using var conn = await OpenConnectionAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT pg_sleep(0)"), Is.SameAs(null)); + } - public MiscTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1364")] + public async Task Unsupported_DbType() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + Assert.That(() => cmd.Parameters.Add(new NpgsqlParameter("p", DbType.UInt32) { Value = 8u }), + Throws.Exception.TypeOf()); } + + public MiscTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/MoneyTests.cs b/test/Npgsql.Tests/Types/MoneyTests.cs index c3e7be963a..4c38f3d111 100644 --- a/test/Npgsql.Tests/Types/MoneyTests.cs +++ b/test/Npgsql.Tests/Types/MoneyTests.cs @@ -1,109 +1,62 @@ -using System; -using System.Data; +using System.Data; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class MoneyTests : TestBase { - public class MoneyTests : MultiplexingTestBase + static readonly object[] MoneyValues = new[] { - static readonly object[] ReadWriteCases = new[] - { - new object[] { "1.22::money", 1.22M }, - new object[] { "1000.22::money", 1000.22M }, - new object[] { "1000000.22::money", 1000000.22M }, - new object[] { "1000000000.22::money", 1000000000.22M }, - new object[] { "1000000000000.22::money", 1000000000000.22M }, - new object[] { "1000000000000000.22::money", 1000000000000000.22M }, - - new object[] { "(+92233720368547758.07::numeric)::money", +92233720368547758.07M }, - new object[] { "(-92233720368547758.08::numeric)::money", -92233720368547758.08M }, - }; - - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Read(string query, decimal expected) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT " + query, conn)) - { - Assert.That( - decimal.GetBits((decimal)(await cmd.ExecuteScalarAsync())!), - Is.EqualTo(decimal.GetBits(expected))); - } - } - - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Write(string query, decimal expected) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Money) { Value = expected }); - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); - Assert.That(rdr.GetFieldValue(1)); - } - } - } - - static readonly object[] WriteWithLargeScaleCases = new[] - { - new object[] { "0.004::money", 0.004M, 0.00M }, - new object[] { "0.005::money", 0.005M, 0.01M }, - }; - - [Test] - [TestCaseSource(nameof(WriteWithLargeScaleCases))] - public async Task WriteWithLargeScale(string query, decimal parameter, decimal expected) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Money) { Value = parameter }); - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); - Assert.That(rdr.GetFieldValue(1)); - } - } - } + new object[] { "$1.22", 1.22M }, + new object[] { "$1,000.22", 1000.22M }, + new object[] { "$1,000,000.22", 1000000.22M }, + new object[] { "$1,000,000,000.22", 1000000000.22M }, + new object[] { "$1,000,000,000,000.22", 1000000000000.22M }, + new object[] { "$1,000,000,000,000,000.22", 1000000000000000.22M }, + + new object[] { "$92,233,720,368,547,758.07", +92233720368547758.07M }, + new object[] { "-$92,233,720,368,547,758.08", -92233720368547758.08M }, + new object[] { "-$92,233,720,368,547,758.08", -92233720368547758.08M }, + }; + + [Test] + [TestCaseSource(nameof(MoneyValues))] + public async Task Money(string sqlLiteral, decimal money) + { + using var conn = await OpenConnectionAsync(); + await conn.ExecuteNonQueryAsync("SET lc_monetary='C'"); + await AssertType(conn, money, sqlLiteral, "money", NpgsqlDbType.Money, DbType.Currency, isDefault: false); + } - [Test] - public async Task Mapping() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Money) { Value = 8M }); - cmd.Parameters.Add(new NpgsqlParameter("p2", DbType.Currency) { Value = 8M }); + [Test] + public async Task Non_decimal_types_are_not_supported() + { + await AssertTypeUnsupportedRead("8", "money"); + await AssertTypeUnsupportedRead("8", "money"); + await AssertTypeUnsupportedRead("8", "money"); + await AssertTypeUnsupportedRead("8", "money"); + await AssertTypeUnsupportedRead("8", "money"); + await AssertTypeUnsupportedRead("8", "money"); + } - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(rdr.GetFieldType(i), Is.EqualTo(typeof(decimal))); - Assert.That(rdr.GetDataTypeName(i), Is.EqualTo("money")); - Assert.That(rdr.GetValue(i), Is.EqualTo(8M)); - Assert.That(rdr.GetProviderSpecificValue(i), Is.EqualTo(8M)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8M)); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - Assert.That(() => rdr.GetFieldValue(i), Throws.InstanceOf()); - } - } - } - } + static readonly object[] WriteWithLargeScaleCases = new[] + { + new object[] { "0.004::money", 0.004M, 0.00M }, + new object[] { "0.005::money", 0.005M, 0.01M }, + }; - public MoneyTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [Test] + [TestCaseSource(nameof(WriteWithLargeScaleCases))] + public async Task Write_with_large_scale(string query, decimal parameter, decimal expected) + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn); + cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Money) { Value = parameter }); + using var rdr = await cmd.ExecuteReaderAsync(); + rdr.Read(); + Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); + Assert.That(rdr.GetFieldValue(1)); } -} +} \ No newline at end of file diff --git a/test/Npgsql.Tests/Types/MultirangeTests.cs b/test/Npgsql.Tests/Types/MultirangeTests.cs new file mode 100644 index 0000000000..0be83ea3e9 --- /dev/null +++ b/test/Npgsql.Tests/Types/MultirangeTests.cs @@ -0,0 +1,186 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Threading.Tasks; +using Npgsql.Properties; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests.Types; + +public class MultirangeTests : TestBase +{ + static readonly TestCaseData[] MultirangeTestCases = + { + // int4multirange + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, false, false), + new(9, true, false, 0, false, true) + }, + "{[3,7),[9,)}", "int4multirange", NpgsqlDbType.IntegerMultirange, true, true, default(NpgsqlRange)) + .SetName("Int"), + + // int8multirange + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, false, false), + new(9, true, false, 0, false, true) + }, + "{[3,7),[9,)}", "int8multirange", NpgsqlDbType.BigIntMultirange, true, true, default(NpgsqlRange)) + .SetName("Long"), + + // nummultirange + // numeric is non-discrete so doesn't undergo normalization, use that to test bound scenarios which otherwise get normalized + new TestCaseData( + new NpgsqlRange[] + { + new(3, true, false, 7, true, false), + new(9, false, false, 0, false, true) + }, + "{[3,7],(9,)}", "nummultirange", NpgsqlDbType.NumericMultirange, true, true, default(NpgsqlRange)) + .SetName("Decimal"), + + // daterange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, true, false, default(NpgsqlRange)) + .SetName("DateTime DateMultirange"), + + // tsmultirange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + """{["2020-01-01 00:00:00","2020-01-05 00:00:00"),["2020-01-10 00:00:00",)}""", "tsmultirange", NpgsqlDbType.TimestampMultirange, true, true, default(NpgsqlRange)) + .SetName("DateTime TimestampMultirange"), + + // tstzmultirange + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1, 0, 0, 0, kind: DateTimeKind.Utc), true, false, new(2020, 1, 5, 0, 0, 0, kind: DateTimeKind.Utc), false, false), + new(new(2020, 1, 10, 0, 0, 0, kind: DateTimeKind.Utc), true, false, default, false, true) + }, + """{["2020-01-01 01:00:00+01","2020-01-05 01:00:00+01"),["2020-01-10 01:00:00+01",)}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, true, true, default(NpgsqlRange)) + .SetName("DateTime TimestampTzMultirange"), + +#if NET6_0_OR_GREATER + new TestCaseData( + new NpgsqlRange[] + { + new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), + new(new(2020, 1, 10), true, false, default, false, true) + }, + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, false, false, default(NpgsqlRange)) + .SetName("DateOnly"), +#endif + }; + + [Test, TestCaseSource(nameof(MultirangeTestCases))] + public Task Multirange_as_array( + T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) + => AssertType(multirangeAsArray, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: isDefaultForReading, + isDefaultForWriting: isDefaultForWriting); + + [Test, TestCaseSource(nameof(MultirangeTestCases))] + public Task Multirange_as_list( + T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) + where T : IList + => AssertType( + new List(multirangeAsArray), + sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: false, isDefaultForWriting: isDefaultForWriting); + + [Test] + [NonParallelizable] + public async Task Unmapped_multirange_with_mapped_subtype() + { + await using var dataSource = CreateDataSource(b => b.EnableUnmappedTypes().ConnectionStringBuilder.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + var value = new[] {new NpgsqlRange( + new string('a', conn.Settings.WriteBufferSize + 10).ToCharArray(), + new string('z', conn.Settings.WriteBufferSize + 10).ToCharArray() + )}; + + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { DataTypeName = typeName + "_multirange", ParameterName = "p", Value = value }); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(NpgsqlRange[]))); + var result = reader.GetFieldValue[]>(0); + Assert.That(result, Is.EqualTo(value).Using[]>((actual, expected) => + actual[0].LowerBound!.SequenceEqual(expected[0].LowerBound!) && actual[0].UpperBound!.SequenceEqual(expected[0].UpperBound!))); + } + + [Test] + public async Task Unmapped_multirange_supported_only_with_EnableUnmappedTypes() + { + await using var connection = await DataSource.OpenConnectionAsync(); + var rangeType = await GetTempTypeName(connection); + var multirangeTypeName = rangeType + "_multirange"; + await connection.ExecuteNonQueryAsync($"CREATE TYPE {rangeType} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + await connection.ReloadTypesAsync(); + + var errorMessage = string.Format( + NpgsqlStrings.UnmappedRangesNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + nameof(NpgsqlDataSourceBuilder)); + + var exception = await AssertTypeUnsupportedWrite( + new NpgsqlRange[] + { + new("bar", "foo"), + new("moo", "zoo"), + }, + multirangeTypeName); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead("""{["bar","foo"],["moo","zoo"]}""", + multirangeTypeName); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead>( + """{["bar","foo"],["moo","zoo"]}""", + multirangeTypeName); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + protected override NpgsqlDataSource DataSource { get; } + + public MultirangeTests() => DataSource = CreateDataSource(builder => + { + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); + + [OneTimeSetUp] + public async Task Setup() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + } + + [OneTimeTearDown] + public void TearDown() => DataSource.Dispose(); +} diff --git a/test/Npgsql.Tests/Types/NetworkTypeTests.cs b/test/Npgsql.Tests/Types/NetworkTypeTests.cs index e9e1d1b5cf..f164b57d75 100644 --- a/test/Npgsql.Tests/Types/NetworkTypeTests.cs +++ b/test/Npgsql.Tests/Types/NetworkTypeTests.cs @@ -5,279 +5,129 @@ using NpgsqlTypes; using NUnit.Framework; -#pragma warning disable 618 // For NpgsqlInet - -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on PostgreSQL numeric types +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-net-types.html +/// +class NetworkTypeTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL numeric types - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-net-types.html - /// - class NetworkTypeTests : MultiplexingTestBase - { - [Test] - public async Task InetV4() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6", conn)) - { - var expectedIp = IPAddress.Parse("192.168.1.1"); - var expectedTuple = (Address: expectedIp, Subnet: 24); - var expectedNpgsqlInet = new NpgsqlInet(expectedIp, 24); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Inet) { Value = expectedIp }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = expectedIp }); - cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType.Inet) { Value = expectedTuple }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p4", Value = expectedTuple }); - cmd.Parameters.Add(new NpgsqlParameter("p5", NpgsqlDbType.Inet) { Value = expectedNpgsqlInet }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p6", Value = expectedNpgsqlInet }); - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - // Address only, no subnet - for (var i = 0; i < 2; i++) - { - // Regular type (IPAddress) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader[i], Is.EqualTo(expectedIp)); - Assert.That(reader.GetValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - - // Provider-specific type (ValueTuple) - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof((IPAddress, int)))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo((expectedIp, 32))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(new NpgsqlInet(expectedIp))); - } - - // Address and subnet - for (var i = 2; i < 6; i++) - { - // Regular type (IPAddress) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader[i], Is.EqualTo(expectedIp)); - Assert.That(reader.GetValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - - // Provider-specific type (NpgsqlInet) - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof((IPAddress, int)))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(expectedTuple)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedNpgsqlInet)); - } - } - } - } - - [Test] - public async Task InetV6() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6", conn)) - { - const string addr = "2001:1db8:85a3:1142:1000:8a2e:1370:7334"; - var expectedIp = IPAddress.Parse(addr); - var expectedTuple = (Address: expectedIp, Subnet: 24); - var expectedNpgsqlInet = new NpgsqlInet(expectedIp, 24); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Inet) { Value = expectedIp }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p2", Value = expectedIp }); - cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType.Inet) { Value = expectedTuple }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p4", Value = expectedTuple }); - cmd.Parameters.Add(new NpgsqlParameter("p5", NpgsqlDbType.Inet) { Value = expectedNpgsqlInet }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p6", Value = expectedNpgsqlInet }); - - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - // Address only, no subnet - for (var i = 0; i < 2; i++) - { - // Regular type (IPAddress) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader[i], Is.EqualTo(expectedIp)); - Assert.That(reader.GetValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - - // Provider-specific type (ValueTuple) - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof((IPAddress, int)))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo((expectedIp, 128))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(new NpgsqlInet(expectedIp))); - } - - // Address and subnet - for (var i = 2; i < 6; i++) - { - // Regular type (IPAddress) - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader[i], Is.EqualTo(expectedIp)); - Assert.That(reader.GetValue(i), Is.EqualTo(expectedIp)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(IPAddress))); - - // Provider-specific type (NpgsqlInet) - Assert.That(reader.GetProviderSpecificFieldType(i), Is.EqualTo(typeof((IPAddress, int)))); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(expectedTuple)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedNpgsqlInet)); - } - } - } - } - - [Test, Description("Tests support for ReadOnlyIPAddress, see https://github.com/dotnet/corefx/issues/33373")] - public async Task IPAddressAny() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Inet) { Value = IPAddress.Any }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType.Inet) { TypedValue = IPAddress.Any }); - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p3", Value = IPAddress.Any }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - for (var i = 0; i < reader.FieldCount; i++) - Assert.That(reader.GetFieldValue(i), Is.EqualTo(IPAddress.Any)); - } - } - } - - [Test] - public async Task Cidr() - { - var expected = (Address: IPAddress.Parse("192.168.1.0"), Subnet: 24); - //var expectedInet = new NpgsqlInet("192.168.1.0/24"); - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT '192.168.1.0/24'::CIDR", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) + [Test] + public Task Inet_v4_as_IPAddress() + => AssertType(IPAddress.Parse("192.168.1.1"), "192.168.1.1/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); + + [Test] + public Task Inet_v4_array_as_IPAddress_array() + => AssertType( + new[] { - reader.Read(); - - // Regular type (IPAddress) - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof((IPAddress, int)))); - Assert.That(reader.GetFieldValue<(IPAddress, int)>(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(new NpgsqlInet(expected.Address, expected.Subnet))); - Assert.That(reader[0], Is.EqualTo(expected)); - Assert.That(reader.GetValue(0), Is.EqualTo(expected)); - } - } - - [Test] - public async Task Macaddr() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) + IPAddress.Parse("192.168.1.1"), + IPAddress.Parse("192.168.1.2") + }, + "{192.168.1.1,192.168.1.2}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); + + [Test] + public Task Inet_v6_as_IPAddress() + => AssertType( + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), + "2001:1db8:85a3:1142:1000:8a2e:1370:7334/128", + "inet", + NpgsqlDbType.Inet, + skipArrayCheck: true); + + [Test] + public Task Inet_v6_array_as_IPAddress_array() + => AssertType( + new[] { - var expected = PhysicalAddress.Parse("08-00-2B-01-02-03"); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.MacAddr) { Value = expected }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = expected }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(PhysicalAddress))); - } - } - } - } - - [Test] - public async Task Macaddr8() - { - using (var conn = await OpenConnectionAsync()) - { - if (conn.PostgreSqlVersion < new Version(10, 0)) - Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - var send6 = PhysicalAddress.Parse("08-00-2B-01-02-03"); - var expected6 = PhysicalAddress.Parse("08-00-2B-FF-FE-01-02-03"); // 6-byte macaddr8 gets FF and FE inserted in the middle - var expected8 = PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.MacAddr8) { Value = send6 }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType.MacAddr8) { Value = expected8 }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected6)); - Assert.That(reader.GetValue(0), Is.EqualTo(expected6)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(PhysicalAddress))); - - Assert.That(reader.GetFieldValue(1), Is.EqualTo(expected8)); - Assert.That(reader.GetValue(1), Is.EqualTo(expected8)); - Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(PhysicalAddress))); - } - } - } - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/835")] - public async Task MacaddrMultiple() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT unnest(ARRAY['08-00-2B-01-02-03'::MACADDR, '08-00-2B-01-02-04'::MACADDR])", conn)) - using (var r = await cmd.ExecuteReaderAsync()) - { - r.Read(); - var p1 = (PhysicalAddress)r[0]; - r.Read(); - var p2 = (PhysicalAddress)r[0]; - Assert.That(p1, Is.EqualTo(PhysicalAddress.Parse("08-00-2B-01-02-03"))); - Assert.That(p2, Is.EqualTo(PhysicalAddress.Parse("08-00-2B-01-02-04"))); - } - } - - [Test] - public async Task MacaddrValidation() - { - using (var conn = await OpenConnectionAsync()) - { - if (conn.PostgreSqlVersion < new Version(10, 0)) - Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), + IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7335") + }, + "{2001:1db8:85a3:1142:1000:8a2e:1370:7334,2001:1db8:85a3:1142:1000:8a2e:1370:7335}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); + + [Test, IssueLink("https://github.com/dotnet/corefx/issues/33373")] + public Task IPAddress_Any() + => AssertTypeWrite(IPAddress.Any, "0.0.0.0/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); + + [Test] + public Task Cidr() + => AssertType( + new NpgsqlCidr(IPAddress.Parse("192.168.1.0"), netmask: 24), + "192.168.1.0/24", + "cidr", + NpgsqlDbType.Cidr, + isDefaultForWriting: false); + + [Test] + public Task Inet_v4_as_NpgsqlInet() + => AssertType( + new NpgsqlInet(IPAddress.Parse("192.168.1.1"), 24), + "192.168.1.1/24", + "inet", + NpgsqlDbType.Inet, + isDefaultForReading: false); + + [Test] + public Task Inet_v6_as_NpgsqlInet() + => AssertType( + new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 24), + "2001:1db8:85a3:1142:1000:8a2e:1370:7334/24", + "inet", + NpgsqlDbType.Inet, + isDefaultForReading: false); + + [Test] + public Task Macaddr() + => AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:01:02:03", "macaddr", NpgsqlDbType.MacAddr); + + [Test] + public async Task Macaddr8() + { + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(10, 0)) + Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - // 6-byte macaddr8 gets FF and FE inserted in the middle - var send8 = PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"); - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.MacAddr) { Value = send8 }); + await AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "08:00:2b:01:02:03:04:05", "macaddr8", NpgsqlDbType.MacAddr8, + isDefaultForWriting: false); + } - var exception = Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync()); - Assert.That(exception.Message, Does.StartWith("22P03:").And.Contain("1")); - } - } - } + [Test] + public async Task Macaddr8_write_with_6_bytes() + { + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(10, 0)) + Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - // Older tests from here + await AssertTypeWrite(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:ff:fe:01:02:03", "macaddr8", NpgsqlDbType.MacAddr8, + isDefault: false); + } - [Test] - public async Task TestNpgsqlSpecificTypesCLRTypesNpgsqlInet() - { - // Please, check https://pgfoundry.org/forum/message.php?msg_id=1005483 - // for a discussion where an NpgsqlInet type isn't shown in a datagrid - // This test tries to check if the type returned is an IPAddress when using - // the GetValue() of NpgsqlDataReader and NpgsqlInet when using GetProviderValue(); + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/835")] + public async Task Macaddr_multiple() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT unnest(ARRAY['08-00-2B-01-02-03'::MACADDR, '08-00-2B-01-02-04'::MACADDR])", conn); + await using var r = await cmd.ExecuteReaderAsync(); + r.Read(); + var p1 = (PhysicalAddress)r[0]; + r.Read(); + var p2 = (PhysicalAddress)r[0]; + Assert.That(p1, Is.EqualTo(PhysicalAddress.Parse("08-00-2B-01-02-03"))); + Assert.That(p2, Is.EqualTo(PhysicalAddress.Parse("08-00-2B-01-02-04"))); + } - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select '192.168.10.10'::inet;", conn)) - using (var dr = await command.ExecuteReaderAsync()) - { - dr.Read(); - var result = dr.GetValue(0); - Assert.AreEqual(typeof(IPAddress), result.GetType()); - } - } + [Test] + public async Task Macaddr_write_validation() + { + await using var conn = await OpenConnectionAsync(); + if (conn.PostgreSqlVersion < new Version(10, 0)) + Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - public NetworkTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + await AssertTypeUnsupportedWrite(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "macaddr"); } + + public NetworkTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/NumericTests.cs b/test/Npgsql.Tests/Types/NumericTests.cs index 501e9ec020..c0bec8f676 100644 --- a/test/Npgsql.Tests/Types/NumericTests.cs +++ b/test/Npgsql.Tests/Types/NumericTests.cs @@ -1,150 +1,220 @@ using System; using System.Data; +using System.Globalization; +using System.Linq; +using System.Numerics; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +public class NumericTests : MultiplexingTestBase { - public class NumericTests : MultiplexingTestBase + static readonly object[] ReadWriteCases = new[] { - static readonly object[] ReadWriteCases = new[] - { - new object[] { "0.0000000000000000000000000001::numeric", 0.0000000000000000000000000001M }, - new object[] { "0.000000000000000000000001::numeric", 0.000000000000000000000001M }, - new object[] { "0.00000000000000000001::numeric", 0.00000000000000000001M }, - new object[] { "0.0000000000000001::numeric", 0.0000000000000001M }, - new object[] { "0.000000000001::numeric", 0.000000000001M }, - new object[] { "0.00000001::numeric", 0.00000001M }, - new object[] { "0.0001::numeric", 0.0001M }, - new object[] { "1::numeric", 1M }, - new object[] { "10000::numeric", 10000M }, - new object[] { "100000000::numeric", 100000000M }, - new object[] { "1000000000000::numeric", 1000000000000M }, - new object[] { "10000000000000000::numeric", 10000000000000000M }, - new object[] { "100000000000000000000::numeric", 100000000000000000000M }, - new object[] { "1000000000000000000000000::numeric", 1000000000000000000000000M }, - new object[] { "10000000000000000000000000000::numeric", 10000000000000000000000000000M }, - - new object[] { "11.222233334444555566667777888::numeric", 11.222233334444555566667777888M }, - new object[] { "111.22223333444455556666777788::numeric", 111.22223333444455556666777788M }, - new object[] { "1111.2222333344445555666677778::numeric", 1111.2222333344445555666677778M }, - - new object[] { "+79228162514264337593543950335::numeric", +79228162514264337593543950335M }, - new object[] { "-79228162514264337593543950335::numeric", -79228162514264337593543950335M }, - - // It is important to test rounding on both even and odd - // numbers to make sure midpoint rounding is away from zero. - new object[] { "1::numeric(10,2)", 1.00M }, - new object[] { "2::numeric(10,2)", 2.00M }, - - new object[] { "1.2::numeric(10,1)", 1.2M }, - new object[] { "1.2::numeric(10,2)", 1.20M }, - new object[] { "1.2::numeric(10,3)", 1.200M }, - new object[] { "1.2::numeric(10,4)", 1.2000M }, - new object[] { "1.2::numeric(10,5)", 1.20000M }, - - new object[] { "1.4::numeric(10,0)", 1M }, - new object[] { "1.5::numeric(10,0)", 2M }, - new object[] { "2.4::numeric(10,0)", 2M }, - new object[] { "2.5::numeric(10,0)", 3M }, - - new object[] { "-1.4::numeric(10,0)", -1M }, - new object[] { "-1.5::numeric(10,0)", -2M }, - new object[] { "-2.4::numeric(10,0)", -2M }, - new object[] { "-2.5::numeric(10,0)", -3M }, - - // Bug 2033 - new object[] { "0.0036882500000000000000000000", 0.0036882500000000000000000000M }, - }; - - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Read(string query, decimal expected) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT " + query, conn)) - { - Assert.That( - decimal.GetBits((decimal)(await cmd.ExecuteScalarAsync())!), - Is.EqualTo(decimal.GetBits(expected))); - } - } + new object[] { "0.0000000000000000000000000001::numeric", 0.0000000000000000000000000001M }, + new object[] { "0.000000000000000000000001::numeric", 0.000000000000000000000001M }, + new object[] { "0.00000000000000000001::numeric", 0.00000000000000000001M }, + new object[] { "0.0000000000000001::numeric", 0.0000000000000001M }, + new object[] { "0.000000000001::numeric", 0.000000000001M }, + new object[] { "0.00000001::numeric", 0.00000001M }, + new object[] { "0.0001::numeric", 0.0001M }, + new object[] { "0.123456000000000100000000::numeric", 0.123456000000000100000000M }, + new object[] { "1::numeric", 1M }, + new object[] { "10000::numeric", 10000M }, + new object[] { "100000000::numeric", 100000000M }, + new object[] { "1000000000000::numeric", 1000000000000M }, + new object[] { "10000000000000000::numeric", 10000000000000000M }, + new object[] { "100000000000000000000::numeric", 100000000000000000000M }, + new object[] { "1000000000000000000000000::numeric", 1000000000000000000000000M }, + new object[] { "10000000000000000000000000000::numeric", 10000000000000000000000000000M }, + + new object[] { "1E-28::numeric", 0.0000000000000000000000000001M }, + new object[] { "1E-24::numeric", 0.000000000000000000000001M }, + new object[] { "1E-20::numeric", 0.00000000000000000001M }, + new object[] { "1E-16::numeric", 0.0000000000000001M }, + new object[] { "1E-12::numeric", 0.000000000001M }, + new object[] { "1E-8::numeric", 0.00000001M }, + new object[] { "1E-4::numeric", 0.0001M }, + new object[] { "1E+0::numeric", 1M }, + new object[] { "1E+4::numeric", 10000M }, + new object[] { "1E+8::numeric", 100000000M }, + new object[] { "1E+12::numeric", 1000000000000M }, + new object[] { "1E+16::numeric", 10000000000000000M }, + new object[] { "1E+20::numeric", 100000000000000000000M }, + new object[] { "1E+24::numeric", 1000000000000000000000000M }, + new object[] { "1E+28::numeric", 10000000000000000000000000000M }, + + new object[] { "1.2222333344445555666677778888::numeric", 1.2222333344445555666677778888M }, + new object[] { "11.222233334444555566667777888::numeric", 11.222233334444555566667777888M }, + new object[] { "111.22223333444455556666777788::numeric", 111.22223333444455556666777788M }, + new object[] { "1111.2222333344445555666677778::numeric", 1111.2222333344445555666677778M }, + + new object[] { "+79228162514264337593543950335::numeric", +79228162514264337593543950335M }, + new object[] { "-79228162514264337593543950335::numeric", -79228162514264337593543950335M }, + + // It is important to test rounding on both even and odd + // numbers to make sure midpoint rounding is away from zero. + new object[] { "1::numeric(10,2)", 1.00M }, + new object[] { "2::numeric(10,2)", 2.00M }, + + new object[] { "1.2::numeric(10,1)", 1.2M }, + new object[] { "1.2::numeric(10,2)", 1.20M }, + new object[] { "1.2::numeric(10,3)", 1.200M }, + new object[] { "1.2::numeric(10,4)", 1.2000M }, + new object[] { "1.2::numeric(10,5)", 1.20000M }, + + new object[] { "1.4::numeric(10,0)", 1M }, + new object[] { "1.5::numeric(10,0)", 2M }, + new object[] { "2.4::numeric(10,0)", 2M }, + new object[] { "2.5::numeric(10,0)", 3M }, + + new object[] { "-1.4::numeric(10,0)", -1M }, + new object[] { "-1.5::numeric(10,0)", -2M }, + new object[] { "-2.4::numeric(10,0)", -2M }, + new object[] { "-2.5::numeric(10,0)", -3M }, + + // Bug 2033 + new object[] { "0.0036882500000000000000000000", 0.0036882500000000000000000000M }, + + new object[] { "936490726837837729197", 936490726837837729197M }, + new object[] { "9364907268378377291970000", 9364907268378377291970000M }, + new object[] { "3649072683783772919700000000", 3649072683783772919700000000M }, + new object[] { "1234567844445555.000000000", 1234567844445555.000000000M }, + new object[] { "11112222000000000000", 11112222000000000000M }, + new object[] { "0::numeric", 0M }, + }; + + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Read(string query, decimal expected) + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT " + query, conn); + var value = (decimal)(await cmd.ExecuteScalarAsync())!; + Assert.That(decimal.GetBits(value), Is.EqualTo(decimal.GetBits(expected))); + } - [Test] - [TestCaseSource(nameof(ReadWriteCases))] - public async Task Write(string query, decimal expected) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); - Assert.That(rdr.GetFieldValue(1)); - } - } - } + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Write(string query, decimal expected) + { + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn); + cmd.Parameters.AddWithValue("p", expected); + using var rdr = await cmd.ExecuteReaderAsync(); + rdr.Read(); + Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); + Assert.That(rdr.GetFieldValue(1)); + } + + + [Test] + public async Task Numeric() + { + await AssertType(5.5m, "5.5", "numeric", NpgsqlDbType.Numeric, DbType.Decimal); + await AssertTypeWrite(5.5m, "5.5", "numeric", NpgsqlDbType.Numeric, DbType.VarNumeric, inferredDbType: DbType.Decimal); + + await AssertType((short)8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType(8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType((byte)8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType(8F, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType(8D, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType(8M, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + } - [Test] - public async Task Mapping() + [Test, Description("Tests that when Numeric value does not fit in a System.Decimal and reader is in ReaderState.InResult, the value was read wholly and it is safe to continue reading")] + public async Task Read_overflow_is_safe() + { + using var conn = await OpenConnectionAsync(); + //This 29-digit number causes OverflowException. Here it is important to have unread column after failing one to leave it ReaderState.InResult + using var cmd = new NpgsqlCommand(@"SELECT (0.20285714285714285714285714285)::numeric, generate_series FROM generate_series(1, 2)", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + var i = 1; + + while (reader.Read()) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.Numeric) { Value = 8M }); - cmd.Parameters.Add(new NpgsqlParameter("p2", DbType.Decimal) { Value = 8M }); - cmd.Parameters.Add(new NpgsqlParameter("p3", DbType.VarNumeric) { Value = 8M }); - cmd.Parameters.Add(new NpgsqlParameter("p4", 8M)); - - using (var rdr = await cmd.ExecuteReaderAsync()) - { - rdr.Read(); - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(rdr.GetFieldType(i), Is.EqualTo(typeof(decimal))); - Assert.That(rdr.GetDataTypeName(i), Is.EqualTo("numeric")); - Assert.That(rdr.GetValue(i), Is.EqualTo(8M)); - Assert.That(rdr.GetProviderSpecificValue(i), Is.EqualTo(8M)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8M)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8.0f)); - Assert.That(rdr.GetFieldValue(i), Is.EqualTo(8.0d)); - } - } - } + Assert.That(() => reader.GetDecimal(0), + Throws.Exception + .With.TypeOf() + .With.Message.EqualTo("Numeric value does not fit in a System.Decimal")); + var intValue = reader.GetInt32(1); + + Assert.That(intValue, Is.EqualTo(i++)); + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(reader.State, Is.EqualTo(ReaderState.InResult)); } + } - [Test, Description("Tests that when Numeric value does not fit in a System.Decimal and reader is in ReaderState.InResult, the value was read wholly and it is safe to continue reading")] - [Timeout(5000)] - public async Task ReadOverflowIsSafe() + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Read_BigInteger(string query, decimal expected) + { + var bigInt = new BigInteger(expected); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT " + query, conn); + using var rdr = await cmd.ExecuteReaderAsync(); + await rdr.ReadAsync(); + + if (decimal.Floor(expected) == expected) + Assert.That(rdr.GetFieldValue(0), Is.EqualTo(bigInt)); + else + Assert.That(() => rdr.GetFieldValue(0), + Throws.Exception + .With.TypeOf() + .With.Message.EqualTo("Numeric value with non-zero fractional digits not supported by BigInteger")); + } + + [Test] + [TestCaseSource(nameof(ReadWriteCases))] + public async Task Write_BigInteger(string query, decimal expected) + { + if (decimal.Floor(expected) == expected) { + var bigInt = new BigInteger(expected); using var conn = await OpenConnectionAsync(); - //This 29-digit number causes OverflowException. Here it is important to have unread column after failing one to leave it ReaderState.InResult - using var cmd = new NpgsqlCommand(@"SELECT (0.20285714285714285714285714285)::numeric, generate_series FROM generate_series(1, 2)", conn); - using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); - var i = 1; - - while (reader.Read()) - { - Assert.That(() => reader.GetDecimal(0), - Throws.Exception - .With.TypeOf() - .With.Message.EqualTo("Numeric value does not fit in a System.Decimal")); - var intValue = reader.GetInt32(1); - - Assert.That(intValue, Is.EqualTo(i++)); - Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open | ConnectionState.Fetching)); - Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); - Assert.That(reader.State, Is.EqualTo(ReaderState.InResult)); - } + using var cmd = new NpgsqlCommand("SELECT @p, @p = " + query, conn); + cmd.Parameters.AddWithValue("p", bigInt); + using var rdr = await cmd.ExecuteReaderAsync(); + await rdr.ReadAsync(); + Assert.That(rdr.GetFieldValue(0), Is.EqualTo(bigInt)); + Assert.That(rdr.GetFieldValue(1)); } + } + + [Test] + public async Task BigInteger_large() + { + var num = BigInteger.Parse(string.Join("", Enumerable.Range(0, 17000).Select(i => ((i + 1) % 10).ToString()))); + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT '0.1'::numeric, @p", conn); + cmd.Parameters.AddWithValue("p", num); + using var rdr = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await rdr.ReadAsync(); + Assert.Throws(() => rdr.GetFieldValue(0)); + Assert.That(rdr.GetFieldValue(1), Is.EqualTo(num)); + } - public NumericTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [Test] + public async Task NumericZero_WithScale() + { + // Scale should not be lost when dealing with 0 + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + var param = new NpgsqlParameter("p", DbType.Decimal, 10, null, ParameterDirection.Input, false, 10, 2, DataRowVersion.Default, 0.00M); + cmd.Parameters.Add(param); + using var rdr = await cmd.ExecuteReaderAsync(); + await rdr.ReadAsync(); + var value = rdr.GetFieldValue(0); + +#if NET7_0_OR_GREATER + Assert.That(value.Scale, Is.EqualTo(2)); +#else + Assert.That(value.ToString(CultureInfo.InvariantCulture), Is.EqualTo(0.00M.ToString(CultureInfo.InvariantCulture))); +#endif } + + public NumericTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/NumericTypeTests.cs b/test/Npgsql.Tests/Types/NumericTypeTests.cs index d85686615a..9fcd5b695b 100644 --- a/test/Npgsql.Tests/Types/NumericTypeTests.cs +++ b/test/Npgsql.Tests/Types/NumericTypeTests.cs @@ -1,353 +1,113 @@ using System; -using System.Collections.Generic; using System.Data; using System.Globalization; using System.Threading.Tasks; -using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on PostgreSQL numeric types +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-numeric.html +/// +public class NumericTypeTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL numeric types - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-numeric.html - /// - public class NumericTypeTests : MultiplexingTestBase + [Test] + public async Task Int16() { - [Test] - public async Task Int16() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Smallint); - var p2 = new NpgsqlParameter("p2", DbType.Int16); - var p3 = new NpgsqlParameter("p3", DbType.Byte); - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = (short)8 }; - var p5 = new NpgsqlParameter { ParameterName = "p5", Value = (byte)8 }; - Assert.That(p4.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Smallint)); - Assert.That(p4.DbType, Is.EqualTo(DbType.Int16)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - cmd.Parameters.Add(p5); - p1.Value = p2.Value = p3.Value = (long)8; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetInt16(i), Is.EqualTo(8)); - Assert.That(reader.GetInt32(i), Is.EqualTo(8)); - Assert.That(reader.GetInt64(i), Is.EqualTo(8)); - Assert.That(reader.GetByte(i), Is.EqualTo(8)); - Assert.That(reader.GetFloat(i), Is.EqualTo(8.0f)); - Assert.That(reader.GetDouble(i), Is.EqualTo(8.0d)); - Assert.That(reader.GetDecimal(i), Is.EqualTo(8.0m)); - Assert.That(reader.GetValue(i), Is.EqualTo(8)); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(8)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(short))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("smallint")); - } - } - } - } - - [Test] - public async Task Int32() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Integer); - var p2 = new NpgsqlParameter("p2", DbType.Int32); - var p3 = new NpgsqlParameter { ParameterName = "p3", Value = 8 }; - Assert.That(p3.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); - Assert.That(p3.DbType, Is.EqualTo(DbType.Int32)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = (long)8; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetInt32(i), Is.EqualTo(8)); - Assert.That(reader.GetInt64(i), Is.EqualTo(8)); - Assert.That(reader.GetInt16(i), Is.EqualTo(8)); - Assert.That(reader.GetByte(i), Is.EqualTo(8)); - Assert.That(reader.GetFloat(i), Is.EqualTo(8.0f)); - Assert.That(reader.GetDouble(i), Is.EqualTo(8.0d)); - Assert.That(reader.GetDecimal(i), Is.EqualTo(8.0m)); - Assert.That(reader.GetValue(i), Is.EqualTo(8)); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(8)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(int))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("integer")); - } - } - } - } - - [Test, Description("Tests some types which are aliased to UInt32")] - [TestCase(NpgsqlDbType.Oid, TestName="OID")] - [TestCase(NpgsqlDbType.Xid, TestName="XID")] - [TestCase(NpgsqlDbType.Cid, TestName="CID")] - public async Task UInt32(NpgsqlDbType npgsqlDbType) - { - var expected = 8u; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", npgsqlDbType) { Value = expected }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(expected)); - Assert.That(reader.GetProviderSpecificValue(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(uint))); - } - } - } - - [Test] - public async Task Int64() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Bigint); - var p2 = new NpgsqlParameter("p2", DbType.Int64); - var p3 = new NpgsqlParameter { ParameterName = "p3", Value = (long)8 }; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = (short)8; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetInt64(i), Is.EqualTo(8)); - Assert.That(reader.GetInt16(i), Is.EqualTo(8)); - Assert.That(reader.GetInt32(i), Is.EqualTo(8)); - Assert.That(reader.GetByte(i), Is.EqualTo(8)); - Assert.That(reader.GetFloat(i), Is.EqualTo(8.0f)); - Assert.That(reader.GetDouble(i), Is.EqualTo(8.0d)); - Assert.That(reader.GetDecimal(i), Is.EqualTo(8.0m)); - Assert.That(reader.GetValue(i), Is.EqualTo(8)); - Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(8)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(long))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("bigint")); - } - } - } - } - - [Test] - public async Task Double() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - const double expected = 4.123456789012345; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Double); - var p2 = new NpgsqlParameter("p2", DbType.Double); - var p3 = new NpgsqlParameter {ParameterName = "p3", Value = expected}; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetDouble(i), Is.EqualTo(expected).Within(10E-07)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(double))); - } - } - } - } - - [Test] - [TestCase(double.NaN)] - [TestCase(double.PositiveInfinity)] - [TestCase(double.NegativeInfinity)] - public async Task DoubleSpecial(double value) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Double, value); - var actual = await cmd.ExecuteScalarAsync(); - Assert.That(actual, Is.EqualTo(value)); - } - } - - [Test] - public async Task Float() - { - const float expected = .123456F; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Real); - var p2 = new NpgsqlParameter("p2", DbType.Single); - var p3 = new NpgsqlParameter {ParameterName = "p3", Value = expected}; - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFloat(i), Is.EqualTo(expected).Within(10E-07)); - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(float))); - } - } - } - } - - [Test] - [TestCase(double.NaN)] - [TestCase(double.PositiveInfinity)] - [TestCase(double.NegativeInfinity)] - public async Task DoubleFloat(double value) - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Real, value); - var actual = await cmd.ExecuteScalarAsync(); - Assert.That(actual, Is.EqualTo(value)); - } - } - - [Test, Description("Tests handling of numeric overflow when writing data")] - [TestCase(NpgsqlDbType.Smallint, 1 + short.MaxValue)] - [TestCase(NpgsqlDbType.Smallint, 1L + short.MaxValue)] - [TestCase(NpgsqlDbType.Smallint, 1F + short.MaxValue)] - [TestCase(NpgsqlDbType.Smallint, 1D + short.MaxValue)] - [TestCase(NpgsqlDbType.Integer, 1L + int.MaxValue)] - [TestCase(NpgsqlDbType.Integer, 1F + int.MaxValue)] - [TestCase(NpgsqlDbType.Integer, 1D + int.MaxValue)] - [TestCase(NpgsqlDbType.Bigint, 1F + long.MaxValue)] - [TestCase(NpgsqlDbType.Bigint, 1D + long.MaxValue)] - [TestCase(NpgsqlDbType.InternalChar, 1 + byte.MaxValue)] - public async Task WriteOverflow(NpgsqlDbType type, object value) - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p1", conn); - - var p1 = new NpgsqlParameter("p1", type) - { - Value = value - }; - cmd.Parameters.Add(p1); - Assert.ThrowsAsync(async () => await cmd.ExecuteScalarAsync()); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - } + await AssertType((short)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16); + // Clr byte/sbyte maps to 'int2' as there is no byte type in PostgreSQL, byte[] maps to bytea however. + await AssertType((byte)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefaultForReading: false, skipArrayCheck: true); + await AssertType((sbyte)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefaultForReading: false); + + await AssertType(8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + await AssertType(8L, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + await AssertType(8F, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + await AssertType(8D, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + await AssertType(8M, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + } - static IEnumerable ReadOverflowTestCases - { - get - { - yield return new TestCaseData(NpgsqlDbType.Smallint, 1D + byte.MaxValue){ }; - } - } - [Test, Description("Tests handling of numeric overflow when reading data")] - [TestCase((byte)0, NpgsqlDbType.Smallint, 1D + byte.MaxValue)] - [TestCase((sbyte)0, NpgsqlDbType.Smallint, 1D + sbyte.MaxValue)] - [TestCase((byte)0, NpgsqlDbType.Integer, 1D + byte.MaxValue)] - [TestCase((short)0, NpgsqlDbType.Integer, 1D + short.MaxValue)] - [TestCase((byte)0, NpgsqlDbType.Bigint, 1D + byte.MaxValue)] - [TestCase((short)0, NpgsqlDbType.Bigint, 1D + short.MaxValue)] - [TestCase(0, NpgsqlDbType.Bigint, 1D + int.MaxValue)] - public async Task ReadOverflow(T readingType, NpgsqlDbType type, double value) - { - var typeString = GetTypeAsString(type); - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT {value}::{typeString}", conn)) - { - Assert.ThrowsAsync(async() => - { - using (var reader = await cmd.ExecuteReaderAsync()) - { - Assert.True(reader.Read()); - reader.GetFieldValue(0); - } - }); - } + [Test] + public async Task Int32() + { + await AssertType(8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32); + + await AssertType((short)8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType(8L, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType((byte)8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType(8F, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType(8D, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType(8M, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + } - string GetTypeAsString(NpgsqlDbType dbType) - => dbType switch - { - NpgsqlDbType.Smallint => "int2", - NpgsqlDbType.Integer => "int4", - NpgsqlDbType.Bigint => "int8", - _ => throw new NotSupportedException() - }; - } + [Test, Description("Tests some types which are aliased to UInt32")] + [TestCase("oid", NpgsqlDbType.Oid, TestName="OID")] + [TestCase("xid", NpgsqlDbType.Xid, TestName="XID")] + [TestCase("cid", NpgsqlDbType.Cid, TestName="CID")] + public Task UInt32(string pgTypeName, NpgsqlDbType npgsqlDbType) + => AssertType(8u, "8", pgTypeName, npgsqlDbType, isDefaultForWriting: false); - // Older tests + [Test] + [TestCase("xid8", NpgsqlDbType.Xid8, TestName="XID8")] + public async Task UInt64(string pgTypeName, NpgsqlDbType npgsqlDbType) + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "13.0", "The xid8 type was introduced in PostgreSQL 13"); - [Test] - public async Task DoubleWithoutPrepared() - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("select :field_float8", conn)) - { - command.Parameters.Add(new NpgsqlParameter(":field_float8", NpgsqlDbType.Double)); - var x = 1d/7d; - command.Parameters[0].Value = x; - var valueReturned = await command.ExecuteScalarAsync(); - Assert.That(valueReturned, Is.EqualTo(x).Within(100).Ulps); - } - } + await AssertType(8ul, "8", pgTypeName, npgsqlDbType, isDefaultForWriting: false); + } - [Test] - public async Task NumberConversionWithCulture() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("select :p1", conn)) - using (TestUtil.SetCurrentCulture(new CultureInfo("es-ES"))) - { - var parameter = new NpgsqlParameter("p1", NpgsqlDbType.Double) { Value = 5.5 }; - cmd.Parameters.Add(parameter); - var result = await cmd.ExecuteScalarAsync(); - Assert.AreEqual(5.5, result); - } - } + [Test] + public async Task Int64() + { + await AssertType(8L, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64); + + await AssertType((short)8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType(8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType((byte)8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType(8F, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType(8D, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType(8M, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + } - [Test] - public async Task Money() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - cmd.CommandText = "select '1'::MONEY, '12345'::MONEY / 100, '123456789012345'::MONEY / 100"; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.AreEqual(1M, reader.GetValue(0)); - Assert.AreEqual(123.45M, reader.GetValue(1)); - Assert.AreEqual(1234567890123.45M, reader.GetValue(2)); - } - } - } + [Test] + [TestCase(4.123456789012345, "4.123456789012345", TestName = "Double")] + [TestCase(double.NaN, "NaN", TestName = "Double_NaN")] + [TestCase(double.PositiveInfinity, "Infinity", TestName = "Double_PositiveInfinity")] + [TestCase(double.NegativeInfinity, "-Infinity", TestName = "Double_NegativeInfinity")] + public async Task Double(double value, string sqlLiteral) + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "12.0"); - public NumericTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + await AssertType(value, sqlLiteral, "double precision", NpgsqlDbType.Double, DbType.Double); } + + [Test] + [TestCase(0.123456F, "0.123456", TestName = "Float")] + [TestCase(float.NaN, "NaN", TestName = "Float_NaN")] + [TestCase(float.PositiveInfinity, "Infinity", TestName = "Float_PositiveInfinity")] + [TestCase(float.NegativeInfinity, "-Infinity", TestName = "Float_NegativeInfinity")] + public Task Float(float value, string sqlLiteral) + => AssertType(value, sqlLiteral, "real", NpgsqlDbType.Real, DbType.Single); + + [Test] + [TestCase(short.MaxValue + 1, "smallint")] + [TestCase(int.MaxValue + 1L, "integer")] + [TestCase(long.MaxValue + 1D, "bigint")] + public Task Write_overflow(T value, string pgTypeName) + => AssertTypeUnsupportedWrite(value, pgTypeName); + + [Test] + [TestCase((short)0, short.MaxValue + 1D, "int")] + [TestCase(0, int.MaxValue + 1D, "bigint")] + [TestCase(0L, long.MaxValue + 1D, "decimal")] + public Task Read_overflow(T _, double value, string pgTypeName) + => AssertTypeUnsupportedRead(value.ToString(CultureInfo.InvariantCulture), pgTypeName); + + public NumericTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/RangeTests.cs b/test/Npgsql.Tests/Types/RangeTests.cs index 35f2d94ee6..38449d30a2 100644 --- a/test/Npgsql.Tests/Types/RangeTests.cs +++ b/test/Npgsql.Tests/Types/RangeTests.cs @@ -2,436 +2,476 @@ using System.ComponentModel; using System.Data; using System.Globalization; +using System.Linq; using System.Threading.Tasks; +using Npgsql.Properties; +using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +class RangeTests : MultiplexingTestBase { - /// - /// https://www.postgresql.org/docs/current/static/rangetypes.html - /// - class RangeTests : MultiplexingTestBase + static readonly TestCaseData[] RangeTestCases = { - [Test, NUnit.Framework.Description("Resolves a range type handler via the different pathways")] - public async Task RangeTypeResolution() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange) + .SetName("IntegerRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int8range", NpgsqlDbType.BigIntRange) + .SetName("BigIntRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("NumericRange"), + new TestCaseData(new NpgsqlRange( + new DateTime(2020, 1, 1, 12, 0, 0), true, + new DateTime(2020, 1, 3, 13, 0, 0), false), + """["2020-01-01 12:00:00","2020-01-03 13:00:00")""", "tsrange", NpgsqlDbType.TimestampRange) + .SetName("TimestampRange"), + // Note that the below text representations are local (according to TimeZone, which is set to Europe/Berlin in this test class), + // because that's how PG does timestamptz *text* representation. + new TestCaseData(new NpgsqlRange( + new DateTime(2020, 1, 1, 12, 0, 0, DateTimeKind.Utc), true, + new DateTime(2020, 1, 3, 13, 0, 0, DateTimeKind.Utc), false), + """["2020-01-01 13:00:00+01","2020-01-03 14:00:00+01")""", "tstzrange", NpgsqlDbType.TimestampTzRange) + .SetName("TimestampTzRange"), + + // Note that numrange is a non-discrete range, and therefore doesn't undergo normalization to inclusive/exclusive in PG + new TestCaseData(NpgsqlRange.Empty, "empty", "numrange", NpgsqlDbType.NumericRange) + .SetName("EmptyRange"), + new TestCaseData(new NpgsqlRange(1, true, 10, true), "[1,10]", "numrange", NpgsqlDbType.NumericRange) + .SetName("Inclusive"), + new TestCaseData(new NpgsqlRange(1, false, 10, false), "(1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("Exclusive"), + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InclusiveExclusive"), + new TestCaseData(new NpgsqlRange(1, false, 10, true), "(1,10]", "numrange", NpgsqlDbType.NumericRange) + .SetName("ExclusiveInclusive"), + new TestCaseData(new NpgsqlRange(1, false, true, 10, false, false), "(,10)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InfiniteLowerBound"), + new TestCaseData(new NpgsqlRange(1, true, false, 10, false, true), "[1,)", "numrange", NpgsqlDbType.NumericRange) + .SetName("InfiniteUpperBound") + }; + + // See more test cases in DateTimeTests + [Test, TestCaseSource(nameof(RangeTestCases))] + public Task Range(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) + => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, + // NpgsqlRange[] is mapped to multirange by default, not array, so the built-in AssertType testing for arrays fails + // (see below) + skipArrayCheck: true); + + // This re-executes the same scenario as above, but with isDefaultForWriting: false and without skipArrayCheck: true. + // This tests coverage of range arrays (as opposed to multiranges). + [Test, TestCaseSource(nameof(RangeTestCases))] + public Task Range_array(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) + => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForWriting: false); + + [Test] + public void Equality_finite() + { + var r1 = new NpgsqlRange(0, true, false, 1, false, false); - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - ApplicationName = nameof(RangeTypeResolution), // Prevent backend type caching in TypeHandlerRegistry - Pooling = false - }; + //different bounds + var r2 = new NpgsqlRange(1, true, false, 2, false, false); + Assert.IsFalse(r1 == r2); - using (var conn = await OpenConnectionAsync(csb)) - { - // Resolve type by NpgsqlDbType - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", NpgsqlDbType.Range | NpgsqlDbType.Integer, DBNull.Value); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - - // Resolve type by ClrType (type inference) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter { ParameterName = "p", Value = new NpgsqlRange(3, 5) }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - - // Resolve type by OID (read) - conn.ReloadTypes(); - using (var cmd = new NpgsqlCommand("SELECT int4range(3, 5)", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetDataTypeName(0), Is.EqualTo("int4range")); - } - } - } + //lower bound is not inclusive + var r3 = new NpgsqlRange(0, false, false, 1, false, false); + Assert.IsFalse(r1 == r3); - [Test] - public async Task Range() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4", conn)) - { - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Range | NpgsqlDbType.Integer) { Value = NpgsqlRange.Empty }; - var p2 = new NpgsqlParameter { ParameterName = "p2", Value = new NpgsqlRange(1, 10) }; - var p3 = new NpgsqlParameter { ParameterName = "p3", Value = new NpgsqlRange(1, false, 10, false) }; - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = new NpgsqlRange(0, false, true, 10, false, false) }; - Assert.That(p2.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Range | NpgsqlDbType.Integer)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - Assert.That(reader[0].ToString(), Is.EqualTo("empty")); - Assert.That(reader[1].ToString(), Is.EqualTo("[1,11)")); - Assert.That(reader[2].ToString(), Is.EqualTo("[2,10)")); - Assert.That(reader[3].ToString(), Is.EqualTo("(,10)")); - } - } - } + //upper bound is inclusive + var r4 = new NpgsqlRange(0, true, false, 1, true, false); + Assert.IsFalse(r1 == r4); - [Test] - public async Task RangeWithLongSubtype() - { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); + var r5 = new NpgsqlRange(0, true, false, 1, false, false); + Assert.IsTrue(r1 == r5); - using (var conn = await OpenConnectionAsync()) - { - await conn.ExecuteNonQueryAsync("CREATE TYPE pg_temp.textrange AS RANGE(subtype=text)"); - conn.ReloadTypes(); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - - var value = new NpgsqlRange( - new string('a', conn.Settings.WriteBufferSize + 10), - new string('z', conn.Settings.WriteBufferSize + 10) - ); - - //var value = new NpgsqlRange("bar", "foo"); - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Range | NpgsqlDbType.Text) { Value = value }); - using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess)) - { - reader.Read(); - Assert.That(reader[0], Is.EqualTo(value)); - } - } - } - } - - [Test] - public void RangeEquality_FiniteRange() - { - var r1 = new NpgsqlRange(0, true, false, 1, false, false); + //check some other combinations while we are here + Assert.IsFalse(r2 == r3); + Assert.IsFalse(r2 == r4); + Assert.IsFalse(r3 == r4); + } - //different bounds - var r2 = new NpgsqlRange(1, true, false, 2, false, false); - Assert.IsFalse(r1 == r2); + [Test] + public void Equality_infinite() + { + var r1 = new NpgsqlRange(0, false, true, 1, false, false); - //lower bound is not inclusive - var r3 = new NpgsqlRange(0, false, false, 1, false, false); - Assert.IsFalse(r1 == r3); + //different upper bound (lower bound shoulnd't matter since it is infinite) + var r2 = new NpgsqlRange(1, false, true, 2, false, false); + Assert.IsFalse(r1 == r2); - //upper bound is inclusive - var r4 = new NpgsqlRange(0, true, false, 1, true, false); - Assert.IsFalse(r1 == r4); + //upper bound is inclusive + var r3 = new NpgsqlRange(0, false, true, 1, true, false); + Assert.IsFalse(r1 == r3); - var r5 = new NpgsqlRange(0, true, false, 1, false, false); - Assert.IsTrue(r1 == r5); + //value of lower bound shouldn't matter since it is infinite + var r4 = new NpgsqlRange(10, false, true, 1, false, false); + Assert.IsTrue(r1 == r4); - //check some other combinations while we are here - Assert.IsFalse(r2 == r3); - Assert.IsFalse(r2 == r4); - Assert.IsFalse(r3 == r4); - } + //check some other combinations while we are here + Assert.IsFalse(r2 == r3); + Assert.IsFalse(r2 == r4); + Assert.IsFalse(r3 == r4); + } - [Test] - public void RangeEquality_InfiniteRange() - { - var r1 = new NpgsqlRange(0, false, true, 1, false, false); + [Test] + public void GetHashCode_value_types() + { + NpgsqlRange a = default; + NpgsqlRange b = NpgsqlRange.Empty; + NpgsqlRange c = NpgsqlRange.Parse("(,)"); + + Assert.IsFalse(a.Equals(b)); + Assert.IsFalse(a.Equals(c)); + Assert.IsFalse(b.Equals(c)); + Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); + Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); + Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); + } - //different upper bound (lower bound shoulnd't matter since it is infinite) - var r2 = new NpgsqlRange(1, false, true, 2, false, false); - Assert.IsFalse(r1 == r2); + [Test] + public void GetHashCode_reference_types() + { + NpgsqlRange a= default; + NpgsqlRange b = NpgsqlRange.Empty; + NpgsqlRange c = NpgsqlRange.Parse("(,)"); + + Assert.IsFalse(a.Equals(b)); + Assert.IsFalse(a.Equals(c)); + Assert.IsFalse(b.Equals(c)); + Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); + Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); + Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); + } - //upper bound is inclusive - var r3 = new NpgsqlRange(0, false, true, 1, true, false); - Assert.IsFalse(r1 == r3); + [Test] + public async Task TimestampTz_range_with_DateTimeOffset() + { + // The default CLR mapping for timestamptz is DateTime, but it also supports DateTimeOffset. + // The range should also support both, defaulting to the first. + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p", conn); + + var dto1 = new DateTimeOffset(2010, 1, 3, 10, 0, 0, TimeSpan.Zero); + var dto2 = new DateTimeOffset(2010, 1, 4, 10, 0, 0, TimeSpan.Zero); + var range = new NpgsqlRange(dto1, dto2); + cmd.Parameters.AddWithValue("p", range); + using var reader = await cmd.ExecuteReaderAsync(); + + await reader.ReadAsync(); + var actual = reader.GetFieldValue>(0); + Assert.That(actual, Is.EqualTo(range)); + } - //value of lower bound shoulnd't matter since it is infinite - var r4 = new NpgsqlRange(10, false, true, 1, false, false); - Assert.IsTrue(r1 == r4); + [Test] + [NonParallelizable] + public async Task Unmapped_range_with_mapped_subtype() + { + await using var dataSource = CreateDataSource(b => b.EnableUnmappedTypes().ConnectionStringBuilder.MaxPoolSize = 1); + await using var conn = await dataSource.OpenConnectionAsync(); + + var typeName = await GetTempTypeName(conn); + await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + conn.ReloadTypes(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + + var value = new NpgsqlRange( + new string('a', conn.Settings.WriteBufferSize + 10).ToCharArray(), + new string('z', conn.Settings.WriteBufferSize + 10).ToCharArray() + ); + + await using var cmd = new NpgsqlCommand("SELECT @p", conn); + cmd.Parameters.Add(new NpgsqlParameter { DataTypeName = typeName, ParameterName = "p", Value = value }); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(NpgsqlRange))); + var result = reader.GetFieldValue>(0); + Assert.That(result, Is.EqualTo(value).Using>((actual, expected) => + actual.LowerBound!.SequenceEqual(expected.LowerBound!) && actual.UpperBound!.SequenceEqual(expected.UpperBound!))); + } - //check some other combinations while we are here - Assert.IsFalse(r2 == r3); - Assert.IsFalse(r2 == r4); - Assert.IsFalse(r3 == r4); - } + [Test] + public async Task Unmapped_range_supported_only_with_EnableUnmappedTypes() + { + await using var connection = await DataSource.OpenConnectionAsync(); + var rangeType = await GetTempTypeName(connection); + await connection.ExecuteNonQueryAsync($"CREATE TYPE {rangeType} AS RANGE(subtype=text)"); + await Task.Yield(); // TODO: fix multiplexing deadlock bug + await connection.ReloadTypesAsync(); + + var errorMessage = string.Format( + NpgsqlStrings.UnmappedRangesNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableUnmappedTypes), + nameof(NpgsqlDataSourceBuilder)); + + var exception = await AssertTypeUnsupportedWrite(new NpgsqlRange("bar", "foo"), rangeType); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead("""["bar","foo"]""", rangeType); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + + exception = await AssertTypeUnsupportedRead>("""["bar","foo"]""", rangeType); + Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } - [Test] - public void RangeHashCode_ValueTypes() - { - NpgsqlRange a = default; - NpgsqlRange b = NpgsqlRange.Empty; - NpgsqlRange c = NpgsqlRange.Parse("(,)"); - - Assert.IsFalse(a.Equals(b)); - Assert.IsFalse(a.Equals(c)); - Assert.IsFalse(b.Equals(c)); - Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); - Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); - Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); - } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4441")] + public async Task Array_of_range() + { + bool supportsMultirange; - [Test] - public void RangeHashCode_ReferenceTypes() + await using (var conn = await OpenConnectionAsync()) { - NpgsqlRange a= default; - NpgsqlRange b = NpgsqlRange.Empty; - NpgsqlRange c = NpgsqlRange.Parse("(,)"); - - Assert.IsFalse(a.Equals(b)); - Assert.IsFalse(a.Equals(c)); - Assert.IsFalse(b.Equals(c)); - Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); - Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); - Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); + supportsMultirange = conn.PostgreSqlVersion.IsGreaterOrEqual(14); } - [Test] - public async Task TimestampTzRangeWithDateTimeOffset() - { - // The default CLR mapping for timestamptz is DateTime, but it also supports DateTimeOffset. - // The range should also support both, defaulting to the first. - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand("SELECT @p", conn); - - var dto1 = new DateTimeOffset(2010, 1, 3, 10, 0, 0, TimeSpan.Zero); - var dto2 = new DateTimeOffset(2010, 1, 4, 10, 0, 0, TimeSpan.Zero); - var range = new NpgsqlRange(dto1, dto2); - cmd.Parameters.AddWithValue("p", range); - using var reader = await cmd.ExecuteReaderAsync(); - - await reader.ReadAsync(); - var actual = reader.GetFieldValue>(0); - Assert.That(actual, Is.EqualTo(range)); - } + // Starting with PG14, we map CLR NpgsqlRange[] to PG multiranges by default, but also support mapping to PG array of range. + // (wee also MultirangeTests for additional multirange-specific tests). + // Earlier versions don't have multirange, so the default mapping is to PG array of range. - [OneTimeSetUp] - public async Task OneTimeSetUp() - { - using (var conn = await OpenConnectionAsync()) - TestUtil.MinimumPgVersion(conn, "9.2.0"); - } + // Note that when NpgsqlDbType inference, we don't know the PG version (since NpgsqlParameter can exist in isolation). So + // if NpgsqlParameter.Value is set to NpgsqlRange[], NpgsqlDbType always returns multirange (hence + // isNpgsqlDbTypeInferredFromClrType is false). + await AssertType( + new NpgsqlRange[] + { + new(3, lowerBoundIsInclusive: true, 4, upperBoundIsInclusive: false), + new(5, lowerBoundIsInclusive: true, 6, upperBoundIsInclusive: false) + }, + """{"[3,4)","[5,6)"}""", + "int4range[]", + NpgsqlDbType.IntegerRange | NpgsqlDbType.Array, + isDefaultForWriting: !supportsMultirange, + isNpgsqlDbTypeInferredFromClrType: false); + } - #region ParseTests + [Test] + public async Task Ranges_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.RangesNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableRanges), nameof(NpgsqlSlimDataSourceBuilder)); - [Theory] - [TestCaseSource(nameof(DateTimeRangeTheoryData))] - public void GivenDateRangeLiteral_WhenConverted_ThenReturnsDateRange(NpgsqlRange input) - { - // Arrange - var wellKnownText = input.ToString(); + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); - // Act - var result = NpgsqlRange.Parse(wellKnownText); + var exception = await AssertTypeUnsupportedRead>("[1,10)", "int4range", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + exception = await AssertTypeUnsupportedWrite(new NpgsqlRange(1, true, 10, false), "int4range", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } - // Assert - Assert.AreEqual(input, result); - } + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableRanges() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableRanges(); + await using var dataSource = dataSourceBuilder.Build(); - [Theory] - [TestCase("empty")] - [TestCase("EMPTY")] - [TestCase(" EmPtY ")] - public void GivenEmptyIntRangeLiteral_WhenParsed_ThenReturnsEmptyIntRange(string value) - { - // Act - var result = NpgsqlRange.Parse(value); + await AssertType( + dataSource, + new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange, skipArrayCheck: true); + } - // Assert - Assert.AreEqual(NpgsqlRange.Empty, result); - } + protected override NpgsqlConnection OpenConnection() + => throw new NotSupportedException(); - [Theory] - [TestCase("(0,1)")] - [TestCase("(0,1]")] - [TestCase("[0,1)")] - [TestCase("[0,1]")] - [TestCase(" [ 0 , 1 ] ")] - public void GivenIntRangeLiteral_WhenParsed_ThenReturnsIntRange(string input) - { - // Act - var result = NpgsqlRange.Parse(input); + #region ParseTests - // Assert - Assert.AreEqual(input.Replace(" ", null), result.ToString()); - } + [Theory] + [TestCaseSource(nameof(DateTimeRangeTheoryData))] + public void Roundtrip_DateTime_ranges_through_ToString_and_Parse(NpgsqlRange input) + { + var wellKnownText = input.ToString(); + var result = NpgsqlRange.Parse(wellKnownText); + Assert.AreEqual(input, result); + } - [Theory] - [TestCase("(1,1)", "empty")] - [TestCase("[1,1)", "empty")] - [TestCase("[,1]", "(,1]")] - [TestCase("[1,]", "[1,)")] - [TestCase("[,]", "(,)")] - [TestCase("[-infinity,infinity]", "(,)")] - [TestCase("[ -infinity , infinity ]", "(,)")] - [TestCase("[-infinity,infinity)", "(,)")] - [TestCase("(-infinity,infinity]", "(,)")] - [TestCase("(-infinity,infinity)", "(,)")] - [TestCase("[null,null]", "(,)")] - [TestCase("[null,infinity]", "(,)")] - [TestCase("[-infinity,null]", "(,)")] - public void GivenPoorlyFormedIntRangeLiteral_WhenParsed_ThenReturnsIntRange(string input, string normalized) - { - // Act - var result = NpgsqlRange.Parse(input); + [Theory] + [TestCase("empty")] + [TestCase("EMPTY")] + [TestCase(" EmPtY ")] + public void Parse_empty(string value) + { + var result = NpgsqlRange.Parse(value); + Assert.AreEqual(NpgsqlRange.Empty, result); + } - // Assert - Assert.AreEqual(normalized, result.ToString()); - } + [Theory] + [TestCase("(0,1)")] + [TestCase("(0,1]")] + [TestCase("[0,1)")] + [TestCase("[0,1]")] + [TestCase(" [ 0 , 1 ] ")] + public void Roundtrip_int_ranges_through_ToString_and_Parse(string input) + { + var result = NpgsqlRange.Parse(input); + Assert.AreEqual(input.Replace(" ", null), result.ToString()); + } - [Theory] - [TestCase("(1,1)", "empty")] - [TestCase("[1,1)", "empty")] - [TestCase("[,1]", "(,1]")] - [TestCase("[1,]", "[1,)")] - [TestCase("[,]", "(,)")] - [TestCase("[-infinity,infinity]", "(,)")] - [TestCase("[ -infinity , infinity ]", "(,)")] - [TestCase("[-infinity,infinity)", "(,)")] - [TestCase("(-infinity,infinity]", "(,)")] - [TestCase("(-infinity,infinity)", "(,)")] - [TestCase("[null,null]", "(,)")] - [TestCase("[null,infinity]", "(,)")] - [TestCase("[-infinity,null]", "(,)")] - public void GivenPoorlyFormedNullableIntRangeLiteral_WhenParsed_ThenReturnsNullableIntRange(string input, string normalized) - { - // Act - var result = NpgsqlRange.Parse(input); + [Theory] + [TestCase("(1,1)", "empty")] + [TestCase("[1,1)", "empty")] + [TestCase("[,1]", "(,1]")] + [TestCase("[1,]", "[1,)")] + [TestCase("[,]", "(,)")] + [TestCase("[-infinity,infinity]", "(,)")] + [TestCase("[ -infinity , infinity ]", "(,)")] + [TestCase("[-infinity,infinity)", "(,)")] + [TestCase("(-infinity,infinity]", "(,)")] + [TestCase("(-infinity,infinity)", "(,)")] + [TestCase("[null,null]", "(,)")] + [TestCase("[null,infinity]", "(,)")] + [TestCase("[-infinity,null]", "(,)")] + public void Int_range_Parse_ToString_returns_normalized_representations(string input, string normalized) + { + var result = NpgsqlRange.Parse(input); + Assert.AreEqual(normalized, result.ToString()); + } - // Assert - Assert.AreEqual(normalized, result.ToString()); - } + [Theory] + [TestCase("(1,1)", "empty")] + [TestCase("[1,1)", "empty")] + [TestCase("[,1]", "(,1]")] + [TestCase("[1,]", "[1,)")] + [TestCase("[,]", "(,)")] + [TestCase("[-infinity,infinity]", "(,)")] + [TestCase("[ -infinity , infinity ]", "(,)")] + [TestCase("[-infinity,infinity)", "(,)")] + [TestCase("(-infinity,infinity]", "(,)")] + [TestCase("(-infinity,infinity)", "(,)")] + [TestCase("[null,null]", "(,)")] + [TestCase("[null,infinity]", "(,)")] + [TestCase("[-infinity,null]", "(,)")] + public void Nullable_int_range_Parse_ToString_returns_normalized_representations(string input, string normalized) + { + var result = NpgsqlRange.Parse(input); + Assert.AreEqual(normalized, result.ToString()); + } - [Theory] - [TestCase("(a,a)", "empty")] - [TestCase("[a,a)", "empty")] - [TestCase("[a,a]", "[a,a]")] - [TestCase("(a,b)", "(a,b)")] - public void GivenStringRangeLiteral_WhenParsed_ThenReturnsStringRange(string input, string normalized) - { - // Act - var result = NpgsqlRange.Parse(input); + [Theory] + [TestCase("(a,a)", "empty")] + [TestCase("[a,a)", "empty")] + [TestCase("[a,a]", "[a,a]")] + [TestCase("(a,b)", "(a,b)")] + public void String_range_Parse_ToString_returns_normalized_representations(string input, string normalized) + { + var result = NpgsqlRange.Parse(input); + Assert.AreEqual(normalized, result.ToString()); + } - // Assert - Assert.AreEqual(normalized, result.ToString()); - } + [Theory] + [TestCase("(one,two)")] + public void Roundtrip_string_ranges_through_ToString_and_Parse2(string input) + { + var result = NpgsqlRange.Parse(input); + Assert.AreEqual(input, result.ToString()); + } - [Theory] - [TestCase("(one,two)")] - public void GivenSimpleTypeRangeLiteral_WhenParsed_ThenReturnsSimpleTypeRange(string input) - { - // Act - var result = NpgsqlRange.Parse(input); + [Theory] + [TestCase("0, 1)")] + [TestCase("(0 1)")] + [TestCase("(0, 1")] + [TestCase(" 0, 1 ")] + public void Parse_malformed_range_throws(string input) + => Assert.Throws(() => NpgsqlRange.Parse(input)); - // Assert - Assert.AreEqual(input, result.ToString()); - } + [Test, Ignore("Fails only on build server, can't reproduce locally.")] + public void TypeConverter() + { + // Arrange + NpgsqlRange.RangeTypeConverter.Register(); + var converter = TypeDescriptor.GetConverter(typeof(NpgsqlRange)); - [Theory] - [TestCase("0, 1)")] - [TestCase("(0 1)")] - [TestCase("(0, 1")] - [TestCase(" 0, 1 ")] - public void GivenMalformedRangeLiteral_WhenParsed_ThenThrowsFormatException(string input) - { - Assert.Throws(() => NpgsqlRange.Parse(input)); - } + // Act + Assert.IsInstanceOf.RangeTypeConverter>(converter); + Assert.IsTrue(converter.CanConvertFrom(typeof(string))); + var result = converter.ConvertFromString("empty"); - [Test, Ignore("Fails only on build server, can't reproduce locally.")] - public void CanGetTypeConverter() - { - // Arrange - NpgsqlRange.RangeTypeConverter.Register(); - var converter = TypeDescriptor.GetConverter(typeof(NpgsqlRange)); + // Assert + Assert.AreEqual(NpgsqlRange.Empty, result); + } - // Act - Assert.IsInstanceOf.RangeTypeConverter>(converter); - Assert.IsTrue(converter.CanConvertFrom(typeof(string))); - var result = converter.ConvertFromString("empty"); + #endregion - // Assert - Assert.AreEqual(NpgsqlRange.Empty, result); - } + #region TheoryData - #endregion + [TypeConverter(typeof(SimpleTypeConverter))] + class SimpleType + { + string? Value { get; } - #region TheoryData + SimpleType(string? value) + { + Value = value; + } - [TypeConverter(typeof(SimpleTypeConverter))] - class SimpleType + public override string? ToString() { - string? Value { get; } + return Value; + } - SimpleType(string? value) - { - Value = value; - } + class SimpleTypeConverter : TypeConverter + { + public override bool CanConvertFrom(ITypeDescriptorContext? context, Type sourceType) + => typeof(string) == sourceType; - public override string? ToString() - { - return Value; - } + public override object ConvertFrom(ITypeDescriptorContext? context, CultureInfo? culture, object value) + => new SimpleType(value.ToString()); + } + } - class SimpleTypeConverter : TypeConverter - { - public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) - => typeof(string) == sourceType; + // ReSharper disable once InconsistentNaming + static readonly DateTime May_17_2018 = DateTime.Parse("2018-05-17"); - public override object ConvertFrom(ITypeDescriptorContext context, CultureInfo culture, object value) - => new SimpleType(value.ToString()); - } - } + // ReSharper disable once InconsistentNaming + static readonly DateTime May_18_2018 = DateTime.Parse("2018-05-18"); - // ReSharper disable once InconsistentNaming - static readonly DateTime May_17_2018 = DateTime.Parse("2018-05-17"); + /// + /// Provides theory data for of . + /// + static object[][] DateTimeRangeTheoryData => + new object[][] + { + // (2018-05-17, 2018-05-18) + new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, false) }, - // ReSharper disable once InconsistentNaming - static readonly DateTime May_18_2018 = DateTime.Parse("2018-05-18"); + // [2018-05-17, 2018-05-18] + new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, true, false) }, - /// - /// Provides theory data for of . - /// - static object[][] DateTimeRangeTheoryData => - new object[][] - { - // (2018-05-17, 2018-05-18) - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, false) }, + // [2018-05-17, 2018-05-18) + new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, false, false) }, - // [2018-05-17, 2018-05-18] - new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, true, false) }, + // (2018-05-17, 2018-05-18] + new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, true, false) }, - // [2018-05-17, 2018-05-18) - new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, false, false) }, + // (,) + new object[] { new NpgsqlRange(default, false, true, default, false, true) }, + new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, true) }, - // (2018-05-17, 2018-05-18] - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, true, false) }, + // (2018-05-17,) + new object[] { new NpgsqlRange(May_17_2018, false, false, default, false, true) }, + new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, true) }, - // (,) - new object[] { new NpgsqlRange(default, false, true, default, false, true) }, - new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, true) }, + // (,2018-05-18) + new object[] { new NpgsqlRange(default, false, true, May_18_2018, false, false) }, + new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, false) } + }; - // (2018-05-17,) - new object[] { new NpgsqlRange(May_17_2018, false, false, default, false, true) }, - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, true) }, + #endregion - // (,2018-05-18) - new object[] { new NpgsqlRange(default, false, true, May_18_2018, false, false) }, - new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, false) } - }; + protected override NpgsqlDataSource DataSource { get; } - #endregion + public RangeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) + => DataSource = CreateDataSource(builder => + { + builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; + }); - public RangeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} - } + [OneTimeTearDown] + public void TearDown() => DataSource.Dispose(); } diff --git a/test/Npgsql.Tests/Types/RecordTests.cs b/test/Npgsql.Tests/Types/RecordTests.cs new file mode 100644 index 0000000000..7aefe1e98d --- /dev/null +++ b/test/Npgsql.Tests/Types/RecordTests.cs @@ -0,0 +1,157 @@ +using System; +using System.Data; +using System.Threading.Tasks; +using Npgsql.Properties; +using NUnit.Framework; +using NUnit.Framework.Constraints; + +namespace Npgsql.Tests.Types; + +public class RecordTests : MultiplexingTestBase +{ + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/724")] + [IssueLink("https://github.com/npgsql/npgsql/issues/1980")] + public async Task Read_Record_as_object_array() + { + var recordLiteral = "(1,'foo'::text)::record"; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = (object[])reader[0]; + Assert.That(record[0], Is.EqualTo(1)); + Assert.That(record[1], Is.EqualTo("foo")); + + var array = (object[][])reader[1]; + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0][0], Is.EqualTo(1)); + Assert.That(array[1][0], Is.EqualTo(1)); + } + + [Test] + public async Task Read_Record_as_ValueTuple() + { + await using var dataSource = CreateDataSource(b => b.EnableRecordsAsTuples()); + await using var conn = await dataSource.OpenConnectionAsync(); + + var recordLiteral = "(1,'foo'::text)::record"; + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = reader.GetFieldValue<(int, string)>(0); + Assert.That(record.Item1, Is.EqualTo(1)); + Assert.That(record.Item2, Is.EqualTo("foo")); + + var array = reader.GetFieldValue<(int, string)[]>(1); + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0].Item1, Is.EqualTo(1)); + Assert.That(array[0].Item2, Is.EqualTo("foo")); + Assert.That(array[1].Item1, Is.EqualTo(1)); + Assert.That(array[1].Item2, Is.EqualTo("foo")); + } + + [Test] + public async Task Read_Record_as_Tuple() + { + await using var dataSource = CreateDataSource(b => b.EnableRecordsAsTuples()); + await using var conn = await dataSource.OpenConnectionAsync(); + + var recordLiteral = "(1,'foo'::text)::record"; + await using var cmd = new NpgsqlCommand($"SELECT {recordLiteral}, ARRAY[{recordLiteral}, {recordLiteral}]", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + var record = reader.GetFieldValue>(0); + Assert.That(record.Item1, Is.EqualTo(1)); + Assert.That(record.Item2, Is.EqualTo("foo")); + + var array = reader.GetFieldValue[]>(1); + Assert.That(array.Length, Is.EqualTo(2)); + Assert.That(array[0].Item1, Is.EqualTo(1)); + Assert.That(array[0].Item2, Is.EqualTo("foo")); + Assert.That(array[1].Item1, Is.EqualTo(1)); + Assert.That(array[1].Item2, Is.EqualTo("foo")); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1238")] + public async Task Record_with_non_int_field() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT ('one'::TEXT, 2)", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var record = reader.GetFieldValue(0); + Assert.That(record[0], Is.EqualTo("one")); + Assert.That(record[1], Is.EqualTo(2)); + } + + [Test] + public async Task As_ValueTuple_supported_only_with_EnableRecordsAsTuples() + { + await using var connection = await DataSource.OpenConnectionAsync(); + await using var command = new NpgsqlCommand("SELECT (1, 'foo')::record", connection); + await using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var errorMessage = string.Format( + NpgsqlStrings.RecordsNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableRecordsAsTuples), + nameof(NpgsqlDataSourceBuilder), + nameof(NpgsqlSlimDataSourceBuilder.EnableRecords)); + + var exception = Assert.Throws(() => reader.GetFieldValue<(int, string)>(0))!; + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + } + + [Test] + public async Task Records_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + // RecordHandler doesn't support writing, so we only check for reading + cmd.CommandText = "SELECT ('one'::text, 2)"; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var errorMessage = string.Format( + NpgsqlStrings.RecordsNotEnabled, + nameof(NpgsqlSlimDataSourceBuilder.EnableRecordsAsTuples), + nameof(NpgsqlSlimDataSourceBuilder), + nameof(NpgsqlSlimDataSourceBuilder.EnableRecords)); + + var exception = Assert.Throws(() => reader.GetValue(0))!; + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + + exception = Assert.Throws(() => reader.GetFieldValue(0))!; + Assert.IsInstanceOf(exception.InnerException); + Assert.AreEqual(errorMessage, exception.InnerException!.Message); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableRecords() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableRecords(); + await using var dataSource = dataSourceBuilder.Build(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + // RecordHandler doesn't support writing, so we only check for reading + cmd.CommandText = "SELECT ('one'::text, 2)"; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(() => reader.GetValue(0), Throws.Nothing); + Assert.That(() => reader.GetFieldValue(0), Throws.Nothing); + } + + public RecordTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} +} diff --git a/test/Npgsql.Tests/Types/TextTests.cs b/test/Npgsql.Tests/Types/TextTests.cs index 10bd2f9fd2..7e86fb131b 100644 --- a/test/Npgsql.Tests/Types/TextTests.cs +++ b/test/Npgsql.Tests/Types/TextTests.cs @@ -1,285 +1,154 @@ using System; using System.Data; +using System.IO; using System.Text; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; -namespace Npgsql.Tests.Types +namespace Npgsql.Tests.Types; + +/// +/// Tests on PostgreSQL text +/// +/// +/// https://www.postgresql.org/docs/current/static/datatype-character.html +/// +public class TextTests : MultiplexingTestBase { - /// - /// Tests on PostgreSQL text - /// - /// - /// https://www.postgresql.org/docs/current/static/datatype-character.html - /// - public class TextTests : MultiplexingTestBase + [Test] + public Task Text_as_string() + => AssertType("foo", "foo", "text", NpgsqlDbType.Text, DbType.String); + + [Test] + public Task Text_as_array_of_chars() + => AssertType("foo".ToCharArray(), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefaultForReading: false); + + [Test] + public Task Text_as_ArraySegment_of_chars() + => AssertTypeWrite(new ArraySegment("foo".ToCharArray()), "foo", "text", NpgsqlDbType.Text, DbType.String, + isDefault: false); + + [Test] + public Task Text_as_array_of_bytes() + => AssertType(Encoding.UTF8.GetBytes("foo"), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefault: false); + + [Test] + public Task Text_as_ReadOnlyMemory_of_bytes() + => AssertTypeWrite(new ReadOnlyMemory(Encoding.UTF8.GetBytes("foo")), "foo", "text", NpgsqlDbType.Text, DbType.String, + isDefault: false); + + [Test] + public Task Char_as_char() + => AssertType('f', "f", "character", NpgsqlDbType.Char, inferredDbType: DbType.String, isDefault: false); + + [Test] + [NonParallelizable] + public async Task Citext_as_string() { - [Test, Description("Roundtrips a string")] - public async Task Roundtrip() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2, @p3, @p4, @p5, @p6, @p7", conn)) - { - const string expected = "Something"; - var expectedBytes = Encoding.UTF8.GetBytes(expected); - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Text); - var p2 = new NpgsqlParameter("p2", NpgsqlDbType.Varchar); - var p3 = new NpgsqlParameter("p3", DbType.String); - var p4 = new NpgsqlParameter { ParameterName = "p4", Value = expected }; - var p5 = new NpgsqlParameter("p5", NpgsqlDbType.Text); - var p6 = new NpgsqlParameter("p6", NpgsqlDbType.Text); - var p7 = new NpgsqlParameter("p7", NpgsqlDbType.Text); - Assert.That(p2.DbType, Is.EqualTo(DbType.String)); - Assert.That(p3.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); - Assert.That(p3.DbType, Is.EqualTo(DbType.String)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - cmd.Parameters.Add(p3); - cmd.Parameters.Add(p4); - cmd.Parameters.Add(p5); - cmd.Parameters.Add(p6); - cmd.Parameters.Add(p7); - p1.Value = p2.Value = p3.Value = expected; - p5.Value = expected.ToCharArray(); - p6.Value = new ArraySegment(("X" + expected).ToCharArray(), 1, expected.Length); - p7.Value = expectedBytes; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); - Assert.That(reader.GetString(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected.ToCharArray())); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expectedBytes)); - } - } - } - } - - [Test] - public async Task Long([Values(CommandBehavior.Default, CommandBehavior.SequentialAccess)] CommandBehavior behavior) - { - using (var conn = await OpenConnectionAsync()) - { - await using var _ = await CreateTempTable(conn, "name TEXT", out var table); - var builder = new StringBuilder("ABCDEééé", conn.Settings.WriteBufferSize); - builder.Append('X', conn.Settings.WriteBufferSize); - var expected = builder.ToString(); - using (var cmd = new NpgsqlCommand($"INSERT INTO {table} (name) VALUES (@p)", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p", expected)); - await cmd.ExecuteNonQueryAsync(); - } - - using (var cmd = new NpgsqlCommand($"SELECT name, 'foo', name, name, name, name FROM {table}", conn)) - { - var reader = await cmd.ExecuteReaderAsync(behavior); - reader.Read(); - - var actual = reader[0]; - Assert.That(actual, Is.EqualTo(expected)); - - if (behavior.IsSequential()) - Assert.That(() => reader[0], Throws.Exception.TypeOf(), "Seek back sequential"); - else - Assert.That(reader[0], Is.EqualTo(expected)); - - Assert.That(reader.GetString(1), Is.EqualTo("foo")); - Assert.That(reader.GetFieldValue(2), Is.EqualTo(expected)); - Assert.That(reader.GetValue(3), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(4), Is.EqualTo(expected)); - //Assert.That(reader.GetFieldValue(5), Is.EqualTo(expected.ToCharArray())); - } - } - } + await using var conn = await OpenConnectionAsync(); + await EnsureExtensionAsync(conn, "citext"); - [Test, Description("Tests that strings are truncated when the NpgsqlParameter's Size is set")] - public async Task Truncate() - { - const string data = "SomeText"; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn)) - { - var p = new NpgsqlParameter("p", data) { Size = 4 }; - cmd.Parameters.Add(p); - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data.Substring(0, 4))); - - // NpgsqlParameter.Size needs to persist when value is changed - const string data2 = "AnotherValue"; - p.Value = data2; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2.Substring(0, 4))); - - // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 - p.Size = data2.Length + 10; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - p.Size = 0; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - p.Size = -1; - Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); - - Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); - } - } + await AssertType("foo", "foo", "citext", NpgsqlDbType.Citext, inferredDbType: DbType.String, isDefaultForWriting: false); + } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/488")] - public async Task NullCharacter() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1", conn)) - { - cmd.Parameters.Add(new NpgsqlParameter("p1", "string with \0\0\0 null \0bytes")); - Assert.That(async () => await cmd.ExecuteReaderAsync(), - Throws.Exception.TypeOf() - .With.Property(nameof(PostgresException.SqlState)).EqualTo("22021") - ); - } - } + [Test] + public Task Text_as_MemoryStream() + => AssertTypeWrite(() => new MemoryStream("foo"u8.ToArray()), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefault: false); - [Test, Description("Tests some types which are aliased to strings")] - [TestCase("Varchar")] - [TestCase("Name")] - public async Task AliasedPgTypes(string typename) - { - const string expected = "some_text"; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand($"SELECT '{expected}'::{typename}", conn)) - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected.ToCharArray())); - } - } + [Test] + public async Task Text_long() + { + await using var conn = await OpenConnectionAsync(); + var builder = new StringBuilder("ABCDEééé", conn.Settings.WriteBufferSize); + builder.Append('X', conn.Settings.WriteBufferSize); + var value = builder.ToString(); + await AssertType(value, value, "text", NpgsqlDbType.Text, DbType.String); + } - [Test] - [TestCase(DbType.AnsiString)] - [TestCase(DbType.AnsiStringFixedLength)] - public async Task AliasedDbTypes(DbType dbType) - { - using (var conn = await OpenConnectionAsync()) - using (var command = new NpgsqlCommand("SELECT @p", conn)) - { - command.Parameters.Add(new NpgsqlParameter("p", dbType) { Value = "SomeString" }); - Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo("SomeString")); - } - } + [Test, Description("Tests that strings are truncated when the NpgsqlParameter's Size is set")] + public async Task Truncate() + { + const string data = "SomeText"; + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand("SELECT @p::TEXT", conn); + var p = new NpgsqlParameter("p", data) { Size = 4 }; + cmd.Parameters.Add(p); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data.Substring(0, 4))); + + // NpgsqlParameter.Size needs to persist when value is changed + const string data2 = "AnotherValue"; + p.Value = data2; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2.Substring(0, 4))); + + // NpgsqlParameter.Size larger than the value size should mean the value size, as well as 0 and -1 + p.Value = data2; + p.Size = data2.Length + 10; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + p.Size = 0; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + p.Size = -1; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(data2)); + + Assert.That(() => p.Size = -2, Throws.Exception.TypeOf()); + } - [Test, Description("Tests the PostgreSQL internal \"char\" type")] - public async Task InternalChar() - { - using (var conn = await OpenConnectionAsync()) - using (var cmd = conn.CreateCommand()) - { - var testArr = new byte[] { (byte)'}', (byte)'"', 3 }; - var testArr2 = new char[] { '}', '"', (char)3 }; + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/488")] + public async Task Null_character() + { + var exception = await AssertTypeUnsupportedWrite("string with \0\0\0 null \0bytes"); + Assert.That(exception.SqlState, Is.EqualTo(PostgresErrorCodes.CharacterNotInRepertoire)); + } - cmd.CommandText = "Select 'a'::\"char\", (-3)::\"char\", :p1, :p2, :p3, :p4, :p5"; - cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.InternalChar) { Value = 'b' }); - cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType.InternalChar) { Value = (byte)66 }); - cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType.InternalChar) { Value = (byte)230 }); - cmd.Parameters.Add(new NpgsqlParameter("p4", NpgsqlDbType.InternalChar | NpgsqlDbType.Array) { Value = testArr }); - cmd.Parameters.Add(new NpgsqlParameter("p5", NpgsqlDbType.InternalChar | NpgsqlDbType.Array) { Value = testArr2 }); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - var expected = new char[] { 'a', (char)(256 - 3), 'b', (char)66, (char)230 }; - for (var i = 0; i < expected.Length; i++) - { - Assert.AreEqual(expected[i], reader.GetChar(i)); - } - var arr = (char[])reader.GetValue(5); - var arr2 = (char[])reader.GetValue(6); - Assert.AreEqual(testArr.Length, arr.Length); - for (var i = 0; i < arr.Length; i++) - { - Assert.AreEqual(testArr[i], arr[i]); - Assert.AreEqual(testArr2[i], arr2[i]); - } - } - } - } + [Test, Description("Tests some types which are aliased to strings")] + [TestCase("character varying", NpgsqlDbType.Varchar)] + [TestCase("name", NpgsqlDbType.Name)] + public Task Aliased_postgres_types(string pgTypeName, NpgsqlDbType npgsqlDbType) + => AssertType("foo", "foo", pgTypeName, npgsqlDbType, inferredDbType: DbType.String, isDefaultForWriting: false); - [Test] - public async Task Char() - { - var expected = 'f'; - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p", conn)) - { - cmd.Parameters.AddWithValue("p", expected); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetChar(0), Is.EqualTo(expected)); - Assert.That(reader.GetString(0), Is.EqualTo(expected.ToString())); - } - } - } + [Test] + [TestCase(DbType.AnsiString)] + [TestCase(DbType.AnsiStringFixedLength)] + public async Task Aliased_DbTypes(DbType dbType) + { + await using var conn = await OpenConnectionAsync(); + await using var command = new NpgsqlCommand("SELECT @p", conn); + command.Parameters.Add(new NpgsqlParameter("p", dbType) { Value = "SomeString" }); + Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo("SomeString")); // Inferred DbType... + } - [Test, Description("Checks support for the citext contrib type")] - [IssueLink("https://github.com/npgsql/npgsql/issues/695")] - public async Task Citext() + [Test, Description("Tests the PostgreSQL internal \"char\" type")] + public async Task Internal_char() + { + using var conn = await OpenConnectionAsync(); + using var cmd = conn.CreateCommand(); + var testArr = new byte[] { (byte)'}', (byte)'"', 3 }; + var testArr2 = new char[] { '}', '"', (char)3 }; + + cmd.CommandText = "Select 'a'::\"char\", (-3)::\"char\", :p1, :p2, :p3, :p4, :p5"; + cmd.Parameters.Add(new NpgsqlParameter("p1", NpgsqlDbType.InternalChar) { Value = 'b' }); + cmd.Parameters.Add(new NpgsqlParameter("p2", NpgsqlDbType.InternalChar) { Value = (byte)66 }); + cmd.Parameters.Add(new NpgsqlParameter("p3", NpgsqlDbType.InternalChar) { Value = (byte)230 }); + cmd.Parameters.Add(new NpgsqlParameter("p4", NpgsqlDbType.InternalChar | NpgsqlDbType.Array) { Value = testArr }); + cmd.Parameters.Add(new NpgsqlParameter("p5", NpgsqlDbType.InternalChar | NpgsqlDbType.Array) { Value = testArr2 }); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var expected = new char[] { 'a', (char)(256 - 3), 'b', (char)66, (char)230 }; + for (var i = 0; i < expected.Length; i++) { - using (var conn = await OpenConnectionAsync()) - { - await EnsureExtensionAsync(conn, "citext"); - - var value = "Foo"; - using (var cmd = new NpgsqlCommand("SELECT @p::CITEXT", conn)) - { - cmd.Parameters.AddWithValue("p", value); - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - Assert.That(reader.GetString(0), Is.EqualTo(value)); - } - } - using (var cmd = new NpgsqlCommand("SELECT @p1::CITEXT = @p2::CITEXT", conn)) - { - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Citext, "abc"); - cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Citext, "ABC"); - Assert.That(await cmd.ExecuteScalarAsync(), Is.True); - } - } + Assert.AreEqual(expected[i], reader.GetChar(i)); } - - [Test] - public async Task Xml() + var arr = (char[])reader.GetValue(5); + var arr2 = (char[])reader.GetValue(6); + Assert.AreEqual(testArr.Length, arr.Length); + for (var i = 0; i < arr.Length; i++) { - using (var conn = await OpenConnectionAsync()) - using (var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn)) - { - const string expected = "foo"; - var p1 = new NpgsqlParameter("p1", NpgsqlDbType.Xml); - var p2 = new NpgsqlParameter("p2", DbType.Xml); - Assert.That(p1.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Xml)); - Assert.That(p2.DbType, Is.EqualTo(DbType.Xml)); - cmd.Parameters.Add(p1); - cmd.Parameters.Add(p2); - p1.Value = p2.Value = expected; - using (var reader = await cmd.ExecuteReaderAsync()) - { - reader.Read(); - - for (var i = 0; i < cmd.Parameters.Count; i++) - { - Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(string))); - Assert.That(reader.GetDataTypeName(i), Is.EqualTo("xml")); - Assert.That(reader.GetString(i), Is.EqualTo(expected)); - Assert.That(reader.GetFieldValue(i), Is.EqualTo(expected)); - Assert.That(reader.GetValue(i), Is.EqualTo(expected)); - } - } - } + Assert.AreEqual(testArr[i], arr[i]); + Assert.AreEqual(testArr2[i], arr2[i]); } - - public TextTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } + + public TextTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/TypeHandlerTestBase.cs b/test/Npgsql.Tests/Types/TypeHandlerTestBase.cs deleted file mode 100644 index 8ce0d00b38..0000000000 --- a/test/Npgsql.Tests/Types/TypeHandlerTestBase.cs +++ /dev/null @@ -1,57 +0,0 @@ -using System.Threading.Tasks; -using NpgsqlTypes; -using NUnit.Framework; - -namespace Npgsql.Tests.Types -{ - public abstract class TypeHandlerTestBase : MultiplexingTestBase - { - readonly NpgsqlDbType? _npgsqlDbType; - readonly string? _typeName; - readonly string? _minVersion; - - protected TypeHandlerTestBase(MultiplexingMode multiplexingMode, NpgsqlDbType? npgsqlDbType, string? typeName, string? minVersion = null) - : base(multiplexingMode) => (_npgsqlDbType, _typeName, _minVersion) = (npgsqlDbType, typeName, minVersion); - - [OneTimeSetUp] - public async Task MinimumPgVersion() - { - if (_minVersion is string minVersion) - { - using var conn = await OpenConnectionAsync(); - TestUtil.MinimumPgVersion(conn, minVersion); - } - } - - [Test] - [TestCaseSource("TestCases")] - public async Task Read(string query, T expected) - { - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand($"SELECT {query}", conn); - - Assert.AreEqual(await cmd.ExecuteScalarAsync(), expected); - } - - [Test] - [TestCaseSource("TestCases")] - public async Task Write(string query, T expected) - { - var parameter = new NpgsqlParameter("p", expected); - - if (_npgsqlDbType != null) - parameter.NpgsqlDbType = _npgsqlDbType.Value; - - if (_typeName != null) - parameter.DataTypeName = _typeName; - - using var conn = await OpenConnectionAsync(); - using var cmd = new NpgsqlCommand($"SELECT {query}::text = @p::text", conn) - { - Parameters = { parameter } - }; - - Assert.That(await cmd.ExecuteScalarAsync(), Is.True); - } - } -} diff --git a/test/Npgsql.Tests/TypesTests.cs b/test/Npgsql.Tests/TypesTests.cs index 5bb13edd60..79e1415344 100644 --- a/test/Npgsql.Tests/TypesTests.cs +++ b/test/Npgsql.Tests/TypesTests.cs @@ -1,508 +1,253 @@ using System; -using System.Globalization; using System.Net; -using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +/// +/// Tests NpgsqlTypes.* independent of a database +/// +public class TypesTests { - /// - /// Tests NpgsqlTypes.* independent of a database - /// - [TestFixture] - public class TypesTests +#pragma warning disable CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete + [Test] + public void TsVector() { - [Test] - public void NpgsqlIntervalParse() - { - string input; - NpgsqlTimeSpan test; - - input = "1 day"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(1).Ticks, test.TotalTicks, input); - - input = "2 days"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(2).Ticks, test.TotalTicks, input); - - input = "2 days 3:04:05"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(2, 3, 4, 5).Ticks, test.TotalTicks, input); - - input = "-2 days"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(-2).Ticks, test.TotalTicks, input); - - input = "-2 days -3:04:05"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(-2, -3, -4, -5).Ticks, test.TotalTicks, input); - - input = "-2 days -0:01:02"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(-2, 0, -1, -2).Ticks, test.TotalTicks, input); - - input = "2 days -12:00"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(2, -12, 0, 0).Ticks, test.TotalTicks, input); - - input = "1 mon"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30).Ticks, test.TotalTicks, input); - - input = "2 mons"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(60).Ticks, test.TotalTicks, input); - - input = "1 mon -1 day"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(29).Ticks, test.TotalTicks, input); - - input = "1 mon -2 days"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(28).Ticks, test.TotalTicks, input); - - input = "-1 mon -2 days -3:04:05"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(-32, -3, -4, -5).Ticks, test.TotalTicks, input); - - input = "1 year"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*12).Ticks, test.TotalTicks, input); - - input = "2 years"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*24).Ticks, test.TotalTicks, input); - - input = "1 year -1 mon"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*11).Ticks, test.TotalTicks, input); - - input = "1 year -2 mons"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*10).Ticks, test.TotalTicks, input); - - input = "1 year -1 day"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*12 - 1).Ticks, test.TotalTicks, input); - - input = "1 year -2 days"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*12 - 2).Ticks, test.TotalTicks, input); - - input = "1 year -1 mon -1 day"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*11 - 1).Ticks, test.TotalTicks, input); - - input = "1 year -2 mons -2 days"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(TimeSpan.FromDays(30*10 - 2).Ticks, test.TotalTicks, input); - - input = "1 day 2:3:4.005"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(1, 2, 3, 4, 5).Ticks, test.TotalTicks, input); - - var testCulture = new CultureInfo("fr-FR"); - Assert.AreEqual(",", testCulture.NumberFormat.NumberDecimalSeparator, "decimal seperator"); - using (TestUtil.SetCurrentCulture(testCulture)) - { - input = "1 day 2:3:4.005"; - test = NpgsqlTimeSpan.Parse(input); - Assert.AreEqual(new TimeSpan(1, 2, 3, 4, 5).Ticks, test.TotalTicks, input); - } - } + NpgsqlTsVector vec; - [Test] - public void NpgsqlIntervalConstructors() - { - NpgsqlTimeSpan test; - - test = new NpgsqlTimeSpan(); - Assert.AreEqual(0, test.Months, "Months"); - Assert.AreEqual(0, test.Days, "Days"); - Assert.AreEqual(0, test.Hours, "Hours"); - Assert.AreEqual(0, test.Minutes, "Minutes"); - Assert.AreEqual(0, test.Seconds, "Seconds"); - Assert.AreEqual(0, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(0, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(1234567890); - Assert.AreEqual(0, test.Months, "Months"); - Assert.AreEqual(0, test.Days, "Days"); - Assert.AreEqual(0, test.Hours, "Hours"); - Assert.AreEqual(2, test.Minutes, "Minutes"); - Assert.AreEqual(3, test.Seconds, "Seconds"); - Assert.AreEqual(456, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(456789, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(new TimeSpan(1, 2, 3, 4, 5)).JustifyInterval(); - Assert.AreEqual(0, test.Months, "Months"); - Assert.AreEqual(1, test.Days, "Days"); - Assert.AreEqual(2, test.Hours, "Hours"); - Assert.AreEqual(3, test.Minutes, "Minutes"); - Assert.AreEqual(4, test.Seconds, "Seconds"); - Assert.AreEqual(5, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(5000, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(3, 2, 1234567890); - Assert.AreEqual(3, test.Months, "Months"); - Assert.AreEqual(2, test.Days, "Days"); - Assert.AreEqual(0, test.Hours, "Hours"); - Assert.AreEqual(2, test.Minutes, "Minutes"); - Assert.AreEqual(3, test.Seconds, "Seconds"); - Assert.AreEqual(456, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(456789, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(1, 2, 3, 4); - Assert.AreEqual(0, test.Months, "Months"); - Assert.AreEqual(1, test.Days, "Days"); - Assert.AreEqual(2, test.Hours, "Hours"); - Assert.AreEqual(3, test.Minutes, "Minutes"); - Assert.AreEqual(4, test.Seconds, "Seconds"); - Assert.AreEqual(0, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(0, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(1, 2, 3, 4, 5); - Assert.AreEqual(0, test.Months, "Months"); - Assert.AreEqual(1, test.Days, "Days"); - Assert.AreEqual(2, test.Hours, "Hours"); - Assert.AreEqual(3, test.Minutes, "Minutes"); - Assert.AreEqual(4, test.Seconds, "Seconds"); - Assert.AreEqual(5, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(5000, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(1, 2, 3, 4, 5, 6); - Assert.AreEqual(1, test.Months, "Months"); - Assert.AreEqual(2, test.Days, "Days"); - Assert.AreEqual(3, test.Hours, "Hours"); - Assert.AreEqual(4, test.Minutes, "Minutes"); - Assert.AreEqual(5, test.Seconds, "Seconds"); - Assert.AreEqual(6, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(6000, test.Microseconds, "Microseconds"); - - test = new NpgsqlTimeSpan(1, 2, 3, 4, 5, 6, 7); - Assert.AreEqual(14, test.Months, "Months"); - Assert.AreEqual(3, test.Days, "Days"); - Assert.AreEqual(4, test.Hours, "Hours"); - Assert.AreEqual(5, test.Minutes, "Minutes"); - Assert.AreEqual(6, test.Seconds, "Seconds"); - Assert.AreEqual(7, test.Milliseconds, "Milliseconds"); - Assert.AreEqual(7000, test.Microseconds, "Microseconds"); - } + vec = NpgsqlTsVector.Parse("a"); + Assert.AreEqual("'a'", vec.ToString()); - [Test] - public void NpgsqlIntervalToString() - { - Assert.AreEqual("00:00:00", new NpgsqlTimeSpan().ToString()); + vec = NpgsqlTsVector.Parse("a "); + Assert.AreEqual("'a'", vec.ToString()); - Assert.AreEqual("00:02:03.456789", new NpgsqlTimeSpan(1234567890).ToString()); + vec = NpgsqlTsVector.Parse("a:1A"); + Assert.AreEqual("'a':1A", vec.ToString()); - Assert.AreEqual("00:02:03.456789", new NpgsqlTimeSpan(1234567891).ToString()); + vec = NpgsqlTsVector.Parse(@"\abc\def:1a "); + Assert.AreEqual("'abcdef':1A", vec.ToString()); - Assert.AreEqual("1 day 02:03:04.005", - new NpgsqlTimeSpan(new TimeSpan(1, 2, 3, 4, 5)).JustifyInterval().ToString()); + vec = NpgsqlTsVector.Parse(@"abc:3A 'abc' abc:4B 'hello''yo' 'meh\'\\':5"); + Assert.AreEqual(@"'abc':3A,4B 'hello''yo' 'meh''\\':5", vec.ToString()); - Assert.AreEqual("3 mons 2 days 00:02:03.456789", new NpgsqlTimeSpan(3, 2, 1234567890).ToString()); + vec = NpgsqlTsVector.Parse(" a:12345C a:24D a:25B b c d 1 2 a:25A,26B,27,28"); + Assert.AreEqual("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'", vec.ToString()); + } - Assert.AreEqual("1 day 02:03:04", new NpgsqlTimeSpan(1, 2, 3, 4).ToString()); + [Test] + public void TsQuery() + { + NpgsqlTsQuery query; + + query = new NpgsqlTsQueryLexeme("a", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B); + query = new NpgsqlTsQueryOr(query, query); + query = new NpgsqlTsQueryOr(query, query); + + var str = query.ToString(); + + query = NpgsqlTsQuery.Parse("a & b | c"); + Assert.AreEqual("'a' & 'b' | 'c'", query.ToString()); + + query = NpgsqlTsQuery.Parse("'a''':*ab&d:d&!c"); + Assert.AreEqual("'a''':*AB & 'd':D & !'c'", query.ToString()); + + query = NpgsqlTsQuery.Parse("(a & !(c | d)) & (!!a&b) | c | d | e"); + Assert.AreEqual("( ( 'a' & !( 'c' | 'd' ) & !( !'a' ) & 'b' | 'c' ) | 'd' ) | 'e'", query.ToString()); + Assert.AreEqual(query.ToString(), NpgsqlTsQuery.Parse(query.ToString()).ToString()); + + query = NpgsqlTsQuery.Parse("(((a:*)))"); + Assert.AreEqual("'a':*", query.ToString()); + + query = NpgsqlTsQuery.Parse(@"'a\\b''cde'"); + Assert.AreEqual(@"a\b'cde", ((NpgsqlTsQueryLexeme)query).Text); + Assert.AreEqual(@"'a\\b''cde'", query.ToString()); + + query = NpgsqlTsQuery.Parse(@"a <-> b"); + Assert.AreEqual("'a' <-> 'b'", query.ToString()); + + query = NpgsqlTsQuery.Parse("((a & b) <5> c) <-> !d <0> e"); + Assert.AreEqual("( ( 'a' & 'b' <5> 'c' ) <-> !'d' ) <0> 'e'", query.ToString()); + + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a b c & &")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("&")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("|")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("!")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("(")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse(")")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("()")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<-")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<->")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a <->")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<>")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a b")); + Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a <-1> b")); + } +#pragma warning restore CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete - Assert.AreEqual("1 day 02:03:04.005", new NpgsqlTimeSpan(1, 2, 3, 4, 5).ToString()); + [Test] + public void TsQueryEquatibility() + { + //Debugger.Launch(); + AreEqual( + new NpgsqlTsQueryLexeme("lexeme"), + new NpgsqlTsQueryLexeme("lexeme")); - Assert.AreEqual("1 mon 2 days 03:04:05.006", new NpgsqlTimeSpan(1, 2, 3, 4, 5, 6).ToString()); + AreEqual( + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B), + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B)); - Assert.AreEqual("14 mons 3 days 04:05:06.007", new NpgsqlTimeSpan(1, 2, 3, 4, 5, 6, 7).ToString()); + AreEqual( + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B, true), + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B, true)); - Assert.AreEqual(new NpgsqlTimeSpan(0, 2, 3, 4, 5).ToString(), new NpgsqlTimeSpan(new TimeSpan(0, 2, 3, 4, 5)).ToString()); + AreEqual( + new NpgsqlTsQueryNot(new NpgsqlTsQueryLexeme("not")), + new NpgsqlTsQueryNot(new NpgsqlTsQueryLexeme("not"))); - Assert.AreEqual(new NpgsqlTimeSpan(1, 2, 3, 4, 5).ToString(), new NpgsqlTimeSpan(new TimeSpan(1, 2, 3, 4, 5)).ToString()); - const long moreThanAMonthInTicks = TimeSpan.TicksPerDay*40; - Assert.AreEqual(new NpgsqlTimeSpan(moreThanAMonthInTicks).ToString(), new NpgsqlTimeSpan(new TimeSpan(moreThanAMonthInTicks)).ToString()); + AreEqual( + new NpgsqlTsQueryAnd(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right")), + new NpgsqlTsQueryAnd(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right"))); - var testCulture = new CultureInfo("fr-FR"); - Assert.AreEqual(",", testCulture.NumberFormat.NumberDecimalSeparator, "decimal seperator"); - using (TestUtil.SetCurrentCulture(testCulture)) - { - Assert.AreEqual("14 mons 3 days 04:05:06.007", new NpgsqlTimeSpan(1, 2, 3, 4, 5, 6, 7).ToString()); - } - } + AreEqual( + new NpgsqlTsQueryOr(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right")), + new NpgsqlTsQueryOr(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right"))); - [Test] - public void NpgsqlDateConstructors() - { - NpgsqlDate date; - DateTime dateTime; - System.Globalization.Calendar calendar = new System.Globalization.GregorianCalendar(); - - date = new NpgsqlDate(); - Assert.AreEqual(1, date.Day); - Assert.AreEqual(DayOfWeek.Monday, date.DayOfWeek); - Assert.AreEqual(1, date.DayOfYear); - Assert.AreEqual(false, date.IsLeapYear); - Assert.AreEqual(1, date.Month); - Assert.AreEqual(1, date.Year); - - dateTime = new DateTime(2009, 5, 31); - date = new NpgsqlDate(dateTime); - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(2009), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - - //Console.WriteLine(new DateTime(2009, 5, 31).Ticks); - //Console.WriteLine((new DateTime(2009, 5, 31) - new DateTime(1, 1, 1)).TotalDays); - // 2009-5-31 - dateTime = new DateTime(633793248000000000); // ticks since 1 Jan 1 - date = new NpgsqlDate(733557); // days since 1 Jan 1 - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(2009), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - - // copy previous value. should get same result - date = new NpgsqlDate(date); - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(2009), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - } + AreEqual( + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 0, new NpgsqlTsQueryLexeme("right")), + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 0, new NpgsqlTsQueryLexeme("right"))); - [Test] - public void NpgsqlDateToString() - { - Assert.AreEqual("2009-05-31", new NpgsqlDate(2009, 5, 31).ToString()); + AreEqual( + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 1, new NpgsqlTsQueryLexeme("right")), + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 1, new NpgsqlTsQueryLexeme("right"))); - Assert.AreEqual("0001-05-07 BC", new NpgsqlDate(-1, 5, 7).ToString()); + AreEqual( + new NpgsqlTsQueryEmpty(), + new NpgsqlTsQueryEmpty()); - var testCulture = new CultureInfo("fr-FR"); - Assert.AreEqual(",", testCulture.NumberFormat.NumberDecimalSeparator, "decimal seperator"); - using (TestUtil.SetCurrentCulture(testCulture)) - Assert.AreEqual("2009-05-31", new NpgsqlDate(2009, 5, 31).ToString()); - } + AreNotEqual( + new NpgsqlTsQueryLexeme("lexeme a"), + new NpgsqlTsQueryLexeme("lexeme b")); - [Test] - public void SpecialDates() - { - NpgsqlDate date; - DateTime dateTime; - System.Globalization.Calendar calendar = new System.Globalization.GregorianCalendar(); - - // a date after a leap year. - dateTime = new DateTime(2008, 5, 31); - date = new NpgsqlDate(dateTime); - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(2008), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - - // A date that is a leap year day. - dateTime = new DateTime(2000, 2, 29); - date = new NpgsqlDate(2000, 2, 29); - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(2000), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - - // A date that is not in a leap year. - dateTime = new DateTime(1900, 3, 1); - date = new NpgsqlDate(1900, 3, 1); - Assert.AreEqual(dateTime.Day, date.Day); - Assert.AreEqual(dateTime.DayOfWeek, date.DayOfWeek); - Assert.AreEqual(dateTime.DayOfYear, date.DayOfYear); - Assert.AreEqual(calendar.IsLeapYear(1900), date.IsLeapYear); - Assert.AreEqual(dateTime.Month, date.Month); - Assert.AreEqual(dateTime.Year, date.Year); - - // a date after a leap year. - date = new NpgsqlDate(-1, 12, 31); - Assert.AreEqual(31, date.Day); - Assert.AreEqual(DayOfWeek.Sunday, date.DayOfWeek); - Assert.AreEqual(366, date.DayOfYear); - Assert.AreEqual(true, date.IsLeapYear); - Assert.AreEqual(12, date.Month); - Assert.AreEqual(-1, date.Year); - } + AreNotEqual( + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.D), + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B)); - [Test] - public void NpgsqlDateMath() - { - NpgsqlDate date; - - // add a day to the empty constructor - date = new NpgsqlDate() + new NpgsqlTimeSpan(0, 1, 0); - Assert.AreEqual(2, date.Day); - Assert.AreEqual(DayOfWeek.Tuesday, date.DayOfWeek); - Assert.AreEqual(2, date.DayOfYear); - Assert.AreEqual(false, date.IsLeapYear); - Assert.AreEqual(1, date.Month); - Assert.AreEqual(1, date.Year); - - // add a day the same value as the empty constructor - date = new NpgsqlDate(1, 1, 1) + new NpgsqlTimeSpan(0, 1, 0); - Assert.AreEqual(2, date.Day); - Assert.AreEqual(DayOfWeek.Tuesday, date.DayOfWeek); - Assert.AreEqual(2, date.DayOfYear); - Assert.AreEqual(false, date.IsLeapYear); - Assert.AreEqual(1, date.Month); - Assert.AreEqual(1, date.Year); - - var diff = new NpgsqlDate(1, 1, 1) - new NpgsqlDate(-1, 12, 31); - Assert.AreEqual(new NpgsqlTimeSpan(0, 1, 0), diff); - - // Test of the addMonths method (positive values added) - var dateForTestMonths = new NpgsqlDate(2008, 1, 1); - Assert.AreEqual(dateForTestMonths.AddMonths(0), dateForTestMonths); - Assert.AreEqual(dateForTestMonths.AddMonths(4), new NpgsqlDate(2008, 5, 1)); - Assert.AreEqual(dateForTestMonths.AddMonths(11), new NpgsqlDate(2008, 12, 1)); - Assert.AreEqual(dateForTestMonths.AddMonths(12), new NpgsqlDate(2009, 1, 1)); - Assert.AreEqual(dateForTestMonths.AddMonths(14), new NpgsqlDate(2009, 3, 1)); - dateForTestMonths = new NpgsqlDate(2008, 1, 31); - Assert.AreEqual(dateForTestMonths.AddMonths(1), new NpgsqlDate(2008, 2, 29)); - Assert.AreEqual(dateForTestMonths.AddMonths(13), new NpgsqlDate(2009, 2, 28)); - - // Test of the addMonths method (negative values added) - dateForTestMonths = new NpgsqlDate(2009, 1, 1); - Assert.AreEqual(dateForTestMonths.AddMonths(0), dateForTestMonths); - Assert.AreEqual(dateForTestMonths.AddMonths(-4), new NpgsqlDate(2008, 9, 1)); - Assert.AreEqual(dateForTestMonths.AddMonths(-12), new NpgsqlDate(2008, 1, 1)); - Assert.AreEqual(dateForTestMonths.AddMonths(-13), new NpgsqlDate(2007, 12, 1)); - dateForTestMonths = new NpgsqlDate(2009, 3, 31); - Assert.AreEqual(dateForTestMonths.AddMonths(-1), new NpgsqlDate(2009, 2, 28)); - Assert.AreEqual(dateForTestMonths.AddMonths(-13), new NpgsqlDate(2008, 2, 29)); - } + AreNotEqual( + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B, true), + new NpgsqlTsQueryLexeme("lexeme", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B, false)); + + AreNotEqual( + new NpgsqlTsQueryNot(new NpgsqlTsQueryLexeme("not")), + new NpgsqlTsQueryNot(new NpgsqlTsQueryLexeme("ton"))); - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3019")] - public void NpgsqlDateTimeMath() + AreNotEqual( + new NpgsqlTsQueryAnd(new NpgsqlTsQueryLexeme("right"), new NpgsqlTsQueryLexeme("left")), + new NpgsqlTsQueryAnd(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right"))); + + AreNotEqual( + new NpgsqlTsQueryOr(new NpgsqlTsQueryLexeme("right"), new NpgsqlTsQueryLexeme("left")), + new NpgsqlTsQueryOr(new NpgsqlTsQueryLexeme("left"), new NpgsqlTsQueryLexeme("right"))); + + AreNotEqual( + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("right"), 0, new NpgsqlTsQueryLexeme("left")), + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 0, new NpgsqlTsQueryLexeme("right"))); + + AreNotEqual( + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 0, new NpgsqlTsQueryLexeme("right")), + new NpgsqlTsQueryFollowedBy(new NpgsqlTsQueryLexeme("left"), 1, new NpgsqlTsQueryLexeme("right"))); + + void AreEqual(NpgsqlTsQuery left, NpgsqlTsQuery right) { - // Note* NpgsqlTimespan treats 1 month as 30 days - Assert.That(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0).Add(new NpgsqlTimeSpan(1, 2, 0)), - Is.EqualTo(new NpgsqlDateTime(2020, 2, 2, 0, 0, 0))); - Assert.That(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0).Add(new NpgsqlTimeSpan(0, -1, 0)), - Is.EqualTo(new NpgsqlDateTime(2019, 12, 31, 0, 0, 0))); - Assert.That(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0).Add(new NpgsqlTimeSpan(0, 0, 0)), - Is.EqualTo(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0))); - Assert.That(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0).Add(new NpgsqlTimeSpan(0, 0, 10000000)), - Is.EqualTo(new NpgsqlDateTime(2020, 1, 1, 0, 0, 1))); - Assert.That(new NpgsqlDateTime(2020, 1, 1, 0, 0, 0).Subtract(new NpgsqlTimeSpan(1, 1, 0)), - Is.EqualTo(new NpgsqlDateTime(2019, 12, 1, 0, 0, 0))); - // Add 1 month = 2020-03-01 then add 30 days (1 month in npgsqlTimespan = 30 days) = 2020-03-31 - Assert.That(new NpgsqlDateTime(2020, 2, 1, 0, 0, 0).AddMonths(1).Add(new NpgsqlTimeSpan(1, 0, 0)), - Is.EqualTo(new NpgsqlDateTime(2020, 3, 31, 0, 0, 0))); + Assert.True(left == right); + Assert.False(left != right); + Assert.AreEqual(left, right); + Assert.AreEqual(left.GetHashCode(), right.GetHashCode()); } - [Test] - public void TsVector() + void AreNotEqual(NpgsqlTsQuery left, NpgsqlTsQuery right) { - NpgsqlTsVector vec; + Assert.False(left == right); + Assert.True(left != right); + Assert.AreNotEqual(left, right); + Assert.AreNotEqual(left.GetHashCode(), right.GetHashCode()); + } + } - vec = NpgsqlTsVector.Parse("a"); - Assert.AreEqual("'a'", vec.ToString()); +#pragma warning disable CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete + [Test] + public void TsQueryOperatorPrecedence() + { + var query = NpgsqlTsQuery.Parse("!a <-> b & c | d & e"); + var expectedGrouping = NpgsqlTsQuery.Parse("((!(a) <-> b) & c) | (d & e)"); + Assert.AreEqual(expectedGrouping.ToString(), query.ToString()); + } +#pragma warning restore CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete - vec = NpgsqlTsVector.Parse("a "); - Assert.AreEqual("'a'", vec.ToString()); + [Test] + public void NpgsqlPath_empty() + => Assert.That(new NpgsqlPath { new(1, 2) }, Is.EqualTo(new NpgsqlPath(new NpgsqlPoint(1, 2)))); - vec = NpgsqlTsVector.Parse("a:1A"); - Assert.AreEqual("'a':1A", vec.ToString()); + [Test] + public void NpgsqlPolygon_empty() + => Assert.That(new NpgsqlPolygon { new(1, 2) }, Is.EqualTo(new NpgsqlPolygon(new NpgsqlPoint(1, 2)))); - vec = NpgsqlTsVector.Parse(@"\abc\def:1a "); - Assert.AreEqual("'abcdef':1A", vec.ToString()); + [Test] + public void Bug1011018() + { + var p = new NpgsqlParameter(); + p.NpgsqlDbType = NpgsqlDbType.Time; + p.Value = DateTime.Now; + var o = p.Value; + } - vec = NpgsqlTsVector.Parse(@"abc:3A 'abc' abc:4B 'hello''yo' 'meh\'\\':5"); - Assert.AreEqual(@"'abc':3A,4B 'hello''yo' 'meh''\\':5", vec.ToString()); +#pragma warning disable 618 + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/750")] + public void NpgsqlInet() + { + var v = new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 32); + Assert.That(v.ToString(), Is.EqualTo("2001:1db8:85a3:1142:1000:8a2e:1370:7334/32")); + } +#pragma warning restore 618 - vec = NpgsqlTsVector.Parse(" a:12345C a:24D a:25B b c d 1 2 a:25A,26B,27,28"); - Assert.AreEqual("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'", vec.ToString()); - } + [Test] + public void NpgsqlInet_parse_ipv4() + { + var ipv4 = new NpgsqlInet("192.168.1.1/8"); + Assert.That(ipv4.Address, Is.EqualTo(IPAddress.Parse("192.168.1.1"))); + Assert.That(ipv4.Netmask, Is.EqualTo(8)); - [Test] - public void TsQuery() - { - NpgsqlTsQuery query; - - query = new NpgsqlTsQueryLexeme("a", NpgsqlTsQueryLexeme.Weight.A | NpgsqlTsQueryLexeme.Weight.B); - query = new NpgsqlTsQueryOr(query, query); - query = new NpgsqlTsQueryOr(query, query); - - var str = query.ToString(); - - query = NpgsqlTsQuery.Parse("a & b | c"); - Assert.AreEqual("'a' & 'b' | 'c'", query.ToString()); - - query = NpgsqlTsQuery.Parse("'a''':*ab&d:d&!c"); - Assert.AreEqual("'a''':*AB & 'd':D & !'c'", query.ToString()); - - query = NpgsqlTsQuery.Parse("(a & !(c | d)) & (!!a&b) | c | d | e"); - Assert.AreEqual("( ( 'a' & !( 'c' | 'd' ) & !( !'a' ) & 'b' | 'c' ) | 'd' ) | 'e'", query.ToString()); - Assert.AreEqual(query.ToString(), NpgsqlTsQuery.Parse(query.ToString()).ToString()); - - query = NpgsqlTsQuery.Parse("(((a:*)))"); - Assert.AreEqual("'a':*", query.ToString()); - - query = NpgsqlTsQuery.Parse(@"'a\\b''cde'"); - Assert.AreEqual(@"a\b'cde", ((NpgsqlTsQueryLexeme)query).Text); - Assert.AreEqual(@"'a\\b''cde'", query.ToString()); - - query = NpgsqlTsQuery.Parse(@"a <-> b"); - Assert.AreEqual("'a' <-> 'b'", query.ToString()); - - query = NpgsqlTsQuery.Parse("((a & b) <5> c) <-> !d <0> e"); - Assert.AreEqual("( ( 'a' & 'b' <5> 'c' ) <-> !'d' ) <0> 'e'", query.ToString()); - - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a b c & &")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("&")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("|")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("!")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("(")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse(")")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("()")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<-")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<->")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a <->")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("<>")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a b")); - Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a <-1> b")); - } + ipv4 = new NpgsqlInet("192.168.1.1/32"); + Assert.That(ipv4.Address, Is.EqualTo(IPAddress.Parse("192.168.1.1"))); + Assert.That(ipv4.Netmask, Is.EqualTo(32)); + } - [Test] - public void TsQueryOperatorPrecedence() - { - var query = NpgsqlTsQuery.Parse("!a <-> b & c | d & e"); - var expectedGrouping = NpgsqlTsQuery.Parse("((!(a) <-> b) & c) | (d & e)"); - Assert.AreEqual(expectedGrouping.ToString(), query.ToString()); - } + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/5638")] + public void NpgsqlInet_parse_ipv6() + { + var ipv6 = new NpgsqlInet("2001:0000:130F:0000:0000:09C0:876A:130B/32"); + Assert.That(ipv6.Address, Is.EqualTo(IPAddress.Parse("2001:0000:130F:0000:0000:09C0:876A:130B"))); + Assert.That(ipv6.Netmask, Is.EqualTo(32)); - [Test] - public void Bug1011018() - { - var p = new NpgsqlParameter(); - p.NpgsqlDbType = NpgsqlDbType.Time; - p.Value = DateTime.Now; - var o = p.Value; - } + ipv6 = new NpgsqlInet("2001:0000:130F:0000:0000:09C0:876A:130B"); + Assert.That(ipv6.Address, Is.EqualTo(IPAddress.Parse("2001:0000:130F:0000:0000:09C0:876A:130B"))); + Assert.That(ipv6.Netmask, Is.EqualTo(128)); + } -#pragma warning disable 618 - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/750")] - public void NpgsqlInet() - { - var v = new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 32); - Assert.That(v.ToString(), Is.EqualTo("2001:1db8:85a3:1142:1000:8a2e:1370:7334/32")); + [Test] + public void NpgsqlInet_ToString_ipv4() + { + Assert.That(new NpgsqlInet("192.168.1.1/8").ToString(), Is.EqualTo("192.168.1.1/8")); + Assert.That(new NpgsqlInet("192.168.1.1/32").ToString(), Is.EqualTo("192.168.1.1")); + } -#pragma warning disable CS8625 - Assert.That(v != null); // #776 -#pragma warning disable CS8625 - } -#pragma warning restore 618 + [Test] + public void NpgsqlInet_ToString_ipv6() + { + Assert.That(new NpgsqlInet("2001:0:130f::9c0:876a:130b/32").ToString(), Is.EqualTo("2001:0:130f::9c0:876a:130b/32")); + Assert.That(new NpgsqlInet("2001:0:130f::9c0:876a:130b/128").ToString(), Is.EqualTo("2001:0:130f::9c0:876a:130b")); } } diff --git a/test/Npgsql.Tests/WriteBufferTests.cs b/test/Npgsql.Tests/WriteBufferTests.cs index b11fa323b9..f87732a62b 100644 --- a/test/Npgsql.Tests/WriteBufferTests.cs +++ b/test/Npgsql.Tests/WriteBufferTests.cs @@ -1,68 +1,124 @@ -using System.IO; -using Npgsql.Util; +using System; +using System.IO; +using Npgsql.Internal; using NUnit.Framework; -namespace Npgsql.Tests +namespace Npgsql.Tests; + +[FixtureLifeCycle(LifeCycle.InstancePerTestCase)] // Parallel access to a single buffer +class WriteBufferTests { - class WriteBufferTests + [Test] + public void Buffered_full_buffer_no_flush() { - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1275")] - public void WriteZeroChars() - { - // Fill up the buffer entirely - WriteBuffer.WriteBytes(new byte[WriteBuffer.Size], 0, WriteBuffer.Size); - Assert.That(WriteBuffer.WriteSpaceLeft, Is.Zero); - - int charsUsed; - bool completed; - WriteBuffer.WriteStringChunked("hello", 0, 5, true, out charsUsed, out completed); - Assert.That(charsUsed, Is.Zero); - Assert.That(completed, Is.False); - WriteBuffer.WriteStringChunked("hello".ToCharArray(), 0, 5, true, out charsUsed, out completed); - Assert.That(charsUsed, Is.Zero); - Assert.That(completed, Is.False); - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] - public void ChunkedStringEncodingFits() - { - WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); - Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); - - var charsUsed = 1; - var completed = true; - // This unicode character is three bytes when encoded in UTF8 - Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C", 0, 1, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] - public void ChunkedByteArrayEncodingFits() + WriteBuffer.WritePosition += WriteBuffer.WriteSpaceLeft - sizeof(int); + var writer = WriteBuffer.GetWriter(null!, FlushMode.NonBlocking); + Assert.That(writer.ShouldFlush(sizeof(int)), Is.False); + + Assert.DoesNotThrow(() => { - WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); - Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); + Span intBytes = stackalloc byte[4]; + writer.WriteBytes(intBytes); + }); + } - var charsUsed = 1; - var completed = true; - // This unicode character is three bytes when encoded in UTF8 - Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C".ToCharArray(), 0, 1, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); - } + [Test] + public void GetWriter_Full_Buffer() + { + WriteBuffer.WritePosition += WriteBuffer.WriteSpaceLeft; + var writer = WriteBuffer.GetWriter(null!, FlushMode.Blocking); + Assert.That(writer.ShouldFlush(sizeof(byte)), Is.True); + writer.Flush(); + Assert.That(writer.ShouldFlush(sizeof(byte)), Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1275")] + public void Write_zero_characters() + { + // Fill up the buffer entirely + WriteBuffer.WriteBytes(new byte[WriteBuffer.Size], 0, WriteBuffer.Size); + Assert.That(WriteBuffer.WriteSpaceLeft, Is.Zero); + + int charsUsed; + bool completed; + WriteBuffer.WriteStringChunked("hello", 0, 5, true, out charsUsed, out completed); + Assert.That(charsUsed, Is.Zero); + Assert.That(completed, Is.False); + WriteBuffer.WriteStringChunked("hello".ToCharArray(), 0, 5, true, out charsUsed, out completed); + Assert.That(charsUsed, Is.Zero); + Assert.That(completed, Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] + public void Chunked_string_encoding_fits() + { + WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); + Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); + + var charsUsed = 1; + var completed = true; + // This unicode character is three bytes when encoded in UTF8 + Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C", 0, 1, true, out charsUsed, out completed), Throws.Nothing); + Assert.That(charsUsed, Is.EqualTo(0)); + Assert.That(completed, Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] + public void Chunked_byte_array_encoding_fits() + { + WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); + Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); + + var charsUsed = 1; + var completed = true; + // This unicode character is three bytes when encoded in UTF8 + Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C".ToCharArray(), 0, 1, true, out charsUsed, out completed), Throws.Nothing); + Assert.That(charsUsed, Is.EqualTo(0)); + Assert.That(completed, Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3733")] + public void Chunked_string_encoding_fits_with_surrogates() + { + WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1]); + Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); + + var charsUsed = 1; + var completed = true; + var cyclone = "🌀"; + + Assert.That(() => WriteBuffer.WriteStringChunked(cyclone, 0, cyclone.Length, true, out charsUsed, out completed), Throws.Nothing); + Assert.That(charsUsed, Is.EqualTo(0)); + Assert.That(completed, Is.False); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3733")] + public void Chunked_char_array_encoding_fits_with_surrogates() + { + WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1]); + Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); + + var charsUsed = 1; + var completed = true; + var cyclone = "🌀"; + + Assert.That(() => WriteBuffer.WriteStringChunked(cyclone.ToCharArray(), 0, cyclone.Length, true, out charsUsed, out completed), Throws.Nothing); + Assert.That(charsUsed, Is.EqualTo(0)); + Assert.That(completed, Is.False); + } #pragma warning disable CS8625 - [SetUp] - public void SetUp() - { - Underlying = new MemoryStream(); - WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, PGUtil.UTF8Encoding); - } + [SetUp] + public void SetUp() + { + Underlying = new MemoryStream(); + WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding); + WriteBuffer.MessageLengthValidation = false; + } #pragma warning restore CS8625 - // ReSharper disable once InconsistentNaming - NpgsqlWriteBuffer WriteBuffer = default!; - // ReSharper disable once InconsistentNaming - MemoryStream Underlying = default!; - } + // ReSharper disable once InconsistentNaming + NpgsqlWriteBuffer WriteBuffer = default!; + // ReSharper disable once InconsistentNaming + MemoryStream Underlying = default!; }