diff --git a/.cirrus.star b/.cirrus.star index e9bb672b95936..7c1caaa12f1f3 100644 --- a/.cirrus.star +++ b/.cirrus.star @@ -73,7 +73,7 @@ def compute_environment_vars(): # REPO_CI_AUTOMATIC_TRIGGER_TASKS="task_name other_task" under "Repository # Settings" on Cirrus CI's website. - default_manual_trigger_tasks = ['mingw', 'netbsd', 'openbsd'] + default_manual_trigger_tasks = [] repo_ci_automatic_trigger_tasks = env.get('REPO_CI_AUTOMATIC_TRIGGER_TASKS', '') for task in default_manual_trigger_tasks: diff --git a/.cirrus.tasks.yml b/.cirrus.tasks.yml index eca9d62fc2297..ddb5305dc815e 100644 --- a/.cirrus.tasks.yml +++ b/.cirrus.tasks.yml @@ -21,12 +21,14 @@ env: # target to test, for all but windows CHECK: check-world PROVE_FLAGS=$PROVE_FLAGS - CHECKFLAGS: -Otarget + # TODO were we avoiding --keep-going on purpose? + CHECKFLAGS: -Otarget --keep-going PROVE_FLAGS: --timer # Build test dependencies as part of the build step, to see compiler # errors/warnings in one place. MBUILD_TARGET: all testprep MTEST_ARGS: --print-errorlogs --no-rebuild -C build + MTEST_SUITES: --suite setup --suite pytest --suite ssl PGCTLTIMEOUT: 120 # avoids spurious failures during parallel tests TEMP_CONFIG: ${CIRRUS_WORKING_DIR}/src/tools/ci/pg_ci_base.conf PG_TEST_EXTRA: kerberos ldap ssl libpq_encryption load_balance oauth @@ -44,6 +46,7 @@ env: -Dldap=enabled -Dssl=openssl -Dtap_tests=enabled + -Dpytest=enabled -Dplperl=enabled -Dplpython=enabled -Ddocs=enabled @@ -222,7 +225,10 @@ task: chown root:postgres /tmp/cores sysctl kern.corefile='/tmp/cores/%N.%P.core' setup_additional_packages_script: | - #pkg install -y ... + pkg install -y \ + py311-cryptography \ + py311-packaging \ + py311-pytest # NB: Intentionally build without -Dllvm. The freebsd image size is already # large enough to make VM startup slow, and even without llvm freebsd @@ -242,7 +248,7 @@ task: test_world_script: | su postgres <<-EOF ulimit -c unlimited - meson test $MTEST_ARGS --num-processes ${TEST_JOBS} + meson test $MTEST_ARGS --num-processes ${TEST_JOBS} ${MTEST_SUITES} EOF # test runningcheck, freebsd chosen because it's currently fast enough @@ -311,7 +317,11 @@ task: -Dpam=enabled setup_additional_packages_script: | - #pkgin -y install ... + pkgin -y install \ + py312-cryptography \ + py312-packaging \ + py312-test + ln -s /usr/pkg/bin/pytest-3.12 /usr/pkg/bin/pytest <<: *netbsd_task_template - name: OpenBSD - Meson @@ -322,6 +332,7 @@ task: OS_NAME: openbsd IMAGE_FAMILY: pg-ci-openbsd-postgres PKGCONFIG_PATH: '/usr/lib/pkgconfig:/usr/local/lib/pkgconfig' + TERM: # TODO why does pytest print ANSI escapes on OpenBSD? MESON_FEATURES: >- -Dbsd_auth=enabled @@ -330,7 +341,10 @@ task: -Duuid=e2fs setup_additional_packages_script: | - #pkg_add -I ... + pkg_add -I \ + py3-cryptography \ + py3-packaging \ + py3-test # Always core dump to ${CORE_DUMP_DIR} set_core_dump_script: sysctl -w kern.nosuidcoredump=2 <<: *openbsd_task_template @@ -378,7 +392,7 @@ task: # Otherwise tests will fail on OpenBSD, due to inability to start enough # processes. ulimit -p 256 - meson test $MTEST_ARGS --num-processes ${TEST_JOBS} + meson test $MTEST_ARGS --num-processes ${TEST_JOBS} ${MTEST_SUITES} EOF on_failure: @@ -489,8 +503,11 @@ task: EOF setup_additional_packages_script: | - #apt-get update - #DEBIAN_FRONTEND=noninteractive apt-get -y install ... + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get -y install \ + python3-cryptography \ + python3-packaging \ + python3-pytest matrix: # SPECIAL: @@ -513,14 +530,15 @@ task: su postgres <<-EOF ./configure \ --enable-cassert --enable-injection-points --enable-debug \ - --enable-tap-tests --enable-nls \ + --enable-tap-tests --enable-pytest --enable-nls \ --with-segsize-blocks=6 \ --with-libnuma \ --with-liburing \ \ ${LINUX_CONFIGURE_FEATURES} \ \ - CLANG="ccache clang-16" + CLANG="ccache clang-16" \ + PYTEST="env LD_PRELOAD=/lib/x86_64-linux-gnu/libasan.so.8 pytest" EOF build_script: su postgres -c "make -s -j${BUILD_JOBS} world-bin" upload_caches: ccache @@ -588,7 +606,7 @@ task: test_world_script: | su postgres <<-EOF ulimit -c unlimited - meson test $MTEST_ARGS --num-processes ${TEST_JOBS} + meson test $MTEST_ARGS --num-processes ${TEST_JOBS} ${MTEST_SUITES} EOF # so that we don't upload 64bit logs if 32bit fails rm -rf build/ @@ -600,7 +618,7 @@ task: test_world_32_script: | su postgres <<-EOF ulimit -c unlimited - PYTHONCOERCECLOCALE=0 LANG=C meson test $MTEST_ARGS -C build-32 --num-processes ${TEST_JOBS} + PYTHONCOERCECLOCALE=0 LANG=C meson test $MTEST_ARGS -C build-32 --num-processes ${TEST_JOBS} ${MTEST_SUITES} EOF on_failure: @@ -630,6 +648,7 @@ task: CIRRUS_WORKING_DIR: ${HOME}/pgsql/ CCACHE_DIR: ${HOME}/ccache MACPORTS_CACHE: ${HOME}/macports-cache + PYTEST_DEBUG_TEMPROOT: /tmp # default is too long for UNIX sockets on Mac MESON_FEATURES: >- -Dbonjour=enabled @@ -650,6 +669,9 @@ task: p5.34-io-tty p5.34-ipc-run python312 + py312-cryptography + py312-packaging + py312-pytest tcl zstd @@ -699,6 +721,7 @@ task: sh src/tools/ci/ci_macports_packages.sh $MACOS_PACKAGE_LIST # system python doesn't provide headers sudo /opt/local/bin/port select python3 python312 + sudo /opt/local/bin/port select pytest pytest312 # Make macports install visible for subsequent steps echo PATH=/opt/local/sbin/:/opt/local/bin/:$PATH >> $CIRRUS_ENV upload_caches: macports @@ -721,7 +744,7 @@ task: test_world_script: | ulimit -c unlimited # default is 0 ulimit -n 1024 # default is 256, pretty low - meson test $MTEST_ARGS --num-processes ${TEST_JOBS} + meson test $MTEST_ARGS --num-processes ${TEST_JOBS} ${MTEST_SUITES} on_failure: <<: *on_failure_meson @@ -772,6 +795,8 @@ task: -Dldap=enabled -Dssl=openssl -Dtap_tests=enabled + -Dpytest=enabled + -DPYTEST=c:\Windows\system32\config\systemprofile\AppData\Roaming\Python\Python310\Scripts\pytest.exe -Dplperl=enabled -Dplpython=enabled @@ -780,8 +805,10 @@ task: depends_on: SanityCheck only_if: $CI_WINDOWS_ENABLED + # XXX Does Chocolatey really not have any Python package installers? setup_additional_packages_script: | REM choco install -y --no-progress ... + pip3 install --user cryptography packaging pytest setup_hosts_file_script: | echo 127.0.0.1 pg-loadbalancetest >> c:\Windows\System32\Drivers\etc\hosts @@ -800,7 +827,7 @@ task: check_world_script: | vcvarsall x64 - meson test %MTEST_ARGS% --num-processes %TEST_JOBS% + meson test %MTEST_ARGS% --num-processes %TEST_JOBS% %MTEST_SUITES% on_failure: <<: *on_failure_meson @@ -844,7 +871,7 @@ task: folder: ${CCACHE_DIR} setup_additional_packages_script: | - REM C:\msys64\usr\bin\pacman.exe -S --noconfirm ... + C:\msys64\usr\bin\pacman.exe -S --noconfirm mingw-w64-ucrt-x86_64-python-cryptography mingw-w64-ucrt-x86_64-python-packaging mingw-w64-ucrt-x86_64-python-pytest mingw_info_script: | %BASH% -c "where gcc" @@ -861,7 +888,7 @@ task: upload_caches: ccache test_world_script: | - %BASH% -c "meson test %MTEST_ARGS% --num-processes %TEST_JOBS%" + %BASH% -c "meson test %MTEST_ARGS% --num-processes %TEST_JOBS% %MTEST_SUITES%" on_failure: <<: *on_failure_meson diff --git a/.gitignore b/.gitignore index 4e911395fe3ba..268426003b190 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ win32ver.rc *.exe lib*dll.def lib*.pc +__pycache__/ # Local excludes in root directory /GNUmakefile diff --git a/config/check_pytest.py b/config/check_pytest.py new file mode 100644 index 0000000000000..14b2f3eec9be5 --- /dev/null +++ b/config/check_pytest.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group +# +# Verify that pytest-requirements.txt is satisfied. This would probably be +# easier with pip, but requiring pip on build machines is a non-starter for +# many. +# +# The design philosophy of this script is to bend over backwards to help people +# figure out what is missing. The target audience for error output is the +# buildfarm operator who just wants to get the tests running, not the test +# developer who presumably already knows how to solve these problems. + +import sys +from typing import List # TODO: Python 3.9 will remove the need for this + + +def main(): + if len(sys.argv) != 2: + sys.exit("usage: python {} REQUIREMENTS_FILE".format(sys.argv[0])) + + requirements_file = sys.argv[1] + with open(requirements_file, "r") as f: + requirements = f.readlines() + + found = packaging_check(requirements) + if not found: + sys.exit("See src/test/pytest/README for package installation help.") + + +def packaging_check(requirements: List[str]) -> bool: + """ + The preferred dependency check, which unfortunately needs newer Python + facilities. Returns True if all dependencies were found. + """ + try: + # First, attempt to find importlib.metadata. This is part of the + # standard library from 3.8 onwards. Earlier Python versions have an + # official backport called importlib_metadata, which can generally be + # installed as a separate OS package (e.g. python3-importlib-metadata). + # This complication can be removed once we stop supporting Python 3.7. + try: + from importlib import metadata + except ImportError: + import importlib_metadata as metadata + + # packaging contains the PyPA definitions of requirement specifiers. + # This is again contained in a separate OS package (for example, + # python3-packaging). + import packaging + from packaging.requirements import Requirement + + except ImportError as err: + # We don't even have enough prerequisites to check our prerequisites. + # Try to fall back on the deprecated parser, to get a better error + # message. + found = setuptools_fallback(requirements) + + if not found: + # Well, the best we can do is just print the import error as-is. + print(err, file=sys.stderr) + + return False + + # Strip extraneous whitespace, whole-line comments, and empty lines from our + # specifier list. + requirements = [r.strip() for r in requirements] + requirements = [r for r in requirements if r and r[0] != "#"] + + found = True + for spec in requirements: + req = Requirement(spec) + + # Skip any packages marked as unneeded for this particular Python env. + if req.marker and not req.marker.evaluate(): + continue + + # Make sure the package is installed... + try: + version = metadata.version(req.name) + except metadata.PackageNotFoundError: + print("Package '{}' is not installed".format(req.name), file=sys.stderr) + found = False + continue + + # ...and that it has a compatible version. + if not req.specifier.contains(version): + print( + "Package '{}' has version {}, but '{}' is required".format( + req.name, version, req.specifier + ), + file=sys.stderr, + ) + found = False + continue + + return found + + +def setuptools_fallback(requirements: List[str]) -> bool: + """ + An alternative dependency helper, based on the old deprecated pkg_resources + module in setuptools, which is pretty widely available in older Pythons. The + point of this is to bootstrap the user into an environment that can run the + packaging_check(). + + Returns False if pkg_resources is also unavailable, in which case we just + have to do our best. + """ + try: + import pkg_resources + except ModuleNotFoundError: + return False + + # An extra newline makes the Autoconf output easier to read. + print(file=sys.stderr) + + # Go one-by-one through the requirements, printing each missing dependency. + found = True + for r in requirements: + try: + pkg_resources.require(r) + except pkg_resources.DistributionNotFound as err: + # The error descriptions given here are pretty good as-is. + print(err, file=sys.stderr) + found = False + except pkg_resources.RequirementParseError as err: + assert False # TODO + + # The only reason the fallback would be called is if we're missing required + # packages. So if we "found them", the requirements file is broken... + assert ( + not found + ), "setuptools_fallback() succeeded unexpectedly; is the requirements file incomplete?" + + return True + + +if __name__ == "__main__": + main() diff --git a/config/pytest-requirements.txt b/config/pytest-requirements.txt new file mode 100644 index 0000000000000..d2e2604046616 --- /dev/null +++ b/config/pytest-requirements.txt @@ -0,0 +1,32 @@ +# +# This file contains the Python packages which are required in order for us to +# enable pytest. +# +# The syntax is a *subset* of pip's requirements.txt syntax, so that both pip +# and check_pytest.py can use it. Only whole-line comments and standard Python +# dependency specifiers are allowed. pip-specific goodies like includes and +# environment substitutions are not supported; keep it simple. +# +# Packages belong here if their absence should cause a configuration failure. If +# you'd like to make a package optional, consider using pytest.importorskip() +# instead. +# + +# pytest 7.0 was the last version which supported Python 3.6, but the BSDs have +# started putting 8.x into ports, so we support both. (pytest 8 can be used +# throughout once we drop support for Python 3.7.) +pytest >= 7.0, < 9 + +# These are meta-packages which allow check_pytest.py to run. +packaging +importlib_metadata ; python_version < "3.8" + +# Notes on the cryptography package: +# - 3.3.2 is shipped on Debian bullseye. +# - 3.4.x drops support for Python 2, making it a version of note for older LTS +# distros. +# - 35.x switched versioning schemes and moved to Rust parsing. +# - 40.x is the last version supporting Python 3.6. +# XXX Is it appropriate to require cryptography, or should we simply skip +# dependent tests? +cryptography >= 3.3.2 diff --git a/configure b/configure index 39c68161ceced..860b07763dcdd 100755 --- a/configure +++ b/configure @@ -630,6 +630,7 @@ vpath_build PG_SYSROOT PG_VERSION_NUM LDFLAGS_EX_BE +PYTEST PROVE DBTOEPUB FOP @@ -773,6 +774,7 @@ CFLAGS CC enable_injection_points PG_TEST_EXTRA +enable_pytest enable_tap_tests enable_dtrace DTRACEFLAGS @@ -851,6 +853,7 @@ enable_profiling enable_coverage enable_dtrace enable_tap_tests +enable_pytest enable_injection_points with_blocksize with_segsize @@ -1551,7 +1554,10 @@ Optional Features: --enable-profiling build with profiling enabled --enable-coverage build with coverage testing instrumentation --enable-dtrace build with DTrace support - --enable-tap-tests enable TAP tests (requires Perl and IPC::Run) + --enable-tap-tests enable (Perl-based) TAP tests (requires Perl and + IPC::Run) + --enable-pytest enable (Python-based) pytest suites (requires + Python) --enable-injection-points enable injection points (for testing) --enable-depend turn on automatic dependency tracking @@ -3639,7 +3645,7 @@ fi # -# TAP tests +# Test frameworks # @@ -3667,6 +3673,32 @@ fi + +# Check whether --enable-pytest was given. +if test "${enable_pytest+set}" = set; then : + enableval=$enable_pytest; + case $enableval in + yes) + : + ;; + no) + : + ;; + *) + as_fn_error $? "no argument expected for --enable-pytest option" "$LINENO" 5 + ;; + esac + +else + enable_pytest=no + +fi + + + + + + # # Injection points # @@ -19120,6 +19152,140 @@ $as_echo "$modulestderr" >&6; } fi fi +if test "$enable_pytest" = yes; then + # Mirror the prove checks, above, for pytest. We don't require the user to + # have selected --with-python, but we do need a Python installation. + if test -z "$PYTHON"; then + if test -z "$PYTHON"; then + for ac_prog in python3 python +do + # Extract the first word of "$ac_prog", so it can be a program name with args. +set dummy $ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_path_PYTHON+:} false; then : + $as_echo_n "(cached) " >&6 +else + case $PYTHON in + [\\/]* | ?:[\\/]*) + ac_cv_path_PYTHON="$PYTHON" # Let the user override the test with a path. + ;; + *) + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_path_PYTHON="$as_dir/$ac_word$ac_exec_ext" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + + ;; +esac +fi +PYTHON=$ac_cv_path_PYTHON +if test -n "$PYTHON"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $PYTHON" >&5 +$as_echo "$PYTHON" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$PYTHON" && break +done + +else + # Report the value of PYTHON in configure's output in all cases. + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for PYTHON" >&5 +$as_echo_n "checking for PYTHON... " >&6; } + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $PYTHON" >&5 +$as_echo "$PYTHON" >&6; } +fi + +if test x"$PYTHON" = x""; then + as_fn_error $? "Python not found" "$LINENO" 5 +fi + + fi + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for Python packages required for pytest" >&5 +$as_echo_n "checking for Python packages required for pytest... " >&6; } + modulestderr=`"$PYTHON" "$srcdir/config/check_pytest.py" "$srcdir/config/pytest-requirements.txt" 2>&1 >/dev/null` + if test $? -eq 0; then + echo "$modulestderr" >&5 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: yes" >&5 +$as_echo "yes" >&6; } + else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $modulestderr" >&5 +$as_echo "$modulestderr" >&6; } + as_fn_error $? "Additional Python packages are required to run the pytest suites" "$LINENO" 5 + fi + if test -z "$PYTEST"; then + for ac_prog in pytest py.test +do + # Extract the first word of "$ac_prog", so it can be a program name with args. +set dummy $ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_path_PYTEST+:} false; then : + $as_echo_n "(cached) " >&6 +else + case $PYTEST in + [\\/]* | ?:[\\/]*) + ac_cv_path_PYTEST="$PYTEST" # Let the user override the test with a path. + ;; + *) + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_path_PYTEST="$as_dir/$ac_word$ac_exec_ext" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + + ;; +esac +fi +PYTEST=$ac_cv_path_PYTEST +if test -n "$PYTEST"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $PYTEST" >&5 +$as_echo "$PYTEST" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$PYTEST" && break +done + +else + # Report the value of PYTEST in configure's output in all cases. + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for PYTEST" >&5 +$as_echo_n "checking for PYTEST... " >&6; } + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $PYTEST" >&5 +$as_echo "$PYTEST" >&6; } +fi + + if test -z "$PYTEST"; then + as_fn_error $? "pytest not found" "$LINENO" 5 + fi +fi + # If compiler will take -Wl,--as-needed (or various platform-specific # spellings thereof) then add that to LDFLAGS. This is much easier than # trying to filter LIBS to the minimum for each executable. diff --git a/configure.ac b/configure.ac index 066e3976c0aac..f4bf94a078f4f 100644 --- a/configure.ac +++ b/configure.ac @@ -231,11 +231,16 @@ AC_SUBST(DTRACEFLAGS)]) AC_SUBST(enable_dtrace) # -# TAP tests +# Test frameworks # PGAC_ARG_BOOL(enable, tap-tests, no, - [enable TAP tests (requires Perl and IPC::Run)]) + [enable (Perl-based) TAP tests (requires Perl and IPC::Run)]) AC_SUBST(enable_tap_tests) + +PGAC_ARG_BOOL(enable, pytest, no, + [enable (Python-based) pytest suites (requires Python)]) +AC_SUBST(enable_pytest) + AC_ARG_VAR(PG_TEST_EXTRA, [enable selected extra tests (overridden at runtime by PG_TEST_EXTRA environment variable)]) @@ -2442,6 +2447,27 @@ if test "$enable_tap_tests" = yes; then fi fi +if test "$enable_pytest" = yes; then + # Mirror the prove checks, above, for pytest. We don't require the user to + # have selected --with-python, but we do need a Python installation. + if test -z "$PYTHON"; then + PGAC_PATH_PYTHON + fi + AC_MSG_CHECKING(for Python packages required for pytest) + [modulestderr=`"$PYTHON" "$srcdir/config/check_pytest.py" "$srcdir/config/pytest-requirements.txt" 2>&1 >/dev/null`] + if test $? -eq 0; then + echo "$modulestderr" >&AS_MESSAGE_LOG_FD + AC_MSG_RESULT(yes) + else + AC_MSG_RESULT([$modulestderr]) + AC_MSG_ERROR([Additional Python packages are required to run the pytest suites]) + fi + PGAC_PATH_PROGS(PYTEST, pytest py.test) + if test -z "$PYTEST"; then + AC_MSG_ERROR([pytest not found]) + fi +fi + # If compiler will take -Wl,--as-needed (or various platform-specific # spellings thereof) then add that to LDFLAGS. This is much easier than # trying to filter LIBS to the minimum for each executable. diff --git a/meson.build b/meson.build index ab8101d67b26d..5166d82e60794 100644 --- a/meson.build +++ b/meson.build @@ -1699,6 +1699,35 @@ endif +############################################################### +# Library: pytest +############################################################### + +pytest_enabled = false +pytest = not_found_dep + +pytestopt = get_option('pytest') +if not pytestopt.disabled() + pytest_check = run_command(python, 'config/check_pytest.py', + 'config/pytest-requirements.txt', check: false) + if pytest_check.returncode() != 0 + message(pytest_check.stderr().strip()) + if pytestopt.enabled() + error('Additional Python packages are required to run the pytest suites.') + else + warning('Additional Python packages are required to run the pytest suites.') + endif + endif + + pytest = find_program(get_option('PYTEST'), native: true, required: pytestopt) + + if pytest.found() and pytest_check.returncode() == 0 + pytest_enabled = true + endif +endif + + + ############################################################### # Library: zstd ############################################################### @@ -3776,6 +3805,63 @@ foreach test_dir : tests ) endforeach install_suites += test_group + elif kind == 'pytest' + testwrap_pytest = testwrap_base + if not pytest_enabled + testwrap_pytest += ['--skip', 'pytest not enabled'] + endif + + test_command = [ + pytest.full_path(), + '-c', meson.project_source_root() / 'pytest.ini', + '--verbose', + '-p', 'pgtap', # enable our test reporter plugin + '-ra', # show skipped and xfailed tests too + ] + + # Add temporary install, the build directory for non-installed binaries and + # also test/ for non-installed test binaries built separately. + env = test_env + env.prepend('PATH', temp_install_bindir, test_dir['bd'], test_dir['bd'] / 'test') + temp_install_datadir = '@0@@1@'.format(test_install_destdir, dir_prefix / dir_data) + env.set('share_contrib_dir', temp_install_datadir / 'contrib') + env.prepend('PYTHONPATH', meson.project_source_root() / 'src' / 'test' / 'pytest' / 'plugins') + + foreach name, value : t.get('env', {}) + env.set(name, value) + endforeach + + test_group = test_dir['name'] + test_kwargs = { + 'protocol': 'tap', + 'suite': test_group, + 'timeout': 1000, + 'depends': test_deps + t.get('deps', []), + 'env': env, + } + t.get('test_kwargs', {}) + + foreach onetest : t['tests'] + # Make test names prettier, remove pyt/ and .py + onetest_p = onetest + if onetest_p.startswith('pyt/') + onetest_p = onetest.split('pyt/')[1] + endif + if onetest_p.endswith('.py') + onetest_p = fs.stem(onetest_p) + endif + + test(test_dir['name'] / onetest_p, + python, + kwargs: test_kwargs, + args: testwrap_pytest + [ + '--testgroup', test_dir['name'], + '--testname', onetest_p, + '--', test_command, + test_dir['sd'] / onetest, + ], + ) + endforeach + install_suites += test_group else error('unknown kind @0@ of test in @1@'.format(kind, test_dir['sd'])) endif @@ -3949,6 +4035,8 @@ summary( 'bison': '@0@ @1@'.format(bison.full_path(), bison_version), 'dtrace': dtrace, 'flex': '@0@ @1@'.format(flex.full_path(), flex_version), + 'prove': prove, + 'pytest': pytest, }, section: 'Programs', ) @@ -3985,3 +4073,12 @@ summary( section: 'External libraries', list_sep: ' ', ) + +summary( + { + 'tap': tap_tests_enabled, + 'pytest': pytest_enabled, + }, + section: 'Other features', + list_sep: ' ', +) diff --git a/meson_options.txt b/meson_options.txt index 06bf5627d3c03..88f22e699d918 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -41,7 +41,10 @@ option('cassert', type: 'boolean', value: false, description: 'Enable assertion checks (for debugging)') option('tap_tests', type: 'feature', value: 'auto', - description: 'Enable TAP tests') + description: 'Enable (Perl-based) TAP tests') + +option('pytest', type: 'feature', value: 'auto', + description: 'Enable (Python-based) pytest suites') option('injection_points', type: 'boolean', value: false, description: 'Enable injection points') @@ -195,6 +198,9 @@ option('PERL', type: 'string', value: 'perl', option('PROVE', type: 'string', value: 'prove', description: 'Path to prove binary') +option('PYTEST', type: 'array', value: ['pytest', 'py.test'], + description: 'Path to pytest binary') + option('PYTHON', type: 'array', value: ['python3', 'python'], description: 'Path to python binary') diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000..837097ba0bd42 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] + +minversion = 7.0 + +# Common test code can be found here. +pythonpath = src/test/pytest diff --git a/src/Makefile.global.in b/src/Makefile.global.in index 8b1b357beaa04..fc744166bd2de 100644 --- a/src/Makefile.global.in +++ b/src/Makefile.global.in @@ -211,6 +211,7 @@ enable_dtrace = @enable_dtrace@ enable_coverage = @enable_coverage@ enable_injection_points = @enable_injection_points@ enable_tap_tests = @enable_tap_tests@ +enable_pytest = @enable_pytest@ python_includespec = @python_includespec@ python_libdir = @python_libdir@ @@ -354,6 +355,7 @@ MSGFMT = @MSGFMT@ MSGFMT_FLAGS = @MSGFMT_FLAGS@ MSGMERGE = @MSGMERGE@ OPENSSL = @OPENSSL@ +PYTEST = @PYTEST@ PYTHON = @PYTHON@ TAR = @TAR@ XGETTEXT = @XGETTEXT@ @@ -508,6 +510,27 @@ prove_installcheck = @echo "TAP tests not enabled. Try configuring with --enable prove_check = $(prove_installcheck) endif +ifeq ($(enable_pytest),yes) + +pytest_installcheck = @echo "Installcheck is not currently supported for pytest." + +define pytest_check +echo "# +++ pytest check in $(subdir) +++" && \ +rm -rf '$(CURDIR)'/tmp_check && \ +$(MKDIR_P) '$(CURDIR)'/tmp_check && \ +cd $(srcdir) && \ + TESTLOGDIR='$(CURDIR)/tmp_check/log' \ + TESTDATADIR='$(CURDIR)/tmp_check' \ + PYTHONPATH='$(abs_top_srcdir)/src/test/pytest/plugins:$$PYTHONPATH' \ + $(with_temp_install) \ + $(PYTEST) -c '$(abs_top_srcdir)/pytest.ini' --verbose -ra ./pyt/ +endef + +else +pytest_installcheck = @echo "pytest is not enabled. Try configuring with --enable-pytest" +pytest_check = $(pytest_installcheck) +endif + # Installation. install_bin = @install_bin@ diff --git a/src/makefiles/meson.build b/src/makefiles/meson.build index 54dbc059adac7..f69eb1068db01 100644 --- a/src/makefiles/meson.build +++ b/src/makefiles/meson.build @@ -56,6 +56,7 @@ pgxs_kv = { 'enable_nls': libintl.found() ? 'yes' : 'no', 'enable_injection_points': get_option('injection_points') ? 'yes' : 'no', 'enable_tap_tests': tap_tests_enabled ? 'yes' : 'no', + 'enable_pytest': pytest_enabled ? 'yes' : 'no', 'enable_debug': get_option('debug') ? 'yes' : 'no', 'enable_coverage': 'no', 'enable_dtrace': dtrace.found() ? 'yes' : 'no', @@ -147,6 +148,7 @@ pgxs_bins = { 'OPENSSL': openssl, 'PERL': perl, 'PROVE': prove, + 'PYTEST': pytest, 'PYTHON': python, 'TAR': tar, 'ZSTD': program_zstd, diff --git a/src/test/Makefile b/src/test/Makefile index 511a72e6238a5..0be9771d71f5f 100644 --- a/src/test/Makefile +++ b/src/test/Makefile @@ -12,7 +12,16 @@ subdir = src/test top_builddir = ../.. include $(top_builddir)/src/Makefile.global -SUBDIRS = perl postmaster regress isolation modules authentication recovery subscription +SUBDIRS = \ + authentication \ + isolation \ + modules \ + perl \ + postmaster \ + pytest \ + recovery \ + regress \ + subscription ifeq ($(with_icu),yes) SUBDIRS += icu diff --git a/src/test/meson.build b/src/test/meson.build index ccc31d6a86a1b..d08a6ef61c229 100644 --- a/src/test/meson.build +++ b/src/test/meson.build @@ -5,6 +5,7 @@ subdir('isolation') subdir('authentication') subdir('postmaster') +subdir('pytest') subdir('recovery') subdir('subscription') subdir('modules') diff --git a/src/test/pytest/Makefile b/src/test/pytest/Makefile new file mode 100644 index 0000000000000..2bdca96ccbee3 --- /dev/null +++ b/src/test/pytest/Makefile @@ -0,0 +1,20 @@ +#------------------------------------------------------------------------- +# +# Makefile for pytest +# +# Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group +# Portions Copyright (c) 1994, Regents of the University of California +# +# src/test/pytest/Makefile +# +#------------------------------------------------------------------------- + +subdir = src/test/pytest +top_builddir = ../../.. +include $(top_builddir)/src/Makefile.global + +check: + $(pytest_check) + +clean distclean maintainer-clean: + rm -rf tmp_check diff --git a/src/test/pytest/README b/src/test/pytest/README new file mode 100644 index 0000000000000..1333ed77b7e1e --- /dev/null +++ b/src/test/pytest/README @@ -0,0 +1 @@ +TODO diff --git a/src/test/pytest/meson.build b/src/test/pytest/meson.build new file mode 100644 index 0000000000000..f53193e868680 --- /dev/null +++ b/src/test/pytest/meson.build @@ -0,0 +1,17 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +if not pytest_enabled + subdir_done() +endif + +tests += { + 'name': 'pytest', + 'sd': meson.current_source_dir(), + 'bd': meson.current_build_dir(), + 'pytest': { + 'tests': [ + 'pyt/test_something.py', + 'pyt/test_libpq.py', + ], + }, +} diff --git a/src/test/pytest/pg/__init__.py b/src/test/pytest/pg/__init__.py new file mode 100644 index 0000000000000..5dae49b6406e1 --- /dev/null +++ b/src/test/pytest/pg/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +from ._env import has_test_extra, require_test_extra +from ._win32 import current_windows_user diff --git a/src/test/pytest/pg/_env.py b/src/test/pytest/pg/_env.py new file mode 100644 index 0000000000000..6f18af078449d --- /dev/null +++ b/src/test/pytest/pg/_env.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import logging +import os +from typing import List, Optional + +import pytest + +logger = logging.getLogger(__name__) + + +def has_test_extra(key: str) -> bool: + """ + Returns True if the PG_TEST_EXTRA environment variable contains the given + key. + """ + extra = os.getenv("PG_TEST_EXTRA", "") + return key in extra.split() + + +def require_test_extra(*keys: str) -> bool: + """ + A convenience annotation which will skip tests if all of the required keys + are not present in PG_TEST_EXTRA. + + To skip a particular test function or class: + + @pg.require_test_extra("ldap") + def test_some_ldap_feature(): + ... + + To skip an entire module: + + pytestmark = pg.require_test_extra("ssl", "kerberos") + """ + return pytest.mark.skipif( + not all([has_test_extra(k) for k in keys]), + reason="requires {} to be set in PG_TEST_EXTRA".format(", ".join(keys)), + ) + + +def test_timeout_default() -> int: + """ + Returns the value of the PG_TEST_TIMEOUT_DEFAULT environment variable, in + seconds, or 180 if one was not provided. + """ + default = os.getenv("PG_TEST_TIMEOUT_DEFAULT", "") + if not default: + return 180 + + try: + return int(default) + except ValueError as v: + logger.warning("PG_TEST_TIMEOUT_DEFAULT could not be parsed: " + str(v)) + return 180 diff --git a/src/test/pytest/pg/_win32.py b/src/test/pytest/pg/_win32.py new file mode 100644 index 0000000000000..3fd67b101912f --- /dev/null +++ b/src/test/pytest/pg/_win32.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import ctypes +import platform + + +def current_windows_user(): + """ + A port of pg_regress.c's current_windows_user() helper. Returns + (accountname, domainname). + + XXX This is dead code now, but I'm keeping it as a motivating example of + Win32 interaction, and someone may find it useful in the future when writing + SSPI tests? + """ + try: + advapi32 = ctypes.windll.advapi32 + kernel32 = ctypes.windll.kernel32 + except AttributeError: + raise RuntimeError( + f"current_windows_user() is not supported on {platform.system()}" + ) + + def raise_winerror_when_false(result, func, arguments): + """ + A ctypes errcheck handler that raises WinError (which will contain the + result of GetLastError()) when the function's return value is false. + """ + if not result: + raise ctypes.WinError() + + # + # Function Prototypes + # + + from ctypes import wintypes + + # GetCurrentProcess + kernel32.GetCurrentProcess.restype = wintypes.HANDLE + kernel32.GetCurrentProcess.argtypes = [] + + # OpenProcessToken + TOKEN_READ = 0x00020008 + + advapi32.OpenProcessToken.restype = wintypes.BOOL + advapi32.OpenProcessToken.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD, + wintypes.PHANDLE, + ] + advapi32.OpenProcessToken.errcheck = raise_winerror_when_false + + # GetTokenInformation + PSID = wintypes.LPVOID # we don't need the internals + TOKEN_INFORMATION_CLASS = wintypes.INT + TokenUser = 1 + + class SID_AND_ATTRIBUTES(ctypes.Structure): + _fields_ = [ + ("Sid", PSID), + ("Attributes", wintypes.DWORD), + ] + + class TOKEN_USER(ctypes.Structure): + _fields_ = [ + ("User", SID_AND_ATTRIBUTES), + ] + + advapi32.GetTokenInformation.restype = wintypes.BOOL + advapi32.GetTokenInformation.argtypes = [ + wintypes.HANDLE, + TOKEN_INFORMATION_CLASS, + wintypes.LPVOID, + wintypes.DWORD, + wintypes.PDWORD, + ] + advapi32.GetTokenInformation.errcheck = raise_winerror_when_false + + # LookupAccountSid + SID_NAME_USE = wintypes.INT + PSID_NAME_USE = ctypes.POINTER(SID_NAME_USE) + + advapi32.LookupAccountSidW.restype = wintypes.BOOL + advapi32.LookupAccountSidW.argtypes = [ + wintypes.LPCWSTR, + PSID, + wintypes.LPWSTR, + wintypes.LPDWORD, + wintypes.LPWSTR, + wintypes.LPDWORD, + PSID_NAME_USE, + ] + advapi32.LookupAccountSidW.errcheck = raise_winerror_when_false + + # + # Implementation (see pg_SSPI_recv_auth()) + # + + # Get the current process token... + token = wintypes.HANDLE() + proc = kernel32.GetCurrentProcess() + advapi32.OpenProcessToken(proc, TOKEN_READ, token) + + # ...then read the TOKEN_USER struct for that token... + info = TOKEN_USER() + infolen = wintypes.DWORD() + + try: + # (GetTokenInformation creates a buffer bigger than TOKEN_USER, so we + # have to query the correct length first.) + advapi32.GetTokenInformation(token, TokenUser, None, 0, ctypes.byref(infolen)) + assert False, "GetTokenInformation succeeded unexpectedly" + + except OSError as err: + assert err.winerror == 122 # insufficient buffer + + ctypes.resize(info, infolen.value) + advapi32.GetTokenInformation( + token, + TokenUser, + ctypes.byref(info), + ctypes.sizeof(info), + ctypes.byref(infolen), + ) + + # ...then pull the account and domain names out of the user SID. + MAXPGPATH = 1024 + + account = ctypes.create_unicode_buffer(MAXPGPATH) + domain = ctypes.create_unicode_buffer(MAXPGPATH) + accountlen = wintypes.DWORD(ctypes.sizeof(account)) + domainlen = wintypes.DWORD(ctypes.sizeof(domain)) + use = SID_NAME_USE() + + advapi32.LookupAccountSidW( + None, + info.User.Sid, + account, + ctypes.byref(accountlen), + domain, + ctypes.byref(domainlen), + ctypes.byref(use), + ) + + return (account.value, domain.value) diff --git a/src/test/pytest/pg/fixtures.py b/src/test/pytest/pg/fixtures.py new file mode 100644 index 0000000000000..b5d3bff69a832 --- /dev/null +++ b/src/test/pytest/pg/fixtures.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import ctypes +import platform +import time +from typing import Any, Callable, Dict + +import pytest + +from ._env import test_timeout_default + + +@pytest.fixture +def remaining_timeout(): + """ + This fixture provides a function that returns how much of the + PG_TEST_TIMEOUT_DEFAULT remains for the current test, in fractional seconds. + This value is never less than zero. + + This fixture is per-test, so the deadline is also reset on a per-test basis. + """ + now = time.monotonic() + deadline = now + test_timeout_default() + + return lambda: max(deadline - time.monotonic(), 0) + + +class _PGconn(ctypes.Structure): + pass + + +class _PGresult(ctypes.Structure): + pass + + +_PGconn_p = ctypes.POINTER(_PGconn) +_PGresult_p = ctypes.POINTER(_PGresult) + + +@pytest.fixture(scope="session") +def libpq_handle(): + """ + Loads a ctypes handle for libpq. Some common function prototypes are + initialized for general use. + """ + system = platform.system() + + if system in ("Linux", "FreeBSD", "NetBSD", "OpenBSD"): + name = "libpq.so.5" + elif system == "Darwin": + name = "libpq.5.dylib" + elif system == "Windows": + name = "libpq.dll" + else: + assert False, f"the libpq fixture must be updated for {system}" + + # XXX ctypes.CDLL() is a little stricter with load paths on Windows. The + # preferred way around that is to know the absolute path to libpq.dll, but + # that doesn't seem to mesh well with the current test infrastructure. For + # now, enable "standard" LoadLibrary behavior. + loadopts = {} + if system == "Windows": + loadopts["winmode"] = 0 + + lib = ctypes.CDLL(name, **loadopts) + + # + # Function Prototypes + # + + lib.PQconnectdb.restype = _PGconn_p + lib.PQconnectdb.argtypes = [ctypes.c_char_p] + + lib.PQstatus.restype = ctypes.c_int + lib.PQstatus.argtypes = [_PGconn_p] + + lib.PQexec.restype = _PGresult_p + lib.PQexec.argtypes = [_PGconn_p, ctypes.c_char_p] + + lib.PQresultStatus.restype = ctypes.c_int + lib.PQresultStatus.argtypes = [_PGresult_p] + + lib.PQclear.restype = None + lib.PQclear.argtypes = [_PGresult_p] + + lib.PQerrorMessage.restype = ctypes.c_char_p + lib.PQerrorMessage.argtypes = [_PGconn_p] + + lib.PQfinish.restype = None + lib.PQfinish.argtypes = [_PGconn_p] + + return lib + + +class PGresult(contextlib.AbstractContextManager): + """Wraps a raw _PGresult_p with a more friendly interface.""" + + def __init__(self, lib: ctypes.CDLL, res: _PGresult_p): + self._lib = lib + self._res = res + + def __exit__(self, *exc): + self._lib.PQclear(self._res) + self._res = None + + def status(self): + return self._lib.PQresultStatus(self._res) + + +class PGconn(contextlib.AbstractContextManager): + """ + Wraps a raw _PGconn_p with a more friendly interface. This is just a + stub; it's expected to grow. + """ + + def __init__( + self, + lib: ctypes.CDLL, + handle: _PGconn_p, + stack: contextlib.ExitStack, + ): + self._lib = lib + self._handle = handle + self._stack = stack + + def __exit__(self, *exc): + self._lib.PQfinish(self._handle) + self._handle = None + + def exec(self, query: str) -> PGresult: + """ + Executes a query via PQexec() and returns a PGresult. + """ + res = self._lib.PQexec(self._handle, query.encode()) + return self._stack.enter_context(PGresult(self._lib, res)) + + +@pytest.fixture +def libpq(libpq_handle, remaining_timeout): + """ + Provides a ctypes-based API wrapped around libpq.so. This fixture keeps + track of allocated resources and cleans them up during teardown. See + _Libpq's public API for details. + """ + + class _Libpq(contextlib.ExitStack): + CONNECTION_OK = 0 + + PGRES_EMPTY_QUERY = 0 + + class Error(RuntimeError): + """ + libpq.Error is the exception class for application-level errors that + are encountered during libpq operations. + """ + + pass + + def __init__(self): + super().__init__() + self.lib = libpq_handle + + def _connstr(self, opts: Dict[str, Any]) -> str: + """ + Flattens the provided options into a libpq connection string. Values + are converted to str and quoted/escaped as necessary. + """ + settings = [] + + for k, v in opts.items(): + v = str(v) + if not v: + v = "''" + else: + v = v.replace("\\", "\\\\") + v = v.replace("'", "\\'") + + if " " in v: + v = f"'{v}'" + + settings.append(f"{k}={v}") + + return " ".join(settings) + + def must_connect(self, **opts) -> PGconn: + """ + Connects to a server, using the given connection options, and + returns a libpq.PGconn object wrapping the connection handle. A + failure will raise libpq.Error. + + Connections honor PG_TEST_TIMEOUT_DEFAULT unless connect_timeout is + explicitly overridden in opts. + """ + + if "connect_timeout" not in opts: + t = int(remaining_timeout()) + opts["connect_timeout"] = max(t, 1) + + conn_p = self.lib.PQconnectdb(self._connstr(opts).encode()) + + # Ensure the connection handle is always closed at the end of the + # test. + conn = self.enter_context(PGconn(self.lib, conn_p, stack=self)) + + if self.lib.PQstatus(conn_p) != self.CONNECTION_OK: + raise self.Error(self.lib.PQerrorMessage(conn_p).decode()) + + return conn + + with _Libpq() as lib: + yield lib diff --git a/src/test/pytest/plugins/pgtap.py b/src/test/pytest/plugins/pgtap.py new file mode 100644 index 0000000000000..ef8291e291c0c --- /dev/null +++ b/src/test/pytest/plugins/pgtap.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import os +import sys +from typing import Optional + +import pytest + +# +# Helpers +# + + +class TAP: + """ + A basic API for reporting via the TAP protocol. + """ + + def __init__(self): + self.count = 0 + + # XXX interacts poorly with testwrap's boilerplate diagnostics + # self.print("TAP version 13") + + def expect(self, num: int): + self.print(f"1..{num}") + + def print(self, *args): + print(*args, file=sys.__stdout__) + + def ok(self, name: str): + self.count += 1 + self.print("ok", self.count, "-", name) + + def skip(self, name: str, reason: str): + self.count += 1 + self.print("ok", self.count, "-", name, "# skip", reason) + + def fail(self, name: str, details: str): + self.count += 1 + self.print("not ok", self.count, "-", name) + + # mtest has some odd behavior around TAP tests where it won't print + # diagnostics on failure if they're part of the stdout stream, so we + # might as well just dump the details directly to stderr instead. + print(details, file=sys.__stderr__) + + +tap = TAP() + + +class TestNotes: + """ + Annotations for a single test. The existing pytest hooks keep interesting + information somewhat separated across the different stages + (setup/test/teardown), so this class is used to correlate them. + """ + + skipped = False + skip_reason = None + + failed = False + details = "" + + +# Register a custom key in the stash dictionary for keeping our TestNotes. +notes_key = pytest.StashKey[TestNotes]() + + +# +# Hook Implementations +# + + +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config): + """ + Hijacks the standard streams as soon as possible during pytest startup. The + pytest-formatted output gets logged to file instead, and we'll use the + original sys.__stdout__/__stderr__ streams for the TAP protocol. + """ + logdir = os.getenv("TESTLOGDIR") + if not logdir: + raise RuntimeError("pgtap requires the TESTLOGDIR envvar to be set") + + os.makedirs(logdir) + logpath = os.path.join(logdir, "pytest.log") + sys.stdout = sys.stderr = open(logpath, "a", buffering=1) + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + """ + Suppresses nonzero exit codes due to failed tests. (In that case, we want + Meson to report a failure count, not a generic ERROR.) + """ + if exitstatus == pytest.ExitCode.TESTS_FAILED: + session.exitstatus = pytest.ExitCode.OK + + +@pytest.hookimpl +def pytest_collectreport(report): + # Include collection failures directly in Meson error output. + if report.failed: + print(report.longreprtext, file=sys.__stderr__) + + +@pytest.hookimpl +def pytest_internalerror(excrepr, excinfo): + # Include internal errors directly in Meson error output. + print(excrepr, file=sys.__stderr__) + + +# +# Hook Wrappers +# +# In pytest parlance, a "wrapper" for a hook can inspect and optionally modify +# existing hooks' behavior, but it does not replace the hook chain. This is done +# through a generator-style API which chains the hooks together (see the use of +# `yield`). +# + + +@pytest.hookimpl(hookwrapper=True) +def pytest_collection(session): + """Reports the number of gathered tests after collection is finished.""" + res = yield + tap.expect(session.testscollected) + return res + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + """ + Annotates a test item with our TestNotes and grabs relevant information for + reporting. + + This is called multiple times per test, so it's not correct to print the TAP + result here. (A test and its teardown stage can both fail, and we want to + see the details for both.) We instead combine all the information for use by + our pytest_runtest_protocol wrapper later on. + """ + res = yield + + if notes_key not in item.stash: + item.stash[notes_key] = TestNotes() + notes = item.stash[notes_key] + + report = res.get_result() + if report.passed: + pass # no annotation needed + + elif report.skipped: + notes.skipped = True + _, _, notes.skip_reason = report.longrepr + + elif report.failed: + notes.failed = True + + if not notes.details: + notes.details += "{:_^72}\n\n".format(f" {report.head_line} ") + + if report.when in ("setup", "teardown"): + notes.details += "\n{:_^72}\n\n".format( + f" Error during {report.when} of {report.head_line} " + ) + + notes.details += report.longreprtext + "\n" + + else: + raise RuntimeError("pytest_runtest_makereport received unknown test status") + + return res + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_protocol(item, nextitem): + """ + Reports the TAP result for this test item using our gathered TestNotes. + """ + res = yield + + assert notes_key in item.stash, "pgtap didn't annotate a test item?" + notes = item.stash[notes_key] + + if notes.failed: + tap.fail(item.nodeid, notes.details) + elif notes.skipped: + tap.skip(item.nodeid, notes.skip_reason) + else: + tap.ok(item.nodeid) + + return res diff --git a/src/test/pytest/pyt/conftest.py b/src/test/pytest/pyt/conftest.py new file mode 100644 index 0000000000000..ecb72be26d722 --- /dev/null +++ b/src/test/pytest/pyt/conftest.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +from pg.fixtures import * diff --git a/src/test/pytest/pyt/test_libpq.py b/src/test/pytest/pyt/test_libpq.py new file mode 100644 index 0000000000000..9f0857cc6124a --- /dev/null +++ b/src/test/pytest/pyt/test_libpq.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import os +import socket +import struct +import threading +from typing import Callable + +import pytest + + +@pytest.mark.parametrize( + "opts, expected", + [ + (dict(), ""), + (dict(port=5432), "port=5432"), + (dict(port=5432, dbname="postgres"), "port=5432 dbname=postgres"), + (dict(host=""), "host=''"), + (dict(host=" "), r"host=' '"), + (dict(keyword="'"), r"keyword=\'"), + (dict(keyword=" \\' "), r"keyword=' \\\' '"), + ], +) +def test_connstr(libpq, opts, expected): + """Tests the escape behavior for libpq._connstr().""" + assert libpq._connstr(opts) == expected + + +def test_must_connect_errors(libpq): + """Tests that must_connect() raises libpq.Error.""" + with pytest.raises(libpq.Error, match="invalid connection option"): + libpq.must_connect(some_unknown_keyword="whatever") + + +@pytest.fixture +def local_server(tmp_path, remaining_timeout): + """ + Opens up a local UNIX socket for mocking a Postgres server on a background + thread. See the _Server API for usage. + + This fixture requires AF_UNIX support; dependent tests will be skipped on + platforms that don't provide it. + """ + + try: + from socket import AF_UNIX + except ImportError: + pytest.skip("AF_UNIX not supported on this platform") + + class _Server(contextlib.ExitStack): + """ + Implementation class for local_server. See .background() for the primary + entry point for tests. Postgres clients may connect to this server via + local_server.host/local_server.port. + + _Server derives from contextlib.ExitStack to provide easy cleanup of + associated resources; see the documentation for that class for a full + explanation. + """ + + def __init__(self): + super().__init__() + + self.host = tmp_path + self.port = 5432 + + self._thread = None + self._thread_exc = None + self._listener = self.enter_context( + socket.socket(AF_UNIX, socket.SOCK_STREAM), + ) + + def bind_and_listen(self): + """ + Does the actual work of binding the UNIX socket using the Postgres + server conventions and listening for connections. + + The listen backlog is currently hardcoded to one. + """ + sockfile = self.host / ".s.PGSQL.{}".format(self.port) + + # Lock down the permissions on the new socket. + prev_mask = os.umask(0o077) + + # Bind (creating the socket file), and immediately register it for + # deletion from disk when the stack is cleaned up. + self._listener.bind(bytes(sockfile)) + self.callback(os.unlink, sockfile) + + os.umask(prev_mask) + + self._listener.listen(1) + + def background(self, fn: Callable[[socket.socket], None]) -> None: + """ + Accepts a client connection on a background thread and passes it to + the provided callback. Any exceptions raised from the callback will + be re-raised on the main thread during fixture teardown. + + Blocking operations on the connected socket default to using the + remaining_timeout(), though this can be changed by the test via the + socket's .settimeout(). + """ + + def _bg(): + try: + self._listener.settimeout(remaining_timeout()) + sock, _ = self._listener.accept() + + with sock: + sock.settimeout(remaining_timeout()) + fn(sock) + + except Exception as e: + # Save the exception for re-raising on the main thread. + self._thread_exc = e + + # TODO: rather than using callback(), consider explicitly signaling + # the fn() implementation to stop early if we get an exception. + # Otherwise we'll hang until the end of the timeout. + self._thread = threading.Thread(target=_bg) + self.callback(self._join) + + self._thread.start() + + def _join(self): + """ + Waits for the background thread to finish and raises any thrown + exception. This is called during fixture teardown. + """ + # Give a little bit of wiggle room on the join timeout, since we're + # racing against the test's own use of remaining_timeout(). (It's + # preferable to let tests report timeouts; the stack traces will + # help with debugging.) + self._thread.join(remaining_timeout() + 1) + if self._thread.is_alive(): + raise TimeoutError("background thread is still running after timeout") + + if self._thread_exc is not None: + raise self._thread_exc + + with _Server() as s: + s.bind_and_listen() + yield s + + +def test_connection_is_finished_on_error(libpq, local_server, remaining_timeout): + """Tests that PQfinish() gets called at the end of testing.""" + expected_error = "something is wrong" + + def serve_error(s: socket.socket) -> None: + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Quick check for the startup packet version. + version = struct.unpack("!HH", s.recv(4)) + assert version == (3, 0) + + # Discard the remainder of the startup packet and send a v2 error. + s.recv(pktlen - 8) + s.send(b"E" + expected_error.encode() + b"\0") + + # And now the socket should be closed. + assert not s.recv(1), "client sent unexpected data" + + local_server.background(serve_error) + + with pytest.raises(libpq.Error, match=expected_error): + # Exiting this context should result in PQfinish(). + with libpq: + libpq.must_connect(host=local_server.host, port=local_server.port) diff --git a/src/test/pytest/pyt/test_something.py b/src/test/pytest/pyt/test_something.py new file mode 100644 index 0000000000000..5bd4561851202 --- /dev/null +++ b/src/test/pytest/pyt/test_something.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import pytest + + +@pytest.fixture +def hey(): + yield + raise "uh-oh" + + +def test_something(hey): + assert 2 == 4 + + +def test_something_else(): + assert 2 == 2 diff --git a/src/test/ssl/Makefile b/src/test/ssl/Makefile index e8a1639db2d3d..895ea5ea41cf0 100644 --- a/src/test/ssl/Makefile +++ b/src/test/ssl/Makefile @@ -30,6 +30,8 @@ clean distclean: # Doesn't depend on sslfiles because we don't rebuild them by default check: $(prove_check) + # XXX these suites should run independently, not serially + $(pytest_check) installcheck: $(prove_installcheck) diff --git a/src/test/ssl/meson.build b/src/test/ssl/meson.build index d8e0fb518e0a2..a0ee2af0899cf 100644 --- a/src/test/ssl/meson.build +++ b/src/test/ssl/meson.build @@ -15,4 +15,10 @@ tests += { 't/003_sslinfo.pl', ], }, + 'pytest': { + 'tests': [ + 'pyt/test_client.py', + 'pyt/test_server.py', + ], + }, } diff --git a/src/test/ssl/pyt/conftest.py b/src/test/ssl/pyt/conftest.py new file mode 100644 index 0000000000000..85d2c99482891 --- /dev/null +++ b/src/test/ssl/pyt/conftest.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import datetime +import os +import pathlib +import platform +import secrets +import socket +import subprocess +import tempfile +from collections import namedtuple + +import pytest + +import pg +from pg.fixtures import * + + +@pytest.fixture(scope="session") +def cryptography(): + return pytest.importorskip("cryptography", "3.3.2") + + +Cert = namedtuple("Cert", "cert, certpath, key, keypath") + + +@pytest.fixture(scope="session") +def certs(cryptography, tmp_path_factory): + """ + Caches commonly used certificates at the session level, and provides a way + to create new ones. + + - certs.ca: the root CA certificate + + - certs.server: the "standard" server certficate, signed by certs.ca + + - certs.server_host: the hostname of the certs.server certificate + + - certs.new(): creates a custom certificate, signed by certs.ca + """ + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + tmpdir = tmp_path_factory.mktemp("test-certs") + + class _Certs: + def __init__(self): + self.ca = self.new( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "PG pytest CA")], + ), + ca=True, + ) + + self.server_host = "example.org" + self.server = self.new( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, self.server_host)], + ) + ) + + def new(self, subject: x509.Name, *, ca=False) -> Cert: + """ + Creates and signs a new Cert with the given subject name. If ca is + True, the certificate will be self-signed; otherwise the certificate + is signed by self.ca. + """ + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + builder = x509.CertificateBuilder() + now = datetime.datetime.now(datetime.timezone.utc) + + builder = ( + builder.subject_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(hours=1)) + ) + + if ca: + builder = builder.issuer_name(subject) + else: + builder = builder.issuer_name(self.ca.cert.subject) + + builder = builder.add_extension( + x509.BasicConstraints(ca=ca, path_length=None), + critical=True, + ) + + cert = builder.sign( + private_key=key if ca else self.ca.key, + algorithm=hashes.SHA256(), + ) + + # Dump the certificate and key to file. + keypath = self._tofile( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ), + suffix=".key", + ) + certpath = self._tofile( + cert.public_bytes(serialization.Encoding.PEM), + suffix="-ca.crt" if ca else ".crt", + ) + + return Cert( + cert=cert, + certpath=certpath, + key=key, + keypath=keypath, + ) + + def _tofile(self, data: bytes, *, suffix) -> str: + """ + Dumps data to a file on disk with the requested suffix and returns + the path. The file is located somewhere in pytest's temporary + directory root. + """ + f = tempfile.NamedTemporaryFile(suffix=suffix, dir=tmpdir, delete=False) + with f: + f.write(data) + + return f.name + + return _Certs() + + +@pytest.fixture(scope="session") +def datadir(tmp_path_factory): + """ + Returns the directory name to use as the server data directory. If + TESTDATADIR is provided, that will be used; otherwise a new temporary + directory is created in the pytest temp root. + """ + d = os.getenv("TESTDATADIR") + if d: + d = pathlib.Path(d) + else: + d = tmp_path_factory.mktemp("tmp_check") + + return d + + +@pytest.fixture(scope="session") +def sockdir(tmp_path_factory): + """ + Returns the directory name to use as the server's unix_socket_directories + setting. Local client connections use this as the PGHOST. + + At the moment, this is always put under the pytest temp root. + """ + return tmp_path_factory.mktemp("sockfiles") + + +@pytest.fixture(scope="session") +def winpassword(): + """The per-session SCRAM password for the server admin on Windows.""" + return secrets.token_urlsafe(16) + + +@pytest.fixture(scope="session") +def server_instance(certs, datadir, sockdir, winpassword): + """ + Starts a running Postgres server listening on localhost. The HBA initially + allows only local UNIX connections from the same user. + + TODO: when installcheck is supported, this should optionally point to the + currently running server instead. + """ + + # Lock down the HBA by default; tests can open it back up later. + if platform.system() == "Windows": + # On Windows, for admin connections, use SCRAM with a generated password + # over local sockets. This requires additional work during initdb. + method = "scram-sha-256" + + # NamedTemporaryFile doesn't work very nicely on Windows until Python + # 3.12, which introduces NamedTemporaryFile(delete_on_close=False). + # Until then, specify delete=False and manually unlink after use. + with tempfile.NamedTemporaryFile("w", delete=False) as pwfile: + pwfile.write(winpassword) + + subprocess.check_call( + ["initdb", "--auth=scram-sha-256", "--pwfile", pwfile.name, datadir] + ) + os.unlink(pwfile.name) + + else: + # For other OSes we can just use peer auth. + method = "peer" + subprocess.check_call(["pg_ctl", "-D", datadir, "init"]) + + with open(datadir / "pg_hba.conf", "w") as f: + print(f"# default: local {method} connections only", file=f) + print(f"local all all {method}", file=f) + + # Figure out a port to listen on. Attempt to reserve both IPv4 and IPv6 + # addresses in one go. + # + # Note: socket.has_dualstack_ipv6/create_server are only in Python 3.8+. + if hasattr(socket, "has_dualstack_ipv6") and socket.has_dualstack_ipv6(): + addr = ("::1", 0) + s = socket.create_server(addr, family=socket.AF_INET6, dualstack_ipv6=True) + + hostaddr, port, _, _ = s.getsockname() + addrs = [hostaddr, "127.0.0.1"] + + else: + addr = ("127.0.0.1", 0) + + s = socket.socket() + s.bind(addr) + + hostaddr, port = s.getsockname() + addrs = [hostaddr] + + log = os.path.join(datadir, "postgresql.log") + + with s, open(os.path.join(datadir, "postgresql.conf"), "a") as f: + print(file=f) + print("unix_socket_directories = '{}'".format(sockdir.as_posix()), file=f) + print("listen_addresses = '{}'".format(",".join(addrs)), file=f) + print("port =", port, file=f) + print("log_connections = all", file=f) + + # Between closing of the socket, s, and server start, we're racing against + # anything that wants to open up ephemeral ports, so try not to put any new + # work here. + + subprocess.check_call(["pg_ctl", "-D", datadir, "-l", log, "start"]) + yield (hostaddr, port) + subprocess.check_call(["pg_ctl", "-D", datadir, "-l", log, "stop"]) diff --git a/src/test/ssl/pyt/test_client.py b/src/test/ssl/pyt/test_client.py new file mode 100644 index 0000000000000..28110ae07178e --- /dev/null +++ b/src/test/ssl/pyt/test_client.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import ctypes +import socket +import ssl +import struct +import threading +from typing import Callable + +import pytest + +import pg + +# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl. +pytestmark = pg.require_test_extra("ssl") + + +@pytest.fixture(scope="session", autouse=True) +def skip_if_no_ssl_support(libpq_handle): + """Skips tests if SSL support is not configured.""" + + # Declare PQsslAttribute(). + PQsslAttribute = libpq_handle.PQsslAttribute + PQsslAttribute.restype = ctypes.c_char_p + PQsslAttribute.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + + if not PQsslAttribute(None, b"library"): + pytest.skip("requires SSL support to be configured") + + +# +# Test Fixtures +# + + +@pytest.fixture +def tcp_server_class(remaining_timeout): + """ + Metafixture to combine related logic for tcp_server and ssl_server. + + TODO: combine with test_libpq.local_server + """ + + class _TCPServer(contextlib.ExitStack): + """ + Implementation class for tcp_server. See .background() for the primary + entry point for tests. Postgres clients may connect to this server via + **tcp_server.conninfo. + + _TCPServer derives from contextlib.ExitStack to provide easy cleanup of + associated resources; see the documentation for that class for a full + explanation. + """ + + def __init__(self): + super().__init__() + + self._thread = None + self._thread_exc = None + self._listener = self.enter_context( + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ) + + self._bind_and_listen() + sockname = self._listener.getsockname() + self.conninfo = dict( + hostaddr=sockname[0], + port=sockname[1], + ) + + def _bind_and_listen(self): + """ + Does the actual work of binding the socket and listening for + connections. + + The listen backlog is currently hardcoded to one. + """ + self._listener.bind(("127.0.0.1", 0)) + self._listener.listen(1) + + def background(self, fn: Callable[[socket.socket], None]) -> None: + """ + Accepts a client connection on a background thread and passes it to + the provided callback. Any exceptions raised from the callback will + be re-raised on the main thread during fixture teardown. + + Blocking operations on the connected socket default to using the + remaining_timeout(), though this can be changed by the test via the + socket's .settimeout(). + """ + + def _bg(): + try: + self._listener.settimeout(remaining_timeout()) + sock, _ = self._listener.accept() + + with sock: + sock.settimeout(remaining_timeout()) + fn(sock) + + except Exception as e: + # Save the exception for re-raising on the main thread. + self._thread_exc = e + + # TODO: rather than using callback(), consider explicitly signaling + # the fn() implementation to stop early if we get an exception. + # Otherwise we'll hang until the end of the timeout. + self._thread = threading.Thread(target=_bg) + self.callback(self._join) + + self._thread.start() + + def _join(self): + """ + Waits for the background thread to finish and raises any thrown + exception. This is called during fixture teardown. + """ + # Give a little bit of wiggle room on the join timeout, since we're + # racing against the test's own use of remaining_timeout(). (It's + # preferable to let tests report timeouts; the stack traces will + # help with debugging.) + self._thread.join(remaining_timeout() + 1) + if self._thread.is_alive(): + raise TimeoutError("background thread is still running after timeout") + + if self._thread_exc is not None: + raise self._thread_exc + + return _TCPServer + + +@pytest.fixture +def tcp_server(tcp_server_class): + """ + Opens up a local TCP socket for mocking a Postgres server on a background + thread. See the _TCPServer API for usage. + """ + with tcp_server_class() as s: + yield s + + +@pytest.fixture +def ssl_server(tcp_server_class, certs): + """ + Like tcp_server, but with an additional .background_ssl() method which will + perform a SSLRequest handshake on the socket before handing the connection + to the test callback. + + This server uses certs.server as its identity. + """ + + class _SSLServer(tcp_server_class): + def __init__(self): + super().__init__() + + self.conninfo["host"] = certs.server_host + + self._ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self._ctx.load_cert_chain(certs.server.certpath, certs.server.keypath) + + def background_ssl(self, fn: Callable[[ssl.SSLSocket], None]) -> None: + """ + Invokes a server callback as with .background(), but an SSLRequest + handshake is performed first, and the socket provided to the + callback has been wrapped in an OpenSSL layer. + """ + + def handshake(s: socket.socket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Make sure we get an SSLRequest. + version = struct.unpack("!HH", s.recv(4)) + assert version == (1234, 5679) + assert pktlen == 8 + + # Accept the SSLRequest. + s.send(b"S") + + with self._ctx.wrap_socket(s, server_side=True) as wrapped: + fn(wrapped) + + self.background(handshake) + + with _SSLServer() as s: + yield s + + +# +# Tests +# + + +@pytest.mark.parametrize("sslmode", ("require", "verify-ca", "verify-full")) +def test_server_with_ssl_disabled(libpq, tcp_server, certs, sslmode): + """ + Make sure client refuses to talk to non-SSL servers with stricter + sslmodes. + """ + + def refuse_ssl(s: socket.socket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Make sure we get an SSLRequest. + version = struct.unpack("!HH", s.recv(4)) + assert version == (1234, 5679) + assert pktlen == 8 + + # Refuse the SSLRequest. + s.send(b"N") + + # Wait for the client to close the connection. + assert not s.recv(1), "client sent unexpected data" + + tcp_server.background(refuse_ssl) + + with pytest.raises(libpq.Error, match="server does not support SSL"): + with libpq: # XXX tests shouldn't need to do this + libpq.must_connect( + **tcp_server.conninfo, + sslrootcert=certs.ca.certpath, + sslmode=sslmode, + ) + + +def test_verify_full_connection(libpq, ssl_server, certs): + """Completes a verify-full connection and empty query.""" + + def handle_empty_query(s: ssl.SSLSocket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Check the startup packet version, then discard the remainder. + version = struct.unpack("!HH", s.recv(4)) + assert version == (3, 0) + s.recv(pktlen - 8) + + # Send the required litany of server messages. + s.send(struct.pack("!cII", b"R", 8, 0)) # AuthenticationOK + + # ParameterStatus: client_encoding + key = b"client_encoding\0" + val = b"UTF-8\0" + s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val) + + # ParameterStatus: DateStyle + key = b"DateStyle\0" + val = b"ISO, MDY\0" + s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val) + + s.send(struct.pack("!cIII", b"K", 12, 1234, 1234)) # BackendKeyData + s.send(struct.pack("!cIc", b"Z", 5, b"I")) # ReadyForQuery + + # Expect an empty query. + pkttype = s.recv(1) + assert pkttype == b"Q" + pktlen = struct.unpack("!I", s.recv(4))[0] + assert s.recv(pktlen - 4) == b"\0" + + # Send an EmptyQueryResponse+ReadyForQuery. + s.send(struct.pack("!cI", b"I", 4)) + s.send(struct.pack("!cIc", b"Z", 5, b"I")) + + # libpq should terminate and close the connection. + assert s.recv(1) == b"X" + pktlen = struct.unpack("!I", s.recv(4))[0] + assert pktlen == 4 + + assert not s.recv(1), "client sent unexpected data" + + ssl_server.background_ssl(handle_empty_query) + + conn = libpq.must_connect( + **ssl_server.conninfo, + sslrootcert=certs.ca.certpath, + sslmode="verify-full", + ) + with conn: + assert conn.exec("").status() == libpq.PGRES_EMPTY_QUERY diff --git a/src/test/ssl/pyt/test_server.py b/src/test/ssl/pyt/test_server.py new file mode 100644 index 0000000000000..2d0be735371eb --- /dev/null +++ b/src/test/ssl/pyt/test_server.py @@ -0,0 +1,538 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import os +import pathlib +import platform +import re +import shutil +import socket +import ssl +import struct +import subprocess +import tempfile +from collections import namedtuple +from typing import Dict, List, Union + +import pytest + +import pg + +# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl. +pytestmark = pg.require_test_extra("ssl") + + +# +# Test Fixtures +# + + +@pytest.fixture(scope="session") +def connenv(server_instance, sockdir, datadir): + """ + Provides the values for several PG* environment variables needed for our + utility programs to connect to the server_instance. + """ + return { + "PGHOST": str(sockdir), + "PGPORT": str(server_instance[1]), + "PGDATABASE": "postgres", + "PGDATA": str(datadir), + } + + +class FileBackup(contextlib.AbstractContextManager): + """ + A context manager which backs up a file's contents, restoring them on exit. + """ + + def __init__(self, file: pathlib.Path): + super().__init__() + + self._file = file + + def __enter__(self): + with tempfile.NamedTemporaryFile( + prefix=self._file.name, dir=self._file.parent, delete=False + ) as f: + self._backup = pathlib.Path(f.name) + + shutil.copyfile(self._file, self._backup) + + return self + + def __exit__(self, *exc): + # Swap the backup and the original file, so that the modified contents + # can still be inspected in case of failure. + # + # TODO: this is less helpful if there are multiple layers, because it's + # not clear which backup to look at. Can the backup name be printed as + # part of the failed test output? Should we only swap on test failure? + tmp = self._backup.parent / (self._backup.name + ".tmp") + + shutil.copyfile(self._file, tmp) + shutil.copyfile(self._backup, self._file) + shutil.move(tmp, self._backup) + + +class HBA(FileBackup): + """ + Backs up a server's HBA configuration and provides means for temporarily + editing it. See also pg_server, which provides an instance of this class and + context managers for enforcing the reload/restart order of operations. + """ + + def __init__(self, datadir: pathlib.Path): + super().__init__(datadir / "pg_hba.conf") + + def prepend(self, *lines: Union[str, List[str]]): + """ + Temporarily prepends lines to the server's pg_hba.conf. + + As sugar for aligning HBA columns in the tests, each line can be either + a string or a list of strings. List elements will be joined by single + spaces before they are written to file. + """ + with open(self._file, "r") as f: + prior_data = f.read() + + with open(self._file, "w") as f: + for l in lines: + if isinstance(l, list): + print(*l, file=f) + else: + print(l, file=f) + + f.write(prior_data) + + +class Config(FileBackup): + """ + Backs up a server's postgresql.conf and provides means for temporarily + editing it. See also pg_server, which provides an instance of this class and + context managers for enforcing the reload/restart order of operations. + """ + + def __init__(self, datadir: pathlib.Path): + super().__init__(datadir / "postgresql.conf") + + def set(self, **gucs): + """ + Temporarily appends GUC settings to the server's postgresql.conf. + """ + + with open(self._file, "a") as f: + print(file=f) + + for n, v in gucs.items(): + v = str(v) + + # TODO: proper quoting + v = v.replace("\\", "\\\\") + v = v.replace("'", "\\'") + v = "'{}'".format(v) + + print(n, "=", v, file=f) + + +@pytest.fixture(scope="session") +def pg_server_session(server_instance, connenv, datadir, winpassword): + """ + Provides common routines for configuring and connecting to the + server_instance. For example: + + users = pg_server_session.create_users("one", "two") + dbs = pg_server_session.create_dbs("default") + + with pg_server_session.reloading() as s: + s.hba.prepend(["local", dbs["default"], users["two"], "peer"]) + + conn = connect_somehow(**pg_server_session.conninfo) + ... + + Attributes of note are + - .conninfo: provides TCP connection info for the server + + This fixture unwinds its configuration changes at the end of the pytest + session. For more granular changes, pg_server_session.subcontext() splits + off a "nested" context to allow smaller scopes. + """ + + class _Server(contextlib.ExitStack): + conninfo = dict( + hostaddr=server_instance[0], + port=server_instance[1], + ) + + # for _backup_configuration() + _Backup = namedtuple("Backup", "conf, hba") + + def subcontext(self): + """ + Creates a new server stack instance that can be tied to a smaller + scope than "session". + """ + # So far, there doesn't seem to be a need to link the two objects, + # since HBA/Config/FileBackup operate directly on the filesystem and + # will appear to "nest" naturally. + return self.__class__() + + def create_users(self, *userkeys: str) -> Dict[str, str]: + """ + Creates new users which will be dropped at the end of the server + context. + + For each provided key, a related user name will be selected and + stored in a map. This map is returned to let calling code look up + the selected usernames (instead of hardcoding them and potentially + stomping on an existing installation). + """ + usermap = {} + + for u in userkeys: + # TODO: use a uniquifier to support installcheck + name = u + "user" + usermap[u] = name + + # TODO: proper escaping + self.psql("-c", "CREATE USER " + name) + self.callback(self.psql, "-c", "DROP USER " + name) + + return usermap + + def create_dbs(self, *dbkeys: str) -> Dict[str, str]: + """ + Creates new databases which will be dropped at the end of the server + context. See create_users() for the meaning of the keys and returned + map. + """ + dbmap = {} + + for d in dbkeys: + # TODO: use a uniquifier to support installcheck + name = d + "db" + dbmap[d] = name + + # TODO: proper escaping + self.psql("-c", "CREATE DATABASE " + name) + self.callback(self.psql, "-c", "DROP DATABASE " + name) + + return dbmap + + @contextlib.contextmanager + def reloading(self): + """ + Provides a context manager for making configuration changes. + + If the context suite finishes successfully, the configuration will + be reloaded via pg_ctl. On teardown, the configuration changes will + be unwound, and the server will be signaled to reload again. + + The context target contains the following attributes which can be + used to configure the server: + - .conf: modifies postgresql.conf + - .hba: modifies pg_hba.conf + + For example: + + with pg_server_session.reloading() as s: + s.conf.set(log_connections="on") + s.hba.prepend("local all all trust") + """ + try: + # Push a reload onto the stack before making any other + # unwindable changes. That way the order of operations will be + # + # # test + # - config change 1 + # - config change 2 + # - reload + # # teardown + # - undo config change 2 + # - undo config change 1 + # - reload + # + self.callback(self.pg_ctl, "reload") + yield self._backup_configuration() + except: + # We only want to reload at the end of the suite if there were + # no errors. During exceptions, the pushed callback handles + # things instead, so there's nothing to do here. + raise + else: + # Suite completed successfully. + self.pg_ctl("reload") + + @contextlib.contextmanager + def restarting(self): + """Like .reloading(), but with a full server restart.""" + try: + self.callback(self.pg_ctl, "restart") + yield self._backup_configuration() + except: + raise + else: + self.pg_ctl("restart") + + def psql(self, *args): + """ + Runs psql with the given arguments. Password prompts are always + disabled. On Windows, the admin password will be included in the + environment. + """ + if platform.system() == "Windows": + pw = dict(PGPASSWORD=winpassword) + else: + pw = None + + self._run("psql", "-w", *args, addenv=pw) + + def pg_ctl(self, *args): + """ + Runs pg_ctl with the given arguments. Log output will be placed in + postgresql.log in the server's data directory. + + TODO: put the log in TESTLOGDIR + """ + self._run("pg_ctl", "-l", str(datadir / "postgresql.log"), *args) + + def _run(self, cmd, *args, addenv: dict = None): + # Override the existing environment with the connenv values and + # anything the caller wanted to add. (Python 3.9 gives us the + # less-ugly `os.environ | connenv` merge operator.) + subenv = dict(os.environ, **connenv) + if addenv: + subenv.update(addenv) + + subprocess.check_call([cmd, *args], env=subenv) + + def _backup_configuration(self): + # Wrap the existing HBA and configuration with FileBackups. + return self._Backup( + hba=self.enter_context(HBA(datadir)), + conf=self.enter_context(Config(datadir)), + ) + + with _Server() as s: + yield s + + +@pytest.fixture(scope="module", autouse=True) +def ssl_setup(pg_server_session, certs, datadir): + """ + Sets up required server settings for all tests in this module. The fixture + variable is a tuple (users, dbs) containing the user and database names that + have been chosen for the test session. + """ + try: + with pg_server_session.restarting() as s: + s.conf.set( + ssl="on", + ssl_ca_file=certs.ca.certpath, + ssl_cert_file=certs.server.certpath, + ssl_key_file=certs.server.keypath, + ) + + # Reject by default. + s.hba.prepend("hostssl all all all reject") + + except subprocess.CalledProcessError: + # This is a decent place to skip if the server isn't set up for SSL. + logpath = datadir / "postgresql.log" + unsupported = re.compile("SSL is not supported") + + with open(logpath, "r") as log: + for line in log: + if unsupported.search(line): + pytest.skip("the server does not support SSL") + + # Some other error happened. + raise + + users = pg_server_session.create_users( + "ssl", + ) + + dbs = pg_server_session.create_dbs( + "ssl", + ) + + return (users, dbs) + + +@pytest.fixture(scope="module") +def client_cert(ssl_setup, certs): + """ + Creates a Cert for the "ssl" user. + """ + from cryptography import x509 + from cryptography.x509.oid import NameOID + + users, _ = ssl_setup + user = users["ssl"] + + return certs.new(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, user)])) + + +@pytest.fixture +def pg_server(pg_server_session): + """ + A per-test instance of pg_server_session. Use this fixture to make changes + to the server which will be rolled back at the end of every test. + """ + with pg_server_session.subcontext() as s: + yield s + + +# +# Tests +# + + +# For use with the `creds` parameter below. +CLIENT = "client" +SERVER = "server" + + +@pytest.mark.parametrize( + # fmt: off + "auth_method, creds, expected_error", +[ + # Trust allows anything. + ("trust", None, None), + ("trust", CLIENT, None), + ("trust", SERVER, None), + + # verify-ca allows any CA-signed certificate. + ("trust clientcert=verify-ca", None, "requires a valid client certificate"), + ("trust clientcert=verify-ca", CLIENT, None), + ("trust clientcert=verify-ca", SERVER, None), + + # cert and verify-full allow only the correct certificate. + ("trust clientcert=verify-full", None, "requires a valid client certificate"), + ("trust clientcert=verify-full", CLIENT, None), + ("trust clientcert=verify-full", SERVER, "authentication failed for user"), + ("cert", None, "requires a valid client certificate"), + ("cert", CLIENT, None), + ("cert", SERVER, "authentication failed for user"), +], + # fmt: on +) +def test_direct_ssl_certificate_authentication( + pg_server, + ssl_setup, + certs, + client_cert, + remaining_timeout, + # test parameters + auth_method, + creds, + expected_error, +): + """ + Tests direct SSL connections with various client-certificate/HBA + combinations. + """ + + # Set up the HBA as desired by the test. + users, dbs = ssl_setup + + user = users["ssl"] + db = dbs["ssl"] + + with pg_server.reloading() as s: + s.hba.prepend( + ["hostssl", db, user, "127.0.0.1/32", auth_method], + ["hostssl", db, user, "::1/128", auth_method], + ) + + # Configure the SSL settings for the client. + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(cafile=certs.ca.certpath) + ctx.set_alpn_protocols(["postgresql"]) # for direct SSL + + # Load up a client certificate if required by the test. + if creds == CLIENT: + ctx.load_cert_chain(client_cert.certpath, client_cert.keypath) + elif creds == SERVER: + # Using a server certificate as the client credential is expected to + # work only for clientcert=verify-ca (and `trust`, naturally). + ctx.load_cert_chain(certs.server.certpath, certs.server.keypath) + + # Make a direct SSL connection. There's no SSLRequest in the handshake; we + # simply wrap a TCP connection with OpenSSL. + addr = (pg_server.conninfo["hostaddr"], pg_server.conninfo["port"]) + with socket.create_connection(addr) as s: + s.settimeout(remaining_timeout()) # XXX this resets every operation + + with ctx.wrap_socket(s, server_hostname=certs.server_host) as conn: + # Build and send the startup packet. + startup_options = dict( + user=user, + database=db, + application_name="pytest", + ) + + payload = b"" + for k, v in startup_options.items(): + payload += k.encode() + b"\0" + payload += str(v).encode() + b"\0" + payload += b"\0" # null terminator + + pktlen = 4 + 4 + len(payload) + conn.send(struct.pack("!IHH", pktlen, 3, 0) + payload) + + if not expected_error: + # Expect an AuthenticationOK to come back. + pkttype, pktlen = struct.unpack("!cI", conn.recv(5)) + assert pkttype == b"R" + assert pktlen == 8 + + authn_result = struct.unpack("!I", conn.recv(4))[0] + assert authn_result == 0 + + # Read and discard to ReadyForQuery. + while True: + pkttype, pktlen = struct.unpack("!cI", conn.recv(5)) + payload = conn.recv(pktlen - 4) + + if pkttype == b"Z": + assert payload == b"I" + break + + # Send an empty query. + conn.send(struct.pack("!cI", b"Q", 5) + b"\0") + + # Expect EmptyQueryResponse+ReadyForQuery. + pkttype, pktlen = struct.unpack("!cI", conn.recv(5)) + assert pkttype == b"I" + assert pktlen == 4 + + pkttype, pktlen = struct.unpack("!cI", conn.recv(5)) + assert pkttype == b"Z" + + payload = conn.recv(pktlen - 4) + assert payload == b"I" + + else: + # Match the expected authentication error. + pkttype, pktlen = struct.unpack("!cI", conn.recv(5)) + assert pkttype == b"E" + + payload = conn.recv(pktlen - 4) + msg = None + + for component in payload.split(b"\0"): + if not component: + break # end of message + + key, val = component[:1], component[1:] + if key == b"S": + assert val == b"FATAL" + elif key == b"M": + msg = val.decode() + + assert re.search(expected_error, msg), "server error did not match" + + # Terminate. + conn.send(struct.pack("!cI", b"X", 4))