diff --git a/.flake8 b/.flake8 index decf40da..5311c61e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,5 @@ [flake8] +select = C90,E,F,W,Y0 ignore = E402,E731,W503,W504,E252 -exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox +exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox +per-file-ignores = *.pyi: F401, F403, F405, F811, E127, E128, E203, E266, E301, E302, E305, E501, E701, E704, E741, B303, W503, W504 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index eef0799e..450f471e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,7 +22,7 @@ jobs: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} version_file: asyncpg/_version.py version_line_pattern: | - __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) + __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Stop if not approved if: steps.checkver.outputs.approved != 'true' diff --git a/.gitignore b/.gitignore index 21286094..a04d0b91 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ docs/_build /.eggs /.vscode /.mypy_cache +/.venv* +/.tox +/.vim diff --git a/.gitmodules b/.gitmodules index c8d0b650..9dc433a1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "asyncpg/pgproto"] path = asyncpg/pgproto - url = https://github.com/MagicStack/py-pgproto.git + url = https://github.com/bryanforbes/py-pgproto.git diff --git a/MANIFEST.in b/MANIFEST.in index 2389f6fa..a51fa57c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ recursive-include docs *.py *.rst Makefile *.css recursive-include examples *.py recursive-include tests *.py *.pem -recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h +recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h include LICENSE README.rst Makefile performance.png .flake8 +include asyncpg/py.typed diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index e8cd11eb..dff9f58f 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations from .connection import connect, Connection # NOQA from .exceptions import * # NOQA @@ -11,9 +12,9 @@ from .protocol import Record # NOQA from .types import * # NOQA - +from . import exceptions from ._version import __version__ # NOQA -__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection') +__all__ = ['connect', 'create_pool', 'Pool', 'Record', 'Connection'] __all__ += exceptions.__all__ # NOQA diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py index ad7dfd8c..b6f515d7 100644 --- a/asyncpg/_asyncio_compat.py +++ b/asyncpg/_asyncio_compat.py @@ -4,10 +4,15 @@ # # SPDX-License-Identifier: PSF-2.0 +from __future__ import annotations import asyncio import functools import sys +import typing + +if typing.TYPE_CHECKING: + from . import compat if sys.version_info < (3, 11): from async_timeout import timeout as timeout_ctx @@ -15,7 +20,12 @@ from asyncio import timeout as timeout_ctx -async def wait_for(fut, timeout): +_T = typing.TypeVar('_T') + + +async def wait_for( + fut: compat.Awaitable[_T], timeout: float | None +) -> _T: """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -65,7 +75,7 @@ async def wait_for(fut, timeout): return await fut -async def _cancel_and_wait(fut): +async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() @@ -82,6 +92,6 @@ async def _cancel_and_wait(fut): fut.remove_done_callback(cb) -def _release_waiter(waiter, *args): +def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None: if not waiter.done(): waiter.set_result(None) diff --git a/asyncpg/_version.py b/asyncpg/_version.py index 67fd67ab..383fe4d2 100644 --- a/asyncpg/_version.py +++ b/asyncpg/_version.py @@ -10,4 +10,8 @@ # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. -__version__ = '0.30.0.dev0' +from __future__ import annotations + +import typing + +__version__: typing.Final = '0.30.0.dev0' diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 4467cc2a..8615d228 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import os @@ -17,28 +18,46 @@ import tempfile import textwrap import time +import typing import asyncpg from asyncpg import serverversion +from asyncpg import exceptions + +if sys.version_info < (3, 12): + from typing_extensions import Unpack +else: + from typing import Unpack + +if typing.TYPE_CHECKING: + import _typeshed + from . import types + from . import connection -_system = platform.uname().system +class _ConnectionSpec(typing.TypedDict): + host: str + port: str + + +_system: typing.Final = platform.uname().system if _system == 'Windows': - def platform_exe(name): + def platform_exe(name: str) -> str: if name.endswith('.exe'): return name return name + '.exe' else: - def platform_exe(name): + def platform_exe(name: str) -> str: return name -def find_available_port(): +def find_available_port() -> int | None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind(('127.0.0.1', 0)) - return sock.getsockname()[1] + sock_name: tuple[str, int] = sock.getsockname() + return sock_name[1] except Exception: return None finally: @@ -50,7 +69,18 @@ class ClusterError(Exception): class Cluster: - def __init__(self, data_dir, *, pg_config_path=None): + _data_dir: str + _pg_config_path: str | None + _pg_bin_dir: str | None + _pg_ctl: str | None + _daemon_pid: int | None + _daemon_process: subprocess.Popen[bytes] | None + _connection_addr: _ConnectionSpec | None + _connection_spec_override: _ConnectionSpec | None + + def __init__( + self, data_dir: str, *, pg_config_path: str | None = None + ) -> None: self._data_dir = data_dir self._pg_config_path = pg_config_path self._pg_bin_dir = ( @@ -63,21 +93,21 @@ def __init__(self, data_dir, *, pg_config_path=None): self._connection_addr = None self._connection_spec_override = None - def get_pg_version(self): + def get_pg_version(self) -> types.ServerVersion: return self._pg_version - def is_managed(self): + def is_managed(self) -> bool: return True - def get_data_dir(self): + def get_data_dir(self) -> str: return self._data_dir - def get_status(self): + def get_status(self) -> str: if self._pg_ctl is None: self._init_env() process = subprocess.run( - [self._pg_ctl, 'status', '-D', self._data_dir], + [typing.cast(str, self._pg_ctl), 'status', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr @@ -96,15 +126,24 @@ def get_status(self): return self._test_connection(timeout=0) else: raise ClusterError( - 'pg_ctl status exited with status {:d}: {}'.format( + 'pg_ctl status exited with status {:d}: {!r}'.format( process.returncode, stderr)) - async def connect(self, loop=None, **kwargs): - conn_info = self.get_connection_spec() + async def connect( + self, + loop: asyncio.AbstractEventLoop | None = None, + **kwargs: object + ) -> connection.Connection[typing.Any]: + conn_info = typing.cast( + 'dict[str, typing.Any]', self.get_connection_spec() + ) conn_info.update(kwargs) - return await asyncpg.connect(loop=loop, **conn_info) + return typing.cast( + 'connection.Connection[typing.Any]', + await asyncpg.connect(loop=loop, **conn_info) + ) - def init(self, **settings): + def init(self, **settings: str) -> str: """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( @@ -123,8 +162,12 @@ def init(self, **settings): extra_args = [] process = subprocess.run( - [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + [ + typing.cast(str, self._pg_ctl), 'init', '-D', self._data_dir + ] + extra_args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ) output = process.stdout @@ -135,7 +178,13 @@ def init(self, **settings): return output.decode() - def start(self, wait=60, *, server_settings={}, **opts): + def start( + self, + wait: int = 60, + *, + server_settings: dict[str, str] = {}, + **opts: object + ) -> None: """Start the cluster.""" status = self.get_status() if status == 'running': @@ -178,17 +227,19 @@ def start(self, wait=60, *, server_settings={}, **opts): for k, v in server_settings.items(): extra_args.extend(['-c', '{}={}'.format(k, v)]) + pg_ctl = typing.cast(str, self._pg_ctl) + if _system == 'Windows': # On Windows we have to use pg_ctl as direct execution # of postgres daemon under an Administrative account # is not permitted and there is no easy way to drop # privileges. if os.getenv('ASYNCPG_DEBUG_SERVER'): - stdout = sys.stdout + stdout: int | typing.TextIO = sys.stdout print( 'asyncpg.cluster: Running', ' '.join([ - self._pg_ctl, 'start', '-D', self._data_dir, + pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args) ]), file=sys.stderr, @@ -197,7 +248,7 @@ def start(self, wait=60, *, server_settings={}, **opts): stdout = subprocess.DEVNULL process = subprocess.run( - [self._pg_ctl, 'start', '-D', self._data_dir, + [pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args)], stdout=stdout, stderr=subprocess.STDOUT) @@ -224,14 +275,14 @@ def start(self, wait=60, *, server_settings={}, **opts): self._test_connection(timeout=wait) - def reload(self): + def reload(self) -> None: """Reload server configuration.""" status = self.get_status() if status != 'running': raise ClusterError('cannot reload: cluster is not running') process = subprocess.run( - [self._pg_ctl, 'reload', '-D', self._data_dir], + [typing.cast(str, self._pg_ctl), 'reload', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = process.stderr @@ -241,11 +292,21 @@ def reload(self): 'pg_ctl stop exited with status {:d}: {}'.format( process.returncode, stderr.decode())) - def stop(self, wait=60): + def stop(self, wait: int = 60) -> None: process = subprocess.run( - [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait), - '-m', 'fast'], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + [ + typing.cast(str, self._pg_ctl), + 'stop', + '-D', + self._data_dir, + '-t', + str(wait), + '-m', + 'fast' + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) stderr = process.stderr @@ -258,14 +319,14 @@ def stop(self, wait=60): self._daemon_process.returncode is None): self._daemon_process.kill() - def destroy(self): + def destroy(self) -> None: status = self.get_status() if status == 'stopped' or status == 'not-initialized': shutil.rmtree(self._data_dir) else: raise ClusterError('cannot destroy {} cluster'.format(status)) - def _get_connection_spec(self): + def _get_connection_spec(self) -> _ConnectionSpec | None: if self._connection_addr is None: self._connection_addr = self._connection_addr_from_pidfile() @@ -277,17 +338,26 @@ def _get_connection_spec(self): else: return self._connection_addr - def get_connection_spec(self): + return None + + def get_connection_spec(self) -> _ConnectionSpec: status = self.get_status() if status != 'running': raise ClusterError('cluster is not running') - return self._get_connection_spec() + spec = self._get_connection_spec() + + if spec is None: + raise ClusterError('cannot determine server connection address') + + return spec - def override_connection_spec(self, **kwargs): - self._connection_spec_override = kwargs + def override_connection_spec(self, **kwargs: str) -> None: + self._connection_spec_override = typing.cast(_ConnectionSpec, kwargs) - def reset_wal(self, *, oid=None, xid=None): + def reset_wal( + self, *, oid: int | None = None, xid: int | None = None + ) -> None: status = self.get_status() if status == 'not-initialized': raise ClusterError( @@ -297,7 +367,7 @@ def reset_wal(self, *, oid=None, xid=None): raise ClusterError( 'cannot modify WAL status: cluster is running') - opts = [] + opts: list[str] = [] if oid is not None: opts.extend(['-o', str(oid)]) if xid is not None: @@ -323,7 +393,7 @@ def reset_wal(self, *, oid=None, xid=None): 'pg_resetwal exited with status {:d}: {}'.format( process.returncode, stderr.decode())) - def reset_hba(self): + def reset_hba(self) -> None: """Remove all records from pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': @@ -339,8 +409,16 @@ def reset_hba(self): raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e - def add_hba_entry(self, *, type='host', database, user, address=None, - auth_method, auth_options=None): + def add_hba_entry( + self, + *, + type: str = 'host', + database: str, + user: str, + address: str | None = None, + auth_method: str, + auth_options: dict[str, str] | None = None, + ) -> None: """Add a record to pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': @@ -365,7 +443,7 @@ def add_hba_entry(self, *, type='host', database, user, address=None, if auth_options is not None: record += ' ' + ' '.join( - '{}={}'.format(k, v) for k, v in auth_options) + '{}={}'.format(k, v) for k, v in auth_options.items()) try: with open(pg_hba, 'a') as f: @@ -374,7 +452,7 @@ def add_hba_entry(self, *, type='host', database, user, address=None, raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e - def trust_local_connections(self): + def trust_local_connections(self) -> None: self.reset_hba() if _system != 'Windows': @@ -390,7 +468,7 @@ def trust_local_connections(self): if status == 'running': self.reload() - def trust_local_replication_by(self, user): + def trust_local_replication_by(self, user: str) -> None: if _system != 'Windows': self.add_hba_entry(type='local', database='replication', user=user, auth_method='trust') @@ -404,7 +482,7 @@ def trust_local_replication_by(self, user): if status == 'running': self.reload() - def _init_env(self): + def _init_env(self) -> None: if not self._pg_bin_dir: pg_config = self._find_pg_config(self._pg_config_path) pg_config_data = self._run_pg_config(pg_config) @@ -418,7 +496,7 @@ def _init_env(self): self._postgres = self._find_pg_binary('postgres') self._pg_version = self._get_pg_version() - def _connection_addr_from_pidfile(self): + def _connection_addr_from_pidfile(self) -> _ConnectionSpec | None: pidfile = os.path.join(self._data_dir, 'postmaster.pid') try: @@ -464,7 +542,7 @@ def _connection_addr_from_pidfile(self): 'port': portnum } - def _test_connection(self, timeout=60): + def _test_connection(self, timeout: int = 60) -> str: self._connection_addr = None loop = asyncio.new_event_loop() @@ -478,17 +556,24 @@ def _test_connection(self, timeout=60): continue try: - con = loop.run_until_complete( - asyncpg.connect(database='postgres', - user='postgres', - timeout=5, loop=loop, - **self._connection_addr)) + con: connection.Connection[ + typing.Any + ] = loop.run_until_complete( + asyncpg.connect( + database='postgres', + user='postgres', + timeout=5, loop=loop, + **typing.cast( + _ConnectionSpec, self._connection_addr + ) + ) + ) except (OSError, asyncio.TimeoutError, - asyncpg.CannotConnectNowError, - asyncpg.PostgresConnectionError): + exceptions.CannotConnectNowError, + exceptions.PostgresConnectionError): time.sleep(1) continue - except asyncpg.PostgresError: + except exceptions.PostgresError: # Any other error other than ServerNotReadyError or # ConnectionError is interpreted to indicate the server is # up. @@ -501,16 +586,19 @@ def _test_connection(self, timeout=60): return 'running' - def _run_pg_config(self, pg_config_path): + def _run_pg_config(self, pg_config_path: str) -> dict[str, str]: process = subprocess.run( pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr if process.returncode != 0: - raise ClusterError('pg_config exited with status {:d}: {}'.format( - process.returncode, stderr)) + raise ClusterError( + 'pg_config exited with status {:d}: {!r}'.format( + process.returncode, stderr + ) + ) else: - config = {} + config: dict[str, str] = {} for line in stdout.splitlines(): k, eq, v = line.decode('utf-8').partition('=') @@ -519,7 +607,7 @@ def _run_pg_config(self, pg_config_path): return config - def _find_pg_config(self, pg_config_path): + def _find_pg_config(self, pg_config_path: str | None) -> str: if pg_config_path is None: pg_install = ( os.environ.get('PGINSTALLATION') @@ -529,7 +617,9 @@ def _find_pg_config(self, pg_config_path): pg_config_path = platform_exe( os.path.join(pg_install, 'pg_config')) else: - pathenv = os.environ.get('PATH').split(os.pathsep) + pathenv = typing.cast( + str, os.environ.get('PATH') + ).split(os.pathsep) for path in pathenv: pg_config_path = platform_exe( os.path.join(path, 'pg_config')) @@ -547,8 +637,10 @@ def _find_pg_config(self, pg_config_path): return pg_config_path - def _find_pg_binary(self, binary): - bpath = platform_exe(os.path.join(self._pg_bin_dir, binary)) + def _find_pg_binary(self, binary: str) -> str: + bpath = platform_exe( + os.path.join(typing.cast(str, self._pg_bin_dir), binary) + ) if not os.path.isfile(bpath): raise ClusterError( @@ -557,7 +649,7 @@ def _find_pg_binary(self, binary): return bpath - def _get_pg_version(self): + def _get_pg_version(self) -> types.ServerVersion: process = subprocess.run( [self._postgres, '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -565,7 +657,7 @@ def _get_pg_version(self): if process.returncode != 0: raise ClusterError( - 'postgres --version exited with status {:d}: {}'.format( + 'postgres --version exited with status {:d}: {!r}'.format( process.returncode, stderr)) version_string = stdout.decode('utf-8').strip(' \n') @@ -580,9 +672,14 @@ def _get_pg_version(self): class TempCluster(Cluster): - def __init__(self, *, - data_dir_suffix=None, data_dir_prefix=None, - data_dir_parent=None, pg_config_path=None): + def __init__( + self, + *, + data_dir_suffix: str | None = None, + data_dir_prefix: str | None = None, + data_dir_parent: _typeshed.StrPath | None = None, + pg_config_path: str | None = None, + ) -> None: self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix, prefix=data_dir_prefix, dir=data_dir_parent) @@ -590,10 +687,16 @@ def __init__(self, *, class HotStandbyCluster(TempCluster): - def __init__(self, *, - master, replication_user, - data_dir_suffix=None, data_dir_prefix=None, - data_dir_parent=None, pg_config_path=None): + def __init__( + self, + *, + master: _ConnectionSpec, + replication_user: str, + data_dir_suffix: str | None = None, + data_dir_prefix: str | None = None, + data_dir_parent: _typeshed.StrPath | None = None, + pg_config_path: str | None = None, + ) -> None: self._master = master self._repl_user = replication_user super().__init__( @@ -602,11 +705,11 @@ def __init__(self, *, data_dir_parent=data_dir_parent, pg_config_path=pg_config_path) - def _init_env(self): + def _init_env(self) -> None: super()._init_env() self._pg_basebackup = self._find_pg_binary('pg_basebackup') - def init(self, **settings): + def init(self, **settings: str) -> str: """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( @@ -641,7 +744,13 @@ def init(self, **settings): return output.decode() - def start(self, wait=60, *, server_settings={}, **opts): + def start( + self, + wait: int = 60, + *, + server_settings: dict[str, str] = {}, + **opts: object + ) -> None: if self._pg_version >= (12, 0): server_settings = server_settings.copy() server_settings['primary_conninfo'] = ( @@ -656,33 +765,43 @@ def start(self, wait=60, *, server_settings={}, **opts): class RunningCluster(Cluster): - def __init__(self, **kwargs): + conn_spec: _ConnectionSpec + + def __init__(self, **kwargs: Unpack[_ConnectionSpec]) -> None: self.conn_spec = kwargs - def is_managed(self): + def is_managed(self) -> bool: return False - def get_connection_spec(self): - return dict(self.conn_spec) + def get_connection_spec(self) -> _ConnectionSpec: + return typing.cast(_ConnectionSpec, dict(self.conn_spec)) - def get_status(self): + def get_status(self) -> str: return 'running' - def init(self, **settings): - pass + def init(self, **settings: str) -> str: # type: ignore[empty-body] + ... - def start(self, wait=60, **settings): - pass + def start(self, wait: int = 60, **settings: object) -> None: + ... - def stop(self, wait=60): - pass + def stop(self, wait: int = 60) -> None: + ... - def destroy(self): - pass + def destroy(self) -> None: + ... - def reset_hba(self): + def reset_hba(self) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') - def add_hba_entry(self, *, type='host', database, user, address=None, - auth_method, auth_options=None): + def add_hba_entry( + self, + *, + type: str = 'host', + database: str, + user: str, + address: str | None = None, + auth_method: str, + auth_options: dict[str, str] | None = None, + ) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 3eec9eb7..0ff6c6da 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -4,22 +4,26 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import pathlib import platform import typing import sys +if typing.TYPE_CHECKING: + import asyncio -SYSTEM = platform.uname().system +SYSTEM: typing.Final = platform.uname().system -if SYSTEM == 'Windows': + +if sys.platform == 'win32': import ctypes.wintypes CSIDL_APPDATA = 0x001a - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. @@ -31,14 +35,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]: return pathlib.Path(buf.value) / 'postgresql' else: - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: try: return pathlib.Path.home() except (RuntimeError, KeyError): return None -async def wait_closed(stream): +async def wait_closed(stream: asyncio.StreamWriter) -> None: # Not all asyncio versions have StreamWriter.wait_closed(). if hasattr(stream, 'wait_closed'): try: @@ -59,3 +63,40 @@ async def wait_closed(stream): from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 else: from asyncio import timeout as timeout # noqa: F401 + +if sys.version_info < (3, 9): + from typing import ( + AsyncIterable as AsyncIterable, + Awaitable as Awaitable, + Callable as Callable, + Coroutine as Coroutine, + Deque as deque, + Generator as Generator, + Iterable as Iterable, + Iterator as Iterator, + List as list, + OrderedDict as OrderedDict, + Sequence as Sequence, + Sized as Sized, + Tuple as tuple, + ) +else: + from builtins import ( # noqa: F401 + list as list, + tuple as tuple, + ) + from collections import ( # noqa: F401 + deque as deque, + OrderedDict as OrderedDict, + ) + from collections.abc import ( # noqa: F401 + AsyncIterable as AsyncIterable, + Awaitable as Awaitable, + Callable as Callable, + Coroutine as Coroutine, + Generator as Generator, + Iterable as Iterable, + Iterator as Iterator, + Sequence as Sequence, + Sized as Sized, + ) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..a9789a28 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -4,9 +4,9 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio -import collections import enum import functools import getpass @@ -29,6 +29,58 @@ from . import exceptions from . import protocol +if typing.TYPE_CHECKING: + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from . import connection + +_ConnectionT = typing.TypeVar( + '_ConnectionT', + bound='connection.Connection[typing.Any]' +) +_ProtocolT = typing.TypeVar( + '_ProtocolT', + bound='protocol.Protocol[typing.Any]' +) +_AsyncProtocolT = typing.TypeVar( + '_AsyncProtocolT', bound='asyncio.protocols.Protocol' +) +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_ParsedSSLType = typing.Union[ + ssl_module.SSLContext, typing.Literal[False] +] +_SSLStringValues = typing.Literal[ + 'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full' +] +_TPTupleType = compat.tuple[ + asyncio.WriteTransport, + _AsyncProtocolT +] +AddrType = typing.Union[ + compat.tuple[str, int], + str +] +HostType = typing.Union[compat.list[str], compat.tuple[str, ...], str] +PasswordType = typing.Union[ + str, + compat.Callable[[], str], + compat.Callable[[], compat.Awaitable[str]] +] +PortListType = typing.Union[ + compat.list[typing.Union[int, str]], + compat.list[int], + compat.list[str], +] +PortType = typing.Union[ + PortListType, + int, + str +] +SSLType = typing.Union[_ParsedSSLType, _SSLStringValues, bool] + class SSLMode(enum.IntEnum): disable = 0 @@ -39,48 +91,40 @@ class SSLMode(enum.IntEnum): verify_full = 5 @classmethod - def parse(cls, sslmode): + def parse(cls, sslmode: str | Self) -> Self: if isinstance(sslmode, cls): return sslmode - return getattr(cls, sslmode.replace('-', '_')) - - -_ConnectionParameters = collections.namedtuple( - 'ConnectionParameters', - [ - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'direct_tls', - 'server_settings', - 'target_session_attrs', - ]) - + return typing.cast( + 'Self', + getattr(cls, typing.cast(str, sslmode).replace('-', '_')) + ) -_ClientConfiguration = collections.namedtuple( - 'ConnectionConfiguration', - [ - 'command_timeout', - 'statement_cache_size', - 'max_cached_statement_lifetime', - 'max_cacheable_statement_size', - ]) +class _ConnectionParameters(typing.NamedTuple): + user: str + password: PasswordType | None + database: str + ssl: _ParsedSSLType | None + sslmode: SSLMode | None + direct_tls: bool + server_settings: dict[str, str] | None + target_session_attrs: SessionAttribute -_system = platform.uname().system +class _ClientConfiguration(typing.NamedTuple): + command_timeout: float | None + statement_cache_size: int + max_cached_statement_lifetime: int + max_cacheable_statement_size: int -if _system == 'Windows': - PGPASSFILE = 'pgpass.conf' -else: - PGPASSFILE = '.pgpass' +_system: typing.Final = platform.uname().system +PGPASSFILE: typing.Final = ( + 'pgpass.conf' if _system == 'Windows' else '.pgpass' +) -def _read_password_file(passfile: pathlib.Path) \ - -> typing.List[typing.Tuple[str, ...]]: +def _read_password_file(passfile: pathlib.Path) -> list[tuple[str, ...]]: passtab = [] try: @@ -122,11 +166,13 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( - *, passfile: typing.Optional[pathlib.Path], - hosts: typing.List[str], - ports: typing.List[int], - database: str, - user: str): + *, + passfile: pathlib.Path, + hosts: compat.Iterable[str], + ports: list[int], + database: str, + user: str +) -> str | None: """Parse the pgpass file and return the matching password. :return: @@ -158,7 +204,7 @@ def _read_password_from_pgpass( return None -def _validate_port_spec(hosts, port): +def _validate_port_spec(hosts: compat.Sized, port: PortType) -> list[int]: if isinstance(port, list): # If there is a list of ports, its length must # match that of the host list. @@ -166,42 +212,49 @@ def _validate_port_spec(hosts, port): raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) + return [int(p) for p in port] else: - port = [port for _ in range(len(hosts))] - - return port + return [int(port) for _ in range(len(hosts))] -def _parse_hostlist(hostlist, port, *, unquote=False): +def _parse_hostlist( + hostlist: str, + port: PortType | None, + *, + unquote: bool = False +) -> tuple[list[str], PortListType]: if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] - hosts = [] - hostlist_ports = [] + hosts: list[str] = [] + hostlist_ports: list[int] = [] + ports: list[int] | None = None if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: - default_port = [int(p) for p in portspec.split(',')] + temp_port: list[int] | int = [ + int(p) for p in portspec.split(',') + ] else: - default_port = int(portspec) + temp_port = int(portspec) else: - default_port = 5432 + temp_port = 5432 - default_port = _validate_port_spec(hostspecs, default_port) + default_port = _validate_port_spec(hostspecs, temp_port) else: - port = _validate_port_spec(hostspecs, port) + ports = _validate_port_spec(hostspecs, port) for i, hostspec in enumerate(hostspecs): if hostspec[0] == '/': # Unix socket addr = hostspec - hostspec_port = '' + hostspec_port: str = '' elif hostspec[0] == '[': # IPv6 address m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) @@ -230,13 +283,13 @@ def _parse_hostlist(hostlist, port, *, unquote=False): else: hostlist_ports.append(default_port[i]) - if not port: - port = hostlist_ports + if not ports: + ports = hostlist_ports - return hosts, port + return hosts, ports -def _parse_tls_version(tls_version): +def _parse_tls_version(tls_version: str) -> ssl_module.TLSVersion: if tls_version.startswith('SSL'): raise exceptions.ClientConfigurationError( f"Unsupported TLS version: {tls_version}" @@ -249,7 +302,7 @@ def _parse_tls_version(tls_version): ) -def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: +def _dot_postgresql_path(filename: str) -> pathlib.Path | None: try: homedir = pathlib.Path.home() except (RuntimeError, KeyError): @@ -258,15 +311,34 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: return (homedir / '.postgresql' / filename).resolve() -def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, - direct_tls, server_settings, - target_session_attrs): +def _parse_connect_dsn_and_args( + *, + dsn: str | None, + host: HostType | None, + port: PortType | None, + user: str | None, + password: str | None, + passfile: str | None, + database: str | None, + ssl: SSLType | None, + direct_tls: bool, + server_settings: dict[str, str] | None, + target_session_attrs: SessionAttribute | None, +) -> tuple[list[tuple[str, int] | str], _ConnectionParameters]: # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. - auth_hosts = None - sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None + auth_hosts: list[str] | tuple[str, ...] | None = None + sslcert: str | pathlib.Path | None = None + sslkey: str | pathlib.Path | None = None + sslrootcert: str | pathlib.Path | None = None + sslcrl: str | pathlib.Path | None = None + sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None + ssl_val: SSLType | str | None = ssl + ssl_parsed: _ParsedSSLType | None = None + target_session_attrs_val: ( + SessionAttribute | str | None + ) = target_session_attrs if dsn: parsed = urllib.parse.urlparse(dsn) @@ -306,10 +378,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, password = urllib.parse.unquote(dsn_password) if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + query: dict[str, str] = { + key: val[-1] if isinstance(val, list) else val + for key, val in urllib.parse.parse_qs( + parsed.query, strict_parsing=True + ).items() + } if 'port' in query: val = query.pop('port') @@ -348,8 +422,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'sslmode' in query: val = query.pop('sslmode') - if ssl is None: - ssl = val + if ssl_val is None: + ssl_val = val if 'sslcert' in query: sslcert = query.pop('sslcert') @@ -380,8 +454,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, dsn_target_session_attrs = query.pop( 'target_session_attrs' ) - if target_session_attrs is None: - target_session_attrs = dsn_target_session_attrs + if target_session_attrs_val is None: + target_session_attrs_val = dsn_target_session_attrs if query: if server_settings is None: @@ -425,7 +499,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: port = int(port) - port = _validate_port_spec(host, port) + validated_ports = _validate_port_spec(host, port) if user is None: user = os.getenv('PGUSER') @@ -456,21 +530,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if passfile is None: homedir = compat.get_pg_home_directory() if homedir: - passfile = homedir / PGPASSFILE + passfile_path: pathlib.Path | None = homedir / PGPASSFILE else: - passfile = None + passfile_path = None else: - passfile = pathlib.Path(passfile) + passfile_path = pathlib.Path(passfile) - if passfile is not None: + if passfile_path is not None: password = _read_password_from_pgpass( - hosts=auth_hosts, ports=port, + hosts=auth_hosts, ports=validated_ports, database=database, user=user, - passfile=passfile) + passfile=passfile_path) - addrs = [] + addrs: list[AddrType] = [] have_tcp_addrs = False - for h, p in zip(host, port): + for h, p in zip(host, validated_ports): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: @@ -485,15 +559,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, raise exceptions.InternalClientError( 'could not determine the database address to connect to') - if ssl is None: - ssl = os.getenv('PGSSLMODE') + if ssl_val is None: + ssl_val = os.getenv('PGSSLMODE') - if ssl is None and have_tcp_addrs: - ssl = 'prefer' + if ssl_val is None and have_tcp_addrs: + ssl_val = 'prefer' - if isinstance(ssl, (str, SSLMode)): + if isinstance(ssl_val, (str, SSLMode)): try: - sslmode = SSLMode.parse(ssl) + sslmode = SSLMode.parse(ssl_val) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( @@ -501,23 +575,25 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: - ssl = False + ssl_parsed = False else: - ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) - ssl.check_hostname = sslmode >= SSLMode.verify_full + ssl_parsed = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) + ssl_parsed.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE + ssl_parsed.verify_mode = ssl_module.CERT_NONE else: if sslrootcert is None: sslrootcert = os.getenv('PGSSLROOTCERT') if sslrootcert: - ssl.load_verify_locations(cafile=sslrootcert) - ssl.verify_mode = ssl_module.CERT_REQUIRED + ssl_parsed.load_verify_locations(cafile=sslrootcert) + ssl_parsed.verify_mode = ssl_module.CERT_REQUIRED else: try: sslrootcert = _dot_postgresql_path('root.crt') if sslrootcert is not None: - ssl.load_verify_locations(cafile=sslrootcert) + ssl_parsed.load_verify_locations( + cafile=sslrootcert + ) else: raise exceptions.ClientConfigurationError( 'cannot determine location of user ' @@ -548,29 +624,31 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, detail=detail, ) elif sslmode == SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE + ssl_parsed.verify_mode = ssl_module.CERT_NONE else: assert False, 'unreachable' else: - ssl.verify_mode = ssl_module.CERT_REQUIRED + ssl_parsed.verify_mode = ssl_module.CERT_REQUIRED if sslcrl is None: sslcrl = os.getenv('PGSSLCRL') if sslcrl: - ssl.load_verify_locations(cafile=sslcrl) - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + ssl_parsed.load_verify_locations(cafile=sslcrl) + ssl_parsed.verify_flags |= ( + ssl_module.VERIFY_CRL_CHECK_CHAIN + ) else: sslcrl = _dot_postgresql_path('root.crl') if sslcrl is not None: try: - ssl.load_verify_locations(cafile=sslcrl) + ssl_parsed.load_verify_locations(cafile=sslcrl) except ( FileNotFoundError, NotADirectoryError, ): pass else: - ssl.verify_flags |= \ + ssl_parsed.verify_flags |= \ ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: @@ -584,14 +662,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if sslcert is None: sslcert = os.getenv('PGSSLCERT') if sslcert: - ssl.load_cert_chain( + ssl_parsed.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) else: sslcert = _dot_postgresql_path('postgresql.crt') if sslcert is not None: try: - ssl.load_cert_chain( + ssl_parsed.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword @@ -603,28 +681,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if hasattr(ssl, 'keylog_filename'): keylogfile = os.environ.get('SSLKEYLOGFILE') if keylogfile and not sys.flags.ignore_environment: - ssl.keylog_filename = keylogfile + ssl_parsed.keylog_filename = keylogfile if ssl_min_protocol_version is None: ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') if ssl_min_protocol_version: - ssl.minimum_version = _parse_tls_version( + ssl_parsed.minimum_version = _parse_tls_version( ssl_min_protocol_version ) else: - ssl.minimum_version = _parse_tls_version('TLSv1.2') + ssl_parsed.minimum_version = _parse_tls_version('TLSv1.2') if ssl_max_protocol_version is None: ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') if ssl_max_protocol_version: - ssl.maximum_version = _parse_tls_version( + ssl_parsed.maximum_version = _parse_tls_version( ssl_max_protocol_version ) - elif ssl is True: - ssl = ssl_module.create_default_context() + elif ssl_val is True: + ssl_parsed = ssl_module.create_default_context() sslmode = SSLMode.verify_full else: + ssl_parsed = ssl_val sslmode = SSLMode.disable if server_settings is not None and ( @@ -635,23 +714,23 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, 'server_settings is expected to be None or ' 'a Dict[str, str]') - if target_session_attrs is None: - target_session_attrs = os.getenv( + if target_session_attrs_val is None: + target_session_attrs_val = os.getenv( "PGTARGETSESSIONATTRS", SessionAttribute.any ) try: - target_session_attrs = SessionAttribute(target_session_attrs) + target_session_attrs = SessionAttribute(target_session_attrs_val) except ValueError: raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( - SessionAttribute.__members__.values, target_session_attrs + SessionAttribute.__members__.values, target_session_attrs_val ) ) from None params = _ConnectionParameters( - user=user, password=password, database=database, ssl=ssl, + user=user, password=password, database=database, ssl=ssl_parsed, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs) @@ -659,13 +738,26 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, return addrs, params -def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, - database, command_timeout, - statement_cache_size, - max_cached_statement_lifetime, - max_cacheable_statement_size, - ssl, direct_tls, server_settings, - target_session_attrs): +def _parse_connect_arguments( + *, + dsn: str | None, + host: HostType | None, + port: PortType | None, + user: str | None, + password: str | None, + passfile: str | None, + database: str | None, + command_timeout: float | typing.SupportsFloat | None, + statement_cache_size: int, + max_cached_statement_lifetime: int, + max_cacheable_statement_size: int, + ssl: SSLType | None, + direct_tls: bool, + server_settings: dict[str, str] | None, + target_session_attrs: SessionAttribute, +) -> tuple[ + list[tuple[str, int] | str], _ConnectionParameters, _ClientConfiguration +]: local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -706,14 +798,27 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + on_data: asyncio.Future[bool] + host: str + port: int + ssl_context: ssl_module.SSLContext + ssl_is_advisory: bool | None + + def __init__( + self, + loop: asyncio.AbstractEventLoop | None, + host: str, + port: int, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None + ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory - def data_received(self, data): + def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and @@ -731,20 +836,63 @@ def data_received(self, data): 'rejected SSL upgrade'.format( host=self.host, port=self.port))) - def connection_lost(self, exc): + def connection_lost(self, exc: Exception | None) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory=False): - - tr, pr = await loop.create_connection( - lambda: TLSUpgradeProto(loop, host, port, - ssl_context, ssl_is_advisory), - host, port) +@typing.overload +async def _create_ssl_connection( + protocol_factory: compat.Callable[[], _ProtocolT], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None = False +) -> _TPTupleType[_ProtocolT]: + ... + + +@typing.overload +async def _create_ssl_connection( + protocol_factory: compat.Callable[[], '_CancelProto'], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None = False +) -> _TPTupleType['_CancelProto']: + ... + + +async def _create_ssl_connection( + protocol_factory: compat.Callable[ + [], _ProtocolT + ] | compat.Callable[ + [], '_CancelProto' + ], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: typing.Optional[bool] = False +) -> _TPTupleType[typing.Any]: + + tr, pr = typing.cast( + compat.tuple[asyncio.WriteTransport, TLSUpgradeProto], + await loop.create_connection( + lambda: TLSUpgradeProto( + loop, host, port, ssl_context, ssl_is_advisory + ), + host, + port + ) + ) tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. @@ -757,8 +905,12 @@ async def _create_ssl_connection(protocol_factory, host, port, *, if hasattr(loop, 'start_tls'): if do_ssl_upgrade: try: - new_tr = await loop.start_tls( - tr, pr, ssl_context, server_hostname=host) + new_tr = typing.cast( + asyncio.WriteTransport, + await loop.start_tls( + tr, pr, ssl_context, server_hostname=host + ) + ) except (Exception, asyncio.CancelledError): tr.close() raise @@ -795,13 +947,13 @@ async def _create_ssl_connection(protocol_factory, host, port, *, async def _connect_addr( *, - addr, - loop, - params, - config, - connection_class, - record_class -): + addr: AddrType, + loop: asyncio.AbstractEventLoop, + params: _ConnectionParameters, + config: _ClientConfiguration, + connection_class: type[_ConnectionT], + record_class: type[_RecordT] +) -> _ConnectionT: assert loop is not None params_input = params @@ -810,7 +962,7 @@ async def _connect_addr( if inspect.isawaitable(password): password = await password - params = params._replace(password=password) + params = params._replace(password=typing.cast(str, password)) args = (addr, loop, config, connection_class, record_class, params_input) # prepare the params (which attempt has ssl) for the 2 attempts @@ -838,15 +990,15 @@ class _RetryConnectSignal(Exception): async def __connect_addr( - params, - retry, - addr, - loop, - config, - connection_class, - record_class, - params_input, -): + params: _ConnectionParameters, + retry: bool, + addr: AddrType, + loop: asyncio.AbstractEventLoop, + config: _ClientConfiguration, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + params_input: _ConnectionParameters, +) -> _ConnectionT: connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( @@ -854,13 +1006,21 @@ async def __connect_addr( if isinstance(addr, str): # UNIX socket - connector = loop.create_unix_connection(proto_factory, addr) + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_unix_connection(proto_factory, addr) + ) elif params.ssl and params.direct_tls: # if ssl and direct_tls are given, skip STARTTLS and perform direct # SSL connection - connector = loop.create_connection( - proto_factory, *addr, ssl=params.ssl + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_connection(proto_factory, *addr, ssl=params.ssl) ) elif params.ssl: @@ -868,7 +1028,12 @@ async def __connect_addr( proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: - connector = loop.create_connection(proto_factory, *addr) + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_connection(proto_factory, *addr) + ) tr, pr = await connector @@ -921,18 +1086,24 @@ class SessionAttribute(str, enum.Enum): read_only = "read-only" -def _accept_in_hot_standby(should_be_in_hot_standby: bool): +def _accept_in_hot_standby(should_be_in_hot_standby: bool) -> compat.Callable[ + [connection.Connection[typing.Any]], compat.Awaitable[bool] +]: """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". If the server allows a connection and states it is in recovery it must be a replica/standby server. """ - async def can_be_used(connection): + async def can_be_used( + connection: connection.Connection[typing.Any] + ) -> bool: settings = connection.get_settings() - hot_standby_status = getattr(settings, 'in_hot_standby', None) + hot_standby_status: str | None = getattr( + settings, 'in_hot_standby', None + ) if hot_standby_status is not None: - is_in_hot_standby = hot_standby_status == 'on' + is_in_hot_standby: bool = hot_standby_status == 'on' else: is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" @@ -942,11 +1113,15 @@ async def can_be_used(connection): return can_be_used -def _accept_read_only(should_be_read_only: bool): +def _accept_read_only(should_be_read_only: bool) -> compat.Callable[ + [connection.Connection[typing.Any]], compat.Awaitable[bool] +]: """ Verify the server has not set default_transaction_read_only=True """ - async def can_be_used(connection): + async def can_be_used( + connection: connection.Connection[typing.Any] + ) -> bool: settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') @@ -957,11 +1132,19 @@ async def can_be_used(connection): return can_be_used -async def _accept_any(_): +async def _accept_any(_: connection.Connection[typing.Any]) -> bool: return True -target_attrs_check = { +target_attrs_check: typing.Final[ + dict[ + SessionAttribute, + compat.Callable[ + [connection.Connection[typing.Any]], + compat.Awaitable[bool] + ] + ] +] = { SessionAttribute.any: _accept_any, SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), @@ -971,21 +1154,30 @@ async def _accept_any(_): } -async def _can_use_connection(connection, attr: SessionAttribute): +async def _can_use_connection( + connection: connection.Connection[typing.Any], + attr: SessionAttribute +) -> bool: can_use = target_attrs_check[attr] return await can_use(connection) -async def _connect(*, loop, connection_class, record_class, **kwargs): +async def _connect( + *, + loop: asyncio.AbstractEventLoop | None, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + **kwargs: typing.Any +) -> _ConnectionT: if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs - candidates = [] + candidates: list[_ConnectionT] = [] chosen_connection = None - last_error = None + last_error: BaseException | None = None for addr in addrs: try: conn = await _connect_addr( @@ -1020,32 +1212,44 @@ async def _connect(*, loop, connection_class, record_class, **kwargs): ) -async def _cancel(*, loop, addr, params: _ConnectionParameters, - backend_pid, backend_secret): +class _CancelProto(asyncio.Protocol): - class CancelProto(asyncio.Protocol): + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self.on_disconnect = _create_future(loop) + self.is_ssl = False - def __init__(self): - self.on_disconnect = _create_future(loop) - self.is_ssl = False + def connection_lost(self, exc: Exception | None) -> None: + if not self.on_disconnect.done(): + self.on_disconnect.set_result(True) - def connection_lost(self, exc): - if not self.on_disconnect.done(): - self.on_disconnect.set_result(True) + +async def _cancel( + *, + loop: asyncio.AbstractEventLoop, + addr: AddrType, + params: _ConnectionParameters, + backend_pid: int, + backend_secret: str +) -> None: + proto_factory: compat.Callable[ + [], _CancelProto + ] = lambda: _CancelProto(loop) if isinstance(addr, str): - tr, pr = await loop.create_unix_connection(CancelProto, addr) + tr, pr = typing.cast( + _TPTupleType[_CancelProto], + await loop.create_unix_connection(proto_factory, addr) + ) else: if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( - CancelProto, + proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: - tr, pr = await loop.create_connection( - CancelProto, *addr) + tr, pr = await loop.create_connection(proto_factory, *addr) _set_nodelay(_get_socket(tr)) # Pack a CancelRequest message @@ -1058,7 +1262,7 @@ def connection_lost(self, exc): tr.close() -def _get_socket(transport): +def _get_socket(transport: asyncio.BaseTransport) -> typing.Any: sock = transport.get_extra_info('socket') if sock is None: # Shouldn't happen with any asyncio-complaint event loop. @@ -1067,14 +1271,16 @@ def _get_socket(transport): return sock -def _set_nodelay(sock): +def _set_nodelay(sock: typing.Any) -> None: if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) -def _create_future(loop): +def _create_future( + loop: asyncio.AbstractEventLoop | None +) -> asyncio.Future[typing.Any]: try: - create_future = loop.create_future + create_future = loop.create_future # type: ignore[union-attr] except AttributeError: return asyncio.Future(loop=loop) else: diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..d551e537 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -4,12 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import asyncpg import collections import collections.abc import contextlib +import dataclasses import functools import itertools import inspect @@ -32,15 +34,114 @@ from . import transaction from . import utils +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +if typing.TYPE_CHECKING: + import io + + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from .protocol import protocol as _cprotocol + from .exceptions import _postgres_message + from . import pool_connection_proxy as _pool + from . import types + +_ConnectionT = typing.TypeVar('_ConnectionT', bound='Connection[typing.Any]') +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record) +_P = ParamSpec('_P') + +_WriterType = compat.Callable[ + [bytes], compat.Coroutine[typing.Any, typing.Any, None] +] +_OutputType = typing.Union[ + 'os.PathLike[typing.Any]', typing.BinaryIO, _WriterType +] +_CopyFormat = typing.Literal['text', 'csv', 'binary'] +_SourceType = typing.Union[ + 'os.PathLike[typing.Any]', typing.BinaryIO, compat.AsyncIterable[bytes] +] +_RecordsType = compat.list[_RecordT] +_RecordsTupleType = compat.tuple[_RecordsType[_RecordT], bytes, bool] + + +class Listener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + pid: int, + channel: str, + payload: object, + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class LogListener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + message: _postgres_message.PostgresMessage, + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class TerminationListener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class QueryLogger(typing.Protocol): + def __call__( + self, record: LoggedQuery, / + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class Executor(typing.Protocol[_RecordT]): + def __call__( + self, + statement: _cprotocol.PreparedStatementState[_RecordT], + timeout: float | None, + / + ) -> typing.Any: + ... + + +class OnRemove(typing.Protocol[_RecordT]): + def __call__( + self, + statement: _cprotocol.PreparedStatementState[_RecordT], + / + ) -> None: + ... + class ConnectionMeta(type): - def __instancecheck__(cls, instance): + def __instancecheck__(cls, instance: object) -> bool: mro = type(instance).__mro__ return Connection in mro or _ConnectionProxy in mro -class Connection(metaclass=ConnectionMeta): +class Connection(typing.Generic[_RecordT], metaclass=ConnectionMeta): """A representation of a database session. Connections are created by calling :func:`~asyncpg.connection.connect`. @@ -56,10 +157,66 @@ class Connection(metaclass=ConnectionMeta): '_log_listeners', '_termination_listeners', '_cancellations', '_source_traceback', '_query_loggers', '__weakref__') - def __init__(self, protocol, transport, loop, - addr, - config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + _protocol: _cprotocol.BaseProtocol[_RecordT] + _transport: object + _loop: asyncio.AbstractEventLoop + _top_xact: transaction.Transaction | None + _aborted: bool + _pool_release_ctr: int + _stmt_cache: _StatementCache + _stmts_to_close: set[_cprotocol.PreparedStatementState[typing.Any]] + _stmt_cache_enabled: bool + _listeners: dict[ + str, + set[ + _Callback[ + [ + Connection[typing.Any] | + _pool.PoolConnectionProxy[typing.Any], + int, + str, + object + ] + ] + ] + ] + _server_version: types.ServerVersion + _server_caps: ServerCapabilities + _intro_query: str + _reset_query: str | None + _proxy: _pool.PoolConnectionProxy[typing.Any] | None + _stmt_exclusive_section: _Atomic + _config: connect_utils._ClientConfiguration + _params: connect_utils._ConnectionParameters + _addr: connect_utils.AddrType + _log_listeners: set[ + _Callback[ + [ + Connection[typing.Any] | _pool.PoolConnectionProxy[typing.Any], + _postgres_message.PostgresMessage, + ] + ] + ] + _termination_listeners: set[ + _Callback[ + [ + Connection[typing.Any] | _pool.PoolConnectionProxy[typing.Any], + ] + ] + ] + _cancellations: set[asyncio.Task[typing.Any]] + _source_traceback: str | None + _query_loggers: set[_Callback[[LoggedQuery]]] + + def __init__( + self, + protocol: _cprotocol.BaseProtocol[_RecordT], + transport: object, + loop: asyncio.AbstractEventLoop, + addr: tuple[str, int] | str, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters, + ) -> None: self._protocol = protocol self._transport = transport self._loop = loop @@ -120,7 +277,7 @@ def __init__(self, protocol, transport, loop, else: self._source_traceback = None - def __del__(self): + def __del__(self) -> None: if not self.is_closed() and self._protocol is not None: if self._source_traceback: msg = "unclosed connection {!r}; created at:\n {}".format( @@ -136,7 +293,7 @@ def __del__(self): if not self._loop.is_closed(): self.terminate() - async def add_listener(self, channel, callback): + async def add_listener(self, channel: str, callback: Listener) -> None: """Add a listener for Postgres notifications. :param str channel: Channel to listen on. @@ -158,7 +315,7 @@ async def add_listener(self, channel, callback): self._listeners[channel] = set() self._listeners[channel].add(_Callback.from_callable(callback)) - async def remove_listener(self, channel, callback): + async def remove_listener(self, channel: str, callback: Listener) -> None: """Remove a listening callback on the specified channel.""" if self.is_closed(): return @@ -172,7 +329,7 @@ async def remove_listener(self, channel, callback): del self._listeners[channel] await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) - def add_log_listener(self, callback): + def add_log_listener(self, callback: LogListener) -> None: """Add a listener for Postgres log messages. It will be called when asyncronous NoticeResponse is received @@ -194,14 +351,14 @@ def add_log_listener(self, callback): raise exceptions.InterfaceError('connection is closed') self._log_listeners.add(_Callback.from_callable(callback)) - def remove_log_listener(self, callback): + def remove_log_listener(self, callback: LogListener) -> None: """Remove a listening callback for log messages. .. versionadded:: 0.12.0 """ self._log_listeners.discard(_Callback.from_callable(callback)) - def add_termination_listener(self, callback): + def add_termination_listener(self, callback: TerminationListener) -> None: """Add a listener that will be called when the connection is closed. :param callable callback: @@ -215,7 +372,9 @@ def add_termination_listener(self, callback): """ self._termination_listeners.add(_Callback.from_callable(callback)) - def remove_termination_listener(self, callback): + def remove_termination_listener( + self, callback: TerminationListener + ) -> None: """Remove a listening callback for connection termination. :param callable callback: @@ -226,7 +385,7 @@ def remove_termination_listener(self, callback): """ self._termination_listeners.discard(_Callback.from_callable(callback)) - def add_query_logger(self, callback): + def add_query_logger(self, callback: QueryLogger) -> None: """Add a logger that will be called when queries are executed. :param callable callback: @@ -239,7 +398,7 @@ def add_query_logger(self, callback): """ self._query_loggers.add(_Callback.from_callable(callback)) - def remove_query_logger(self, callback): + def remove_query_logger(self, callback: QueryLogger) -> None: """Remove a query logger callback. :param callable callback: @@ -250,11 +409,11 @@ def remove_query_logger(self, callback): """ self._query_loggers.discard(_Callback.from_callable(callback)) - def get_server_pid(self): + def get_server_pid(self) -> int: """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() - def get_server_version(self): + def get_server_version(self) -> types.ServerVersion: """Return the version of the connected PostgreSQL server. The returned value is a named tuple similar to that in @@ -270,15 +429,20 @@ def get_server_version(self): """ return self._server_version - def get_settings(self): + def get_settings(self) -> _cprotocol.ConnectionSettings: """Return connection settings. :return: :class:`~asyncpg.ConnectionSettings`. """ return self._protocol.get_settings() - def transaction(self, *, isolation=None, readonly=False, - deferrable=False): + def transaction( + self, + *, + isolation: transaction.IsolationLevels | None = None, + readonly: bool = False, + deferrable: bool = False, + ) -> transaction.Transaction: """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction @@ -303,7 +467,7 @@ def transaction(self, *, isolation=None, readonly=False, self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) - def is_in_transaction(self): + def is_in_transaction(self) -> bool: """Return True if Connection is currently inside a transaction. :return bool: True if inside transaction, False otherwise. @@ -312,7 +476,9 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, query: str, *args: object, timeout: float | None = None + ) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -359,7 +525,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None, + ) -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -390,16 +562,42 @@ async def executemany(self, command: str, args, *, timeout: float=None): self._check_open() return await self._executemany(command, args, timeout) + @typing.overload async def _get_statement( self, - query, - timeout, + query: str, + timeout: float | None, *, - named=False, - use_cache=True, - ignore_custom_codec=False, - record_class=None - ): + named: bool | str = ..., + use_cache: bool = ..., + ignore_custom_codec: bool = ..., + record_class: None = ..., + ) -> _cprotocol.PreparedStatementState[_RecordT]: + ... + + @typing.overload + async def _get_statement( + self, + query: str, + timeout: float | None, + *, + named: bool | str = ..., + use_cache: bool = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT], + ) -> _cprotocol.PreparedStatementState[_OtherRecordT]: + ... + + async def _get_statement( + self, + query: str, + timeout: float | None, + *, + named: bool | str = False, + use_cache: bool = True, + ignore_custom_codec: bool = False, + record_class: type[typing.Any] | None = None + ) -> _cprotocol.PreparedStatementState[typing.Any]: if record_class is None: record_class = self._protocol.get_record_class() else: @@ -492,7 +690,11 @@ async def _get_statement( return statement - async def _introspect_types(self, typeoids, timeout): + async def _introspect_types( + self, + typeoids: compat.Iterable[int], + timeout: float | None + ) -> tuple[typing.Any, _cprotocol.PreparedStatementState[_RecordT]]: if self._server_caps.jit: try: cfgrow, _ = await self.__execute( @@ -534,7 +736,7 @@ async def _introspect_types(self, typeoids, timeout): return result - async def _introspect_type(self, typename, schema): + async def _introspect_type(self, typename: str, schema: str) -> typing.Any: if ( schema == 'pg_catalog' and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP @@ -562,14 +764,47 @@ async def _introspect_type(self, typename, schema): return rows[0] + @typing.overload def cursor( self, - query, - *args, - prefetch=None, - timeout=None, - record_class=None - ): + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> cursor.CursorFactory[_RecordT]: + ... + + @typing.overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> cursor.CursorFactory[_OtherRecordT]: + ... + + @typing.overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> cursor.CursorFactory[_RecordT] | cursor.CursorFactory[_OtherRecordT]: + ... + + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = None, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> cursor.CursorFactory[typing.Any]: """Return a *cursor factory* for the specified query. :param args: @@ -601,13 +836,52 @@ def cursor( record_class, ) + @typing.overload async def prepare( self, - query, + query: str, *, - name=None, - timeout=None, - record_class=None, + name: str | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: + ... + + @typing.overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: + ... + + @typing.overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): + ... + + async def prepare( + self, + query: str, + *, + name: str | None = None, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] ): """Create a *prepared statement* for the specified query. @@ -641,32 +915,108 @@ async def prepare( record_class=record_class, ) + @typing.overload + async def _prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: + ... + + @typing.overload + async def _prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: + ... + + @typing.overload async def _prepare( self, - query, + query: str, *, - name=None, - timeout=None, - use_cache: bool=False, - record_class=None + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): + ... + + async def _prepare( + self, + query: str, + *, + name: str | None = None, + timeout: float | None = None, + use_cache: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] ): self._check_open() + + named: bool | str = True if name is None else name stmt = await self._get_statement( query, timeout, - named=True if name is None else name, + named=named, use_cache=use_cache, record_class=record_class, ) - return prepared_stmt.PreparedStatement(self, query, stmt) + return prepared_stmt.PreparedStatement(self, query, typing.cast( + '_cprotocol.PreparedStatementState[typing.Any]', stmt + )) + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: + ... + @typing.overload async def fetch( self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: + ... + + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None + ) -> list[_RecordT] | list[_OtherRecordT]: """Run a query and return the results as a list of :class:`Record`. :param str query: @@ -696,7 +1046,13 @@ async def fetch( record_class=record_class, ) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval( + self, + query: str, + *args: object, + column: int = 0, + timeout: float | None = None, + ) -> typing.Any: """Run a query and return a value in the first row. :param str query: Query text. @@ -717,13 +1073,43 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] + @typing.overload async def fetchrow( self, - query, - *args, - timeout=None, - record_class=None - ): + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: + ... + + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None + ) -> _RecordT | _OtherRecordT | None: """Run a query and return the first row. :param str query: @@ -757,11 +1143,24 @@ async def fetchrow( return None return data[0] - async def copy_from_table(self, table_name, *, output, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, delimiter=None, - null=None, header=None, quote=None, - escape=None, force_quote=None, encoding=None): + async def copy_from_table( + self, + table_name: str, + *, + output: _OutputType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy table contents to a file or file-like object. :param str table_name: @@ -829,11 +1228,22 @@ async def copy_from_table(self, table_name, *, output, return await self._copy_out(copy_stmt, output, timeout) - async def copy_from_query(self, query, *args, output, - timeout=None, format=None, oids=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - encoding=None): + async def copy_from_query( + self, + query: str, + *args: object, + output: _OutputType, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy the results of a query to a file or file-like object. :param str query: @@ -891,13 +1301,28 @@ async def copy_from_query(self, query, *args, output, return await self._copy_out(copy_stmt, output, timeout) - async def copy_to_table(self, table_name, *, source, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - force_not_null=None, force_null=None, - encoding=None, where=None): + async def copy_to_table( + self, + table_name: str, + *, + source: _SourceType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + where: str | None = None, + ) -> str: """Copy data to the specified table. :param str table_name: @@ -979,9 +1404,18 @@ async def copy_to_table(self, table_name, *, source, return await self._copy_in(copy_stmt, source, timeout) - async def copy_records_to_table(self, table_name, *, records, - columns=None, schema_name=None, - timeout=None, where=None): + async def copy_records_to_table( + self, + table_name: str, + *, + records: compat.Iterable[ + compat.Sequence[object] + ] | compat.AsyncIterable[compat.Sequence[object]], + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + where: str | None = None, + ) -> str: """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -1081,7 +1515,7 @@ async def copy_records_to_table(self, table_name, *, records, return await self._protocol.copy_in( copy_stmt, None, None, records, intro_ps._state, timeout) - def _format_copy_where(self, where): + def _format_copy_where(self, where: str | None) -> str: if where and not self._server_caps.sql_copy_from_where: raise exceptions.UnsupportedServerFeatureError( 'the `where` parameter requires PostgreSQL 12 or later') @@ -1093,13 +1527,25 @@ def _format_copy_where(self, where): return where_clause - def _format_copy_opts(self, *, format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, quote=None, - escape=None, force_quote=None, force_not_null=None, - force_null=None, encoding=None): + def _format_copy_opts( + self, + *, + format: _CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None + ) -> str: kwargs = dict(locals()) kwargs.pop('self') - opts = [] + opts: list[str] = [] if force_quote is not None and isinstance(force_quote, bool): kwargs.pop('force_quote') @@ -1122,24 +1568,31 @@ def _format_copy_opts(self, *, format=None, oids=None, freeze=None, else: return '' - async def _copy_out(self, copy_stmt, output, timeout): + async def _copy_out( + self, copy_stmt: str, output: _OutputType, timeout: float | None + ) -> str: try: - path = os.fspath(output) + path: str | bytes | None = typing.cast( + 'str | bytes', os.fspath(typing.cast(typing.Any, output)) + ) except TypeError: # output is not a path-like object path = None - writer = None + writer: _WriterType | None = None opened_by_us = False run_in_executor = self._loop.run_in_executor if path is not None: # a path - f = await run_in_executor(None, open, path, 'wb') + f = typing.cast( + 'io.BufferedWriter', + await run_in_executor(None, open, path, 'wb') + ) opened_by_us = True elif hasattr(output, 'write'): # file-like - f = output + f = typing.cast('io.BufferedWriter', output) elif callable(output): # assuming calling output returns an awaitable. writer = output @@ -1151,7 +1604,7 @@ async def _copy_out(self, copy_stmt, output, timeout): ) if writer is None: - async def _writer(data): + async def _writer(data: bytes) -> None: await run_in_executor(None, f.write, data) writer = _writer @@ -1161,14 +1614,18 @@ async def _writer(data): if opened_by_us: f.close() - async def _copy_in(self, copy_stmt, source, timeout): + async def _copy_in( + self, copy_stmt: str, source: _SourceType, timeout: float | None + ) -> str: try: - path = os.fspath(source) + path: str | bytes | None = typing.cast( + 'str | bytes', os.fspath(typing.cast(typing.Any, source)) + ) except TypeError: # source is not a path-like object path = None - f = None + f: typing.BinaryIO | None = None reader = None data = None opened_by_us = False @@ -1176,11 +1633,14 @@ async def _copy_in(self, copy_stmt, source, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'rb') + f = typing.cast( + 'io.BufferedWriter', + await run_in_executor(None, open, path, 'rb') + ) opened_by_us = True elif hasattr(source, 'read'): # file-like - f = source + f = typing.cast('io.BufferedWriter', source) elif isinstance(source, collections.abc.AsyncIterable): # assuming calling output returns an awaitable. # copy_in() is designed to handle very large amounts of data, and @@ -1194,11 +1654,13 @@ async def _copy_in(self, copy_stmt, source, timeout): if f is not None: # Copying from a file-like object. class _Reader: - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): - data = await run_in_executor(None, f.read, 524288) + async def __anext__(self) -> bytes: + data = await run_in_executor( + None, typing.cast(typing.BinaryIO, f).read, 524288 + ) if len(data) == 0: raise StopAsyncIteration else: @@ -1211,11 +1673,20 @@ async def __anext__(self): copy_stmt, reader, data, None, None, timeout) finally: if opened_by_us: - await run_in_executor(None, f.close) + await run_in_executor( + None, + typing.cast(typing.BinaryIO, f).close + ) - async def set_type_codec(self, typename, *, - schema='public', encoder, decoder, - format='text'): + async def set_type_codec( + self, + typename: str, + *, + schema: str = 'public', + encoder: compat.Callable[[typing.Any], typing.Any], + decoder: compat.Callable[[typing.Any], typing.Any], + format: str = 'text', + ) -> None: """Set an encoder/decoder pair for the specified data type. :param typename: @@ -1337,7 +1808,7 @@ async def set_type_codec(self, typename, *, self._check_open() settings = self._protocol.get_settings() typeinfo = await self._introspect_type(typename, schema) - full_typeinfos = [] + full_typeinfos: list[object] = [] if introspection.is_scalar_type(typeinfo): kind = 'scalar' elif introspection.is_composite_type(typeinfo): @@ -1375,7 +1846,9 @@ async def set_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def reset_type_codec(self, typename, *, schema='public'): + async def reset_type_codec( + self, typename: str, *, schema: str = 'public' + ) -> None: """Reset *typename* codec to the default implementation. :param typename: @@ -1395,9 +1868,14 @@ async def reset_type_codec(self, typename, *, schema='public'): # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def set_builtin_type_codec(self, typename, *, - schema='public', codec_name, - format=None): + async def set_builtin_type_codec( + self, + typename: str, + *, + schema: str = 'public', + codec_name: str, + format: str | None = None, + ) -> None: """Set a builtin codec for the specified scalar data type. This method has two uses. The first is to register a builtin @@ -1445,7 +1923,7 @@ async def set_builtin_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - def is_closed(self): + def is_closed(self) -> bool: """Return ``True`` if the connection is closed, ``False`` otherwise. :return bool: ``True`` if the connection is closed, ``False`` @@ -1453,7 +1931,7 @@ def is_closed(self): """ return self._aborted or not self._protocol.is_connected() - async def close(self, *, timeout=None): + async def close(self, *, timeout: float | None = None) -> None: """Close the connection gracefully. :param float timeout: @@ -1472,13 +1950,13 @@ async def close(self, *, timeout=None): finally: self._cleanup() - def terminate(self): + def terminate(self) -> None: """Terminate the connection without waiting for pending data.""" if not self.is_closed(): self._abort() self._cleanup() - async def reset(self, *, timeout=None): + async def reset(self, *, timeout: float | None = None) -> None: self._check_open() self._listeners.clear() self._log_listeners.clear() @@ -1499,13 +1977,13 @@ async def reset(self, *, timeout=None): if reset_query: await self.execute(reset_query, timeout=timeout) - def _abort(self): + def _abort(self) -> None: # Put the connection into the aborted state. self._aborted = True self._protocol.abort() - self._protocol = None + self._protocol = None # type: ignore[assignment] - def _cleanup(self): + def _cleanup(self) -> None: self._call_termination_listeners() # Free the resources associated with this connection. # This must be called when a connection is terminated. @@ -1521,7 +1999,7 @@ def _cleanup(self): self._query_loggers.clear() self._clean_tasks() - def _clean_tasks(self): + def _clean_tasks(self) -> None: # Wrap-up any remaining tasks associated with this connection. if self._cancellations: for fut in self._cancellations: @@ -1529,16 +2007,16 @@ def _clean_tasks(self): fut.cancel() self._cancellations.clear() - def _check_open(self): + def _check_open(self) -> None: if self.is_closed(): raise exceptions.InterfaceError('connection is closed') - def _get_unique_id(self, prefix): + def _get_unique_id(self, prefix: str) -> str: global _uid _uid += 1 return '__asyncpg_{}_{:x}__'.format(prefix, _uid) - def _mark_stmts_as_closed(self): + def _mark_stmts_as_closed(self) -> None: for stmt in self._stmt_cache.iter_statements(): stmt.mark_closed() @@ -1548,7 +2026,9 @@ def _mark_stmts_as_closed(self): self._stmt_cache.clear() self._stmts_to_close.clear() - def _maybe_gc_stmt(self, stmt): + def _maybe_gc_stmt( + self, stmt: _cprotocol.PreparedStatementState[typing.Any] + ) -> None: if ( stmt.refs == 0 and stmt.name @@ -1567,7 +2047,7 @@ def _maybe_gc_stmt(self, stmt): stmt.mark_closed() self._stmts_to_close.add(stmt) - async def _cleanup_stmts(self): + async def _cleanup_stmts(self) -> None: # Called whenever we create a new prepared statement in # `Connection._get_statement()` and `_stmts_to_close` is # not empty. @@ -1578,7 +2058,7 @@ async def _cleanup_stmts(self): # so we ignore the timeout. await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) - async def _cancel(self, waiter): + async def _cancel(self, waiter: asyncio.Future[None]) -> None: try: # Open new connection to the server await connect_utils._cancel( @@ -1602,15 +2082,18 @@ async def _cancel(self, waiter): if not waiter.done(): waiter.set_exception(ex) finally: - self._cancellations.discard( - asyncio.current_task(self._loop)) + current_task = asyncio.current_task(self._loop) + if current_task is not None: + self._cancellations.discard(current_task) if not waiter.done(): waiter.set_result(None) - def _cancel_current_command(self, waiter): + def _cancel_current_command(self, waiter: asyncio.Future[None]) -> None: self._cancellations.add(self._loop.create_task(self._cancel(waiter))) - def _process_log_message(self, fields, last_query): + def _process_log_message( + self, fields: dict[str, str], last_query: str + ) -> None: if not self._log_listeners: return @@ -1618,38 +2101,31 @@ def _process_log_message(self, fields, last_query): con_ref = self._unwrap() for cb in self._log_listeners: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref, message)) - else: - self._loop.call_soon(cb.cb, con_ref, message) + cb.invoke(self._loop, con_ref, message) - def _call_termination_listeners(self): + def _call_termination_listeners(self) -> None: if not self._termination_listeners: return con_ref = self._unwrap() for cb in self._termination_listeners: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref)) - else: - self._loop.call_soon(cb.cb, con_ref) + cb.invoke(self._loop, con_ref) self._termination_listeners.clear() - def _process_notification(self, pid, channel, payload): + def _process_notification( + self, pid: int, channel: str, payload: typing.Any + ) -> None: if channel not in self._listeners: return con_ref = self._unwrap() for cb in self._listeners[channel]: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) - else: - self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) + cb.invoke(self._loop, con_ref, pid, channel, payload) - def _unwrap(self): + def _unwrap(self) -> Self | _pool.PoolConnectionProxy[typing.Any]: if self._proxy is None: - con_ref = self + con_ref: Self | _pool.PoolConnectionProxy[typing.Any] = self else: # `_proxy` is not None when the connection is a member # of a connection pool. Which means that the user is working @@ -1658,13 +2134,13 @@ def _unwrap(self): con_ref = self._proxy return con_ref - def _get_reset_query(self): + def _get_reset_query(self) -> str: if self._reset_query is not None: return self._reset_query caps = self._server_caps - _reset_query = [] + _reset_query: list[str] = [] if caps.advisory_locks: _reset_query.append('SELECT pg_advisory_unlock_all();') if caps.sql_close_all: @@ -1674,12 +2150,11 @@ def _get_reset_query(self): if caps.sql_reset: _reset_query.append('RESET ALL;') - _reset_query = '\n'.join(_reset_query) - self._reset_query = _reset_query + self._reset_query = '\n'.join(_reset_query) - return _reset_query + return self._reset_query - def _set_proxy(self, proxy): + def _set_proxy(self, proxy: _pool.PoolConnectionProxy[typing.Any]) -> None: if self._proxy is not None and proxy is not None: # Should not happen unless there is a bug in `Pool`. raise exceptions.InterfaceError( @@ -1687,7 +2162,9 @@ def _set_proxy(self, proxy): self._proxy = proxy - def _check_listeners(self, listeners, listener_type): + def _check_listeners( + self, listeners: compat.Sized, listener_type: str + ) -> None: if listeners: count = len(listeners) @@ -1699,7 +2176,7 @@ def _check_listeners(self, listeners, listener_type): warnings.warn(w) - def _on_release(self, stacklevel=1): + def _on_release(self, stacklevel: int = 1) -> None: # Invalidate external references to the connection. self._pool_release_ctr += 1 # Called when the connection is about to be released to the pool. @@ -1710,10 +2187,10 @@ def _on_release(self, stacklevel=1): self._check_listeners( self._log_listeners, 'log') - def _drop_local_statement_cache(self): + def _drop_local_statement_cache(self) -> None: self._stmt_cache.clear() - def _drop_global_statement_cache(self): + def _drop_global_statement_cache(self) -> None: if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. @@ -1722,10 +2199,10 @@ def _drop_global_statement_cache(self): else: self._drop_local_statement_cache() - def _drop_local_type_cache(self): + def _drop_local_type_cache(self) -> None: self._protocol.get_settings().clear_type_cache() - def _drop_global_type_cache(self): + def _drop_global_type_cache(self) -> None: if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. @@ -1734,7 +2211,7 @@ def _drop_global_type_cache(self): else: self._drop_local_type_cache() - async def reload_schema_state(self): + async def reload_schema_state(self) -> None: """Indicate that the database schema information must be reloaded. For performance reasons, asyncpg caches certain aspects of the @@ -1779,17 +2256,101 @@ async def reload_schema_state(self): self._drop_global_type_cache() self._drop_global_statement_cache() + @typing.overload async def _execute( self, - query, - args, - limit, - timeout, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, *, - return_status=False, - ignore_custom_codec=False, - record_class=None - ): + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> _RecordsType[_RecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> _RecordsType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> _RecordsType[_RecordT] | _RecordsType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> _RecordsTupleType[_RecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> _RecordsTupleType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> _RecordsTupleType[_RecordT] | _RecordsTupleType[_OtherRecordT]: + ... + + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool = False, + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> _RecordsType[typing.Any] | _RecordsTupleType[typing.Any]: with self._stmt_exclusive_section: result, _ = await self.__execute( query, @@ -1803,7 +2364,7 @@ async def _execute( return result @contextlib.contextmanager - def query_logger(self, callback): + def query_logger(self, callback: QueryLogger) -> compat.Iterator[None]: """Context manager that adds `callback` to the list of query loggers, and removes it upon exit. @@ -1834,7 +2395,9 @@ def __call__(self, record): self.remove_query_logger(callback) @contextlib.contextmanager - def _time_and_log(self, query, args, timeout): + def _time_and_log( + self, query: str, args: typing.Any, timeout: float | None + ) -> compat.Iterator[None]: start = time.monotonic() exception = None try: @@ -1854,23 +2417,127 @@ def _time_and_log(self, query, args, timeout): conn_params=self._params, ) for cb in self._query_loggers: - if cb.is_async: - self._loop.create_task(cb.cb(record)) - else: - self._loop.call_soon(cb.cb, record) + cb.invoke(self._loop, record) + @typing.overload async def __execute( self, - query, - args, - limit, - timeout, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, *, - return_status=False, - ignore_custom_codec=False, - record_class=None - ): - executor = lambda stmt, timeout: self._protocol.bind_execute( + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> tuple[ + _RecordsType[_RecordT], _cprotocol.PreparedStatementState[_RecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool, + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ] | tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool = False, + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ] | tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + executor: Executor[ + _OtherRecordT + ] = lambda stmt, timeout: self._protocol.bind_execute( state=stmt, args=args, portal_name='', @@ -1898,8 +2565,15 @@ async def __execute( ) return result, stmt - async def _executemany(self, query, args, timeout): - executor = lambda stmt, timeout: self._protocol.bind_execute_many( + async def _executemany( + self, + query: str, + args: compat.Iterable[compat.Sequence[object]], + timeout: float | None, + ) -> None: + executor: Executor[ + _RecordT + ] = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', @@ -1908,19 +2582,20 @@ async def _executemany(self, query, args, timeout): timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: with self._time_and_log(query, args, timeout): + result: None result, _ = await self._do_execute(query, executor, timeout) return result async def _do_execute( self, - query, - executor, - timeout, - retry=True, + query: str, + executor: Executor[typing.Any], + timeout: float | None, + retry: bool = True, *, - ignore_custom_codec=False, - record_class=None - ): + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None, + ) -> tuple[typing.Any, _cprotocol.PreparedStatementState[typing.Any]]: if timeout is None: stmt = await self._get_statement( query, @@ -1948,7 +2623,7 @@ async def _do_execute( result = await executor(stmt, timeout) finally: after = time.monotonic() - timeout -= after - before + timeout -= after - before # pyright: ignore [reportPossiblyUnboundVariable] # noqa: E501 except exceptions.OutdatedSchemaCacheError: # This exception is raised when we detect a difference between @@ -1992,22 +2667,103 @@ async def _do_execute( return result, stmt -async def connect(dsn=None, *, - host=None, port=None, - user=None, password=None, passfile=None, - database=None, - loop=None, - timeout=60, - statement_cache_size=100, - max_cached_statement_lifetime=300, - max_cacheable_statement_size=1024 * 15, - command_timeout=None, - ssl=None, - direct_tls=False, - connection_class=Connection, - record_class=protocol.Record, - server_settings=None, - target_session_attrs=None): +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + record_class: type[_RecordT], + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> Connection[_RecordT]: + ... + + +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + connection_class: type[_ConnectionT], + record_class: type[_RecordT] = ..., + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> _ConnectionT: + ... + + +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> Connection[protocol.Record]: + ... + + +async def connect( + dsn: str | None = None, + *, + host: connect_utils.HostType | None = None, + port: connect_utils.PortType | None = None, + user: str | None = None, + password: connect_utils.PasswordType | None = None, + passfile: str | None = None, + database: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float = 60, + statement_cache_size: int = 100, + max_cached_statement_lifetime: int = 300, + max_cacheable_statement_size: int = 1024 * 15, + command_timeout: float | None = None, + ssl: connect_utils.SSLType | None = None, + direct_tls: bool = False, + connection_class: type[_ConnectionT] = typing.cast(typing.Any, Connection), + record_class: type[_RecordT] = typing.cast(typing.Any, protocol.Record), + server_settings: dict[str, str] | None = None, + target_session_attrs: connect_utils.SessionAttribute | None = None, +) -> Connection[typing.Any]: r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2348,11 +3104,24 @@ async def connect(dsn=None, *, ) -class _StatementCacheEntry: +_StatementCacheKey = compat.tuple[str, 'type[_RecordT]', bool] + + +class _StatementCacheEntry(typing.Generic[_RecordT]): __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') - def __init__(self, cache, query, statement): + _query: _StatementCacheKey[_RecordT] + _statement: _cprotocol.PreparedStatementState[_RecordT] + _cache: _StatementCache + _cleanup_cb: asyncio.TimerHandle | None + + def __init__( + self, + cache: _StatementCache, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT] + ) -> None: self._cache = cache self._query = query self._statement = statement @@ -2364,7 +3133,23 @@ class _StatementCache: __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', '_max_lifetime') - def __init__(self, *, loop, max_size, on_remove, max_lifetime): + _loop: asyncio.AbstractEventLoop + _entries: compat.OrderedDict[ + _StatementCacheKey[typing.Any], + _StatementCacheEntry[typing.Any] + ] + _max_size: int + _on_remove: OnRemove[typing.Any] + _max_lifetime: float + + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop, + max_size: int, + on_remove: OnRemove[typing.Any], + max_lifetime: float + ) -> None: self._loop = loop self._max_size = max_size self._on_remove = on_remove @@ -2389,21 +3174,21 @@ def __init__(self, *, loop, max_size, on_remove, max_lifetime): # beginning of it. self._entries = collections.OrderedDict() - def __len__(self): + def __len__(self) -> int: return len(self._entries) - def get_max_size(self): + def get_max_size(self) -> int: return self._max_size - def set_max_size(self, new_size): + def set_max_size(self, new_size: int) -> None: assert new_size >= 0 self._max_size = new_size self._maybe_cleanup() - def get_max_lifetime(self): + def get_max_lifetime(self) -> float: return self._max_lifetime - def set_max_lifetime(self, new_lifetime): + def set_max_lifetime(self, new_lifetime: float) -> None: assert new_lifetime >= 0 self._max_lifetime = new_lifetime for entry in self._entries.values(): @@ -2411,14 +3196,16 @@ def set_max_lifetime(self, new_lifetime): # and setup a new one if necessary. self._set_entry_timeout(entry) - def get(self, query, *, promote=True): + def get( + self, query: _StatementCacheKey[_RecordT], *, promote: bool = True + ) -> _cprotocol.PreparedStatementState[_RecordT] | None: if not self._max_size: # The cache is disabled. - return + return None - entry = self._entries.get(query) # type: _StatementCacheEntry + entry: _StatementCacheEntry[_RecordT] | None = self._entries.get(query) if entry is None: - return + return None if entry._statement.closed: # Happens in unittests when we call `stmt._state.mark_closed()` @@ -2426,7 +3213,7 @@ def get(self, query, *, promote=True): # cache error. self._entries.pop(query) self._clear_entry_callback(entry) - return + return None if promote: # `promote` is `False` when `get()` is called by `has()`. @@ -2434,10 +3221,14 @@ def get(self, query, *, promote=True): return entry._statement - def has(self, query): + def has(self, query: _StatementCacheKey[_RecordT]) -> bool: return self.get(query, promote=False) is not None - def put(self, query, statement): + def put( + self, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT], + ) -> None: if not self._max_size: # The cache is disabled. return @@ -2448,10 +3239,12 @@ def put(self, query, statement): # if necessary. self._maybe_cleanup() - def iter_statements(self): + def iter_statements( + self + ) -> compat.Iterator[_cprotocol.PreparedStatementState[typing.Any]]: return (e._statement for e in self._entries.values()) - def clear(self): + def clear(self) -> None: # Store entries for later. entries = tuple(self._entries.values()) @@ -2464,7 +3257,9 @@ def clear(self): self._clear_entry_callback(entry) self._on_remove(entry._statement) - def _set_entry_timeout(self, entry): + def _set_entry_timeout( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: # Clear the existing timeout. self._clear_entry_callback(entry) @@ -2473,23 +3268,31 @@ def _set_entry_timeout(self, entry): entry._cleanup_cb = self._loop.call_later( self._max_lifetime, self._on_entry_expired, entry) - def _new_entry(self, query, statement): + def _new_entry( + self, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT], + ) -> _StatementCacheEntry[_RecordT]: entry = _StatementCacheEntry(self, query, statement) self._set_entry_timeout(entry) return entry - def _on_entry_expired(self, entry): + def _on_entry_expired( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: # `call_later` callback, called when an entry stayed longer # than `self._max_lifetime`. if self._entries.get(entry._query) is entry: self._entries.pop(entry._query) self._on_remove(entry._statement) - def _clear_entry_callback(self, entry): + def _clear_entry_callback( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: if entry._cleanup_cb is not None: entry._cleanup_cb.cancel() - def _maybe_cleanup(self): + def _maybe_cleanup(self) -> None: # Delete cache entries until the size of the cache is `max_size`. while len(self._entries) > self._max_size: old_query, old_entry = self._entries.popitem(last=False) @@ -2500,13 +3303,35 @@ def _maybe_cleanup(self): self._on_remove(old_entry._statement) -class _Callback(typing.NamedTuple): +_CallbackType = compat.Callable[ + _P, + 'compat.Coroutine[typing.Any, typing.Any, None] | None' +] + + +@dataclasses.dataclass(frozen=True) +class _Callback(typing.Generic[_P]): + __slots__ = ('cb', 'is_async') - cb: typing.Callable[..., None] + cb: _CallbackType[_P] is_async: bool + def invoke( + self, + loop: asyncio.AbstractEventLoop, + /, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + if self.is_async: + loop.create_task( + typing.cast(typing.Any, self.cb(*args, **kwargs)) + ) + else: + loop.call_soon(lambda: self.cb(*args, **kwargs)) + @classmethod - def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + def from_callable(cls, cb: _CallbackType[_P]) -> Self: if inspect.iscoroutinefunction(cb): is_async = True elif callable(cb): @@ -2523,39 +3348,52 @@ def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': class _Atomic: __slots__ = ('_acquired',) - def __init__(self): + _acquired: int + + def __init__(self) -> None: self._acquired = 0 - def __enter__(self): + def __enter__(self) -> None: if self._acquired: raise exceptions.InterfaceError( 'cannot perform operation: another operation is in progress') self._acquired = 1 - def __exit__(self, t, e, tb): + def __exit__(self, t: object, e: object, tb: object) -> None: self._acquired = 0 -class _ConnectionProxy: +class _ConnectionProxy(typing.Generic[_RecordT]): # Base class to enable `isinstance(Connection)` check. __slots__ = () -LoggedQuery = collections.namedtuple( - 'LoggedQuery', - ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', - 'conn_params']) -LoggedQuery.__doc__ = 'Log record of an executed query.' - - -ServerCapabilities = collections.namedtuple( - 'ServerCapabilities', - ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', - 'sql_close_all', 'sql_copy_from_where', 'jit']) -ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' - - -def _detect_server_capabilities(server_version, connection_settings): +class LoggedQuery(typing.NamedTuple): + '''Log record of an executed query.''' + query: str + args: typing.Any + timeout: float | None + elapsed: float + exception: BaseException | None + conn_addr: tuple[str, int] | str + conn_params: connect_utils._ConnectionParameters + + +class ServerCapabilities(typing.NamedTuple): + '''PostgreSQL server capabilities.''' + advisory_locks: bool + notifications: bool + plpgsql: bool + sql_reset: bool + sql_close_all: bool + sql_copy_from_where: bool + jit: bool + + +def _detect_server_capabilities( + server_version: types.ServerVersion, + connection_settings: _cprotocol.ConnectionSettings, +) -> ServerCapabilities: if hasattr(connection_settings, 'padb_revision'): # Amazon Redshift detected. advisory_locks = False @@ -2604,18 +3442,18 @@ def _detect_server_capabilities(server_version, connection_settings): ) -def _extract_stack(limit=10): +def _extract_stack(limit: int = 10) -> str: """Replacement for traceback.extract_stack() that only does the necessary work for asyncio debug mode. """ frame = sys._getframe().f_back try: - stack = traceback.StackSummary.extract( + stack: list[traceback.FrameSummary] = traceback.StackSummary.extract( traceback.walk_stack(frame), lookup_lines=False) finally: del frame - apg_path = asyncpg.__path__[0] + apg_path = list(asyncpg.__path__)[0] i = 0 while i < len(stack) and stack[i][0].startswith(apg_path): i += 1 @@ -2625,7 +3463,7 @@ def _extract_stack(limit=10): return ''.join(traceback.format_list(stack)) -def _check_record_class(record_class): +def _check_record_class(record_class: type[typing.Any]) -> None: if record_class is protocol.Record: pass elif ( @@ -2646,7 +3484,10 @@ def _check_record_class(record_class): ) -def _weak_maybe_gc_stmt(weak_ref, stmt): +def _weak_maybe_gc_stmt( + weak_ref: weakref.ref[Connection[typing.Any]], + stmt: _cprotocol.PreparedStatementState[typing.Any], +) -> None: self = weak_ref() if self is not None: self._maybe_gc_stmt(stmt) diff --git a/asyncpg/connresource.py b/asyncpg/connresource.py index 3b0c1d3c..60aa97a6 100644 --- a/asyncpg/connresource.py +++ b/asyncpg/connresource.py @@ -5,31 +5,46 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import functools +import typing from . import exceptions +if typing.TYPE_CHECKING: + from . import compat + from . import connection as _conn -def guarded(meth): +_F = typing.TypeVar('_F', bound='compat.Callable[..., typing.Any]') + + +def guarded(meth: _F) -> _F: """A decorator to add a sanity check to ConnectionResource methods.""" @functools.wraps(meth) - def _check(self, *args, **kwargs): + def _check( + self: ConnectionResource, + *args: typing.Any, + **kwargs: typing.Any + ) -> typing.Any: self._check_conn_validity(meth.__name__) return meth(self, *args, **kwargs) - return _check + return typing.cast(_F, _check) class ConnectionResource: __slots__ = ('_connection', '_con_release_ctr') - def __init__(self, connection): + _connection: _conn.Connection[typing.Any] + _con_release_ctr: int + + def __init__(self, connection: _conn.Connection[typing.Any]) -> None: self._connection = connection self._con_release_ctr = connection._pool_release_ctr - def _check_conn_validity(self, meth_name): + def _check_conn_validity(self, meth_name: str) -> None: con_release_ctr = self._connection._pool_release_ctr if con_release_ctr != self._con_release_ctr: raise exceptions.InterfaceError( diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index b4abeed1..0b3980ba 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -4,14 +4,30 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import collections +import typing from . import connresource from . import exceptions +if typing.TYPE_CHECKING: + import sys -class CursorFactory(connresource.ConnectionResource): + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from .protocol import protocol as _cprotocol + from . import connection as _connection + from . import compat + +_RecordT = typing.TypeVar('_RecordT', bound='_cprotocol.Record') + + +class CursorFactory(connresource.ConnectionResource, typing.Generic[_RecordT]): """A cursor interface for the results of a query. A cursor interface can be used to initiate efficient traversal of the @@ -27,16 +43,49 @@ class CursorFactory(connresource.ConnectionResource): '_record_class', ) + _state: _cprotocol.PreparedStatementState[_RecordT] | None + _args: compat.Sequence[object] + _prefetch: int | None + _query: str + _timeout: float | None + _record_class: type[_RecordT] | None + + @typing.overload + def __init__( + self: CursorFactory[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: None + ) -> None: + ... + + @typing.overload + def __init__( + self: CursorFactory[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: type[_RecordT] + ) -> None: + ... + def __init__( self, - connection, - query, - state, - args, - prefetch, - timeout, - record_class - ): + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: type[_RecordT] | None + ) -> None: super().__init__(connection) self._args = args self._prefetch = prefetch @@ -48,7 +97,7 @@ def __init__( state.attach() @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> CursorIterator[_RecordT]: prefetch = 50 if self._prefetch is None else self._prefetch return CursorIterator( self._connection, @@ -61,11 +110,13 @@ def __aiter__(self): ) @connresource.guarded - def __await__(self): + def __await__( + self + ) -> compat.Generator[typing.Any, None, Cursor[_RecordT]]: if self._prefetch is not None: raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') - cursor = Cursor( + cursor: Cursor[_RecordT] = Cursor( self._connection, self._query, self._state, @@ -74,13 +125,13 @@ def __await__(self): ) return cursor._init(self._timeout).__await__() - def __del__(self): + def __del__(self) -> None: if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) -class BaseCursor(connresource.ConnectionResource): +class BaseCursor(connresource.ConnectionResource, typing.Generic[_RecordT]): __slots__ = ( '_state', @@ -91,7 +142,43 @@ class BaseCursor(connresource.ConnectionResource): '_record_class', ) - def __init__(self, connection, query, state, args, record_class): + _state: _cprotocol.PreparedStatementState[_RecordT] | None + _args: compat.Sequence[object] + _portal_name: str | None + _exhausted: bool + _query: str + _record_class: type[_RecordT] | None + + @typing.overload + def __init__( + self: BaseCursor[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: None, + ) -> None: + ... + + @typing.overload + def __init__( + self: BaseCursor[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT], + ) -> None: + ... + + def __init__( + self, + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT] | None, + ) -> None: super().__init__(connection) self._args = args self._state = state @@ -102,7 +189,7 @@ def __init__(self, connection, query, state, args, record_class): self._query = query self._record_class = record_class - def _check_ready(self): + def _check_ready(self) -> None: if self._state is None: raise exceptions.InterfaceError( 'cursor: no associated prepared statement') @@ -115,7 +202,7 @@ def _check_ready(self): raise exceptions.NoActiveSQLTransactionError( 'cursor cannot be created outside of a transaction') - async def _bind_exec(self, n, timeout): + async def _bind_exec(self, n: int, timeout: float | None) -> typing.Any: self._check_ready() if self._portal_name: @@ -126,11 +213,15 @@ async def _bind_exec(self, n, timeout): protocol = con._protocol self._portal_name = con._get_unique_id('portal') + + if typing.TYPE_CHECKING: + assert self._state is not None + buffer, _, self._exhausted = await protocol.bind_execute( self._state, self._args, self._portal_name, n, True, timeout) return buffer - async def _bind(self, timeout): + async def _bind(self, timeout: float | None) -> typing.Any: self._check_ready() if self._portal_name: @@ -141,12 +232,16 @@ async def _bind(self, timeout): protocol = con._protocol self._portal_name = con._get_unique_id('portal') + + if typing.TYPE_CHECKING: + assert self._state is not None + buffer = await protocol.bind(self._state, self._args, self._portal_name, timeout) return buffer - async def _exec(self, n, timeout): + async def _exec(self, n: int, timeout: float | None) -> typing.Any: self._check_ready() if not self._portal_name: @@ -158,7 +253,7 @@ async def _exec(self, n, timeout): self._state, self._portal_name, n, True, timeout) return buffer - async def _close_portal(self, timeout): + async def _close_portal(self, timeout: float | None) -> None: self._check_ready() if not self._portal_name: @@ -169,8 +264,8 @@ async def _close_portal(self, timeout): await protocol.close_portal(self._portal_name, timeout) self._portal_name = None - def __repr__(self): - attrs = [] + def __repr__(self) -> str: + attrs: list[str] = [] if self._exhausted: attrs.append('exhausted') attrs.append('') # to separate from id @@ -182,29 +277,59 @@ def __repr__(self): return '<{}.{} "{!s:.30}" {}{:#x}>'.format( mod, self.__class__.__name__, - self._state.query, + self._state.query if self._state is not None else '', ' '.join(attrs), id(self)) - def __del__(self): + def __del__(self) -> None: if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) -class CursorIterator(BaseCursor): +class CursorIterator(BaseCursor[_RecordT]): __slots__ = ('_buffer', '_prefetch', '_timeout') + _buffer: compat.deque[_RecordT] + _prefetch: int + _timeout: float | None + + @typing.overload + def __init__( + self: CursorIterator[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: None, + prefetch: int, + timeout: float | None, + ) -> None: + ... + + @typing.overload + def __init__( + self: CursorIterator[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT], + prefetch: int, + timeout: float | None, + ) -> None: + ... + def __init__( self, - connection, - query, - state, - args, - record_class, - prefetch, - timeout - ): + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT] | None, + prefetch: int, + timeout: float | None, + ) -> None: super().__init__(connection, query, state, args, record_class) if prefetch <= 0: @@ -216,11 +341,11 @@ def __init__( self._timeout = timeout @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> Self: return self @connresource.guarded - async def __anext__(self): + async def __anext__(self) -> _RecordT: if self._state is None: self._state = await self._connection._get_statement( self._query, @@ -247,12 +372,12 @@ async def __anext__(self): raise StopAsyncIteration -class Cursor(BaseCursor): +class Cursor(BaseCursor[_RecordT]): """An open *portal* into the results of a query.""" __slots__ = () - async def _init(self, timeout): + async def _init(self, timeout: float | None) -> Self: if self._state is None: self._state = await self._connection._get_statement( self._query, @@ -266,7 +391,9 @@ async def _init(self, timeout): return self @connresource.guarded - async def fetch(self, n, *, timeout=None): + async def fetch( + self, n: int, *, timeout: float | None = None + ) -> list[_RecordT]: r"""Return the next *n* rows as a list of :class:`Record` objects. :param float timeout: Optional timeout value in seconds. @@ -278,13 +405,15 @@ async def fetch(self, n, *, timeout=None): raise exceptions.InterfaceError('n must be greater than zero') if self._exhausted: return [] - recs = await self._exec(n, timeout) + recs: list[_RecordT] = await self._exec(n, timeout) if len(recs) < n: self._exhausted = True return recs @connresource.guarded - async def fetchrow(self, *, timeout=None): + async def fetchrow( + self, *, timeout: float | None = None + ) -> _RecordT | None: r"""Return the next row. :param float timeout: Optional timeout value in seconds. @@ -294,14 +423,14 @@ async def fetchrow(self, *, timeout=None): self._check_ready() if self._exhausted: return None - recs = await self._exec(1, timeout) + recs: list[_RecordT] = await self._exec(1, timeout) if len(recs) < 1: self._exhausted = True return None return recs[0] @connresource.guarded - async def forward(self, n, *, timeout=None) -> int: + async def forward(self, n: int, *, timeout: float | None = None) -> int: r"""Skip over the next *n* rows. :param float timeout: Optional timeout value in seconds. diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py index 8c97d5a0..4769f766 100644 --- a/asyncpg/exceptions/__init__.py +++ b/asyncpg/exceptions/__init__.py @@ -1,88 +1,91 @@ # GENERATED FROM postgresql/src/backend/utils/errcodes.txt # DO NOT MODIFY, use tools/generate_exceptions.py to update +from __future__ import annotations + +import typing from ._base import * # NOQA from . import _base class PostgresWarning(_base.PostgresLogMessage, Warning): - sqlstate = '01000' + sqlstate: typing.ClassVar[str] = '01000' class DynamicResultSetsReturned(PostgresWarning): - sqlstate = '0100C' + sqlstate: typing.ClassVar[str] = '0100C' class ImplicitZeroBitPadding(PostgresWarning): - sqlstate = '01008' + sqlstate: typing.ClassVar[str] = '01008' class NullValueEliminatedInSetFunction(PostgresWarning): - sqlstate = '01003' + sqlstate: typing.ClassVar[str] = '01003' class PrivilegeNotGranted(PostgresWarning): - sqlstate = '01007' + sqlstate: typing.ClassVar[str] = '01007' class PrivilegeNotRevoked(PostgresWarning): - sqlstate = '01006' + sqlstate: typing.ClassVar[str] = '01006' class StringDataRightTruncation(PostgresWarning): - sqlstate = '01004' + sqlstate: typing.ClassVar[str] = '01004' class DeprecatedFeature(PostgresWarning): - sqlstate = '01P01' + sqlstate: typing.ClassVar[str] = '01P01' class NoData(PostgresWarning): - sqlstate = '02000' + sqlstate: typing.ClassVar[str] = '02000' class NoAdditionalDynamicResultSetsReturned(NoData): - sqlstate = '02001' + sqlstate: typing.ClassVar[str] = '02001' class SQLStatementNotYetCompleteError(_base.PostgresError): - sqlstate = '03000' + sqlstate: typing.ClassVar[str] = '03000' class PostgresConnectionError(_base.PostgresError): - sqlstate = '08000' + sqlstate: typing.ClassVar[str] = '08000' class ConnectionDoesNotExistError(PostgresConnectionError): - sqlstate = '08003' + sqlstate: typing.ClassVar[str] = '08003' class ConnectionFailureError(PostgresConnectionError): - sqlstate = '08006' + sqlstate: typing.ClassVar[str] = '08006' class ClientCannotConnectError(PostgresConnectionError): - sqlstate = '08001' + sqlstate: typing.ClassVar[str] = '08001' class ConnectionRejectionError(PostgresConnectionError): - sqlstate = '08004' + sqlstate: typing.ClassVar[str] = '08004' class TransactionResolutionUnknownError(PostgresConnectionError): - sqlstate = '08007' + sqlstate: typing.ClassVar[str] = '08007' class ProtocolViolationError(PostgresConnectionError): - sqlstate = '08P01' + sqlstate: typing.ClassVar[str] = '08P01' class TriggeredActionError(_base.PostgresError): - sqlstate = '09000' + sqlstate: typing.ClassVar[str] = '09000' class FeatureNotSupportedError(_base.PostgresError): - sqlstate = '0A000' + sqlstate: typing.ClassVar[str] = '0A000' class InvalidCachedStatementError(FeatureNotSupportedError): @@ -90,969 +93,969 @@ class InvalidCachedStatementError(FeatureNotSupportedError): class InvalidTransactionInitiationError(_base.PostgresError): - sqlstate = '0B000' + sqlstate: typing.ClassVar[str] = '0B000' class LocatorError(_base.PostgresError): - sqlstate = '0F000' + sqlstate: typing.ClassVar[str] = '0F000' class InvalidLocatorSpecificationError(LocatorError): - sqlstate = '0F001' + sqlstate: typing.ClassVar[str] = '0F001' class InvalidGrantorError(_base.PostgresError): - sqlstate = '0L000' + sqlstate: typing.ClassVar[str] = '0L000' class InvalidGrantOperationError(InvalidGrantorError): - sqlstate = '0LP01' + sqlstate: typing.ClassVar[str] = '0LP01' class InvalidRoleSpecificationError(_base.PostgresError): - sqlstate = '0P000' + sqlstate: typing.ClassVar[str] = '0P000' class DiagnosticsError(_base.PostgresError): - sqlstate = '0Z000' + sqlstate: typing.ClassVar[str] = '0Z000' class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError): - sqlstate = '0Z002' + sqlstate: typing.ClassVar[str] = '0Z002' class CaseNotFoundError(_base.PostgresError): - sqlstate = '20000' + sqlstate: typing.ClassVar[str] = '20000' class CardinalityViolationError(_base.PostgresError): - sqlstate = '21000' + sqlstate: typing.ClassVar[str] = '21000' class DataError(_base.PostgresError): - sqlstate = '22000' + sqlstate: typing.ClassVar[str] = '22000' class ArraySubscriptError(DataError): - sqlstate = '2202E' + sqlstate: typing.ClassVar[str] = '2202E' class CharacterNotInRepertoireError(DataError): - sqlstate = '22021' + sqlstate: typing.ClassVar[str] = '22021' class DatetimeFieldOverflowError(DataError): - sqlstate = '22008' + sqlstate: typing.ClassVar[str] = '22008' class DivisionByZeroError(DataError): - sqlstate = '22012' + sqlstate: typing.ClassVar[str] = '22012' class ErrorInAssignmentError(DataError): - sqlstate = '22005' + sqlstate: typing.ClassVar[str] = '22005' class EscapeCharacterConflictError(DataError): - sqlstate = '2200B' + sqlstate: typing.ClassVar[str] = '2200B' class IndicatorOverflowError(DataError): - sqlstate = '22022' + sqlstate: typing.ClassVar[str] = '22022' class IntervalFieldOverflowError(DataError): - sqlstate = '22015' + sqlstate: typing.ClassVar[str] = '22015' class InvalidArgumentForLogarithmError(DataError): - sqlstate = '2201E' + sqlstate: typing.ClassVar[str] = '2201E' class InvalidArgumentForNtileFunctionError(DataError): - sqlstate = '22014' + sqlstate: typing.ClassVar[str] = '22014' class InvalidArgumentForNthValueFunctionError(DataError): - sqlstate = '22016' + sqlstate: typing.ClassVar[str] = '22016' class InvalidArgumentForPowerFunctionError(DataError): - sqlstate = '2201F' + sqlstate: typing.ClassVar[str] = '2201F' class InvalidArgumentForWidthBucketFunctionError(DataError): - sqlstate = '2201G' + sqlstate: typing.ClassVar[str] = '2201G' class InvalidCharacterValueForCastError(DataError): - sqlstate = '22018' + sqlstate: typing.ClassVar[str] = '22018' class InvalidDatetimeFormatError(DataError): - sqlstate = '22007' + sqlstate: typing.ClassVar[str] = '22007' class InvalidEscapeCharacterError(DataError): - sqlstate = '22019' + sqlstate: typing.ClassVar[str] = '22019' class InvalidEscapeOctetError(DataError): - sqlstate = '2200D' + sqlstate: typing.ClassVar[str] = '2200D' class InvalidEscapeSequenceError(DataError): - sqlstate = '22025' + sqlstate: typing.ClassVar[str] = '22025' class NonstandardUseOfEscapeCharacterError(DataError): - sqlstate = '22P06' + sqlstate: typing.ClassVar[str] = '22P06' class InvalidIndicatorParameterValueError(DataError): - sqlstate = '22010' + sqlstate: typing.ClassVar[str] = '22010' class InvalidParameterValueError(DataError): - sqlstate = '22023' + sqlstate: typing.ClassVar[str] = '22023' class InvalidPrecedingOrFollowingSizeError(DataError): - sqlstate = '22013' + sqlstate: typing.ClassVar[str] = '22013' class InvalidRegularExpressionError(DataError): - sqlstate = '2201B' + sqlstate: typing.ClassVar[str] = '2201B' class InvalidRowCountInLimitClauseError(DataError): - sqlstate = '2201W' + sqlstate: typing.ClassVar[str] = '2201W' class InvalidRowCountInResultOffsetClauseError(DataError): - sqlstate = '2201X' + sqlstate: typing.ClassVar[str] = '2201X' class InvalidTablesampleArgumentError(DataError): - sqlstate = '2202H' + sqlstate: typing.ClassVar[str] = '2202H' class InvalidTablesampleRepeatError(DataError): - sqlstate = '2202G' + sqlstate: typing.ClassVar[str] = '2202G' class InvalidTimeZoneDisplacementValueError(DataError): - sqlstate = '22009' + sqlstate: typing.ClassVar[str] = '22009' class InvalidUseOfEscapeCharacterError(DataError): - sqlstate = '2200C' + sqlstate: typing.ClassVar[str] = '2200C' class MostSpecificTypeMismatchError(DataError): - sqlstate = '2200G' + sqlstate: typing.ClassVar[str] = '2200G' class NullValueNotAllowedError(DataError): - sqlstate = '22004' + sqlstate: typing.ClassVar[str] = '22004' class NullValueNoIndicatorParameterError(DataError): - sqlstate = '22002' + sqlstate: typing.ClassVar[str] = '22002' class NumericValueOutOfRangeError(DataError): - sqlstate = '22003' + sqlstate: typing.ClassVar[str] = '22003' class SequenceGeneratorLimitExceededError(DataError): - sqlstate = '2200H' + sqlstate: typing.ClassVar[str] = '2200H' class StringDataLengthMismatchError(DataError): - sqlstate = '22026' + sqlstate: typing.ClassVar[str] = '22026' class StringDataRightTruncationError(DataError): - sqlstate = '22001' + sqlstate: typing.ClassVar[str] = '22001' class SubstringError(DataError): - sqlstate = '22011' + sqlstate: typing.ClassVar[str] = '22011' class TrimError(DataError): - sqlstate = '22027' + sqlstate: typing.ClassVar[str] = '22027' class UnterminatedCStringError(DataError): - sqlstate = '22024' + sqlstate: typing.ClassVar[str] = '22024' class ZeroLengthCharacterStringError(DataError): - sqlstate = '2200F' + sqlstate: typing.ClassVar[str] = '2200F' class PostgresFloatingPointError(DataError): - sqlstate = '22P01' + sqlstate: typing.ClassVar[str] = '22P01' class InvalidTextRepresentationError(DataError): - sqlstate = '22P02' + sqlstate: typing.ClassVar[str] = '22P02' class InvalidBinaryRepresentationError(DataError): - sqlstate = '22P03' + sqlstate: typing.ClassVar[str] = '22P03' class BadCopyFileFormatError(DataError): - sqlstate = '22P04' + sqlstate: typing.ClassVar[str] = '22P04' class UntranslatableCharacterError(DataError): - sqlstate = '22P05' + sqlstate: typing.ClassVar[str] = '22P05' class NotAnXmlDocumentError(DataError): - sqlstate = '2200L' + sqlstate: typing.ClassVar[str] = '2200L' class InvalidXmlDocumentError(DataError): - sqlstate = '2200M' + sqlstate: typing.ClassVar[str] = '2200M' class InvalidXmlContentError(DataError): - sqlstate = '2200N' + sqlstate: typing.ClassVar[str] = '2200N' class InvalidXmlCommentError(DataError): - sqlstate = '2200S' + sqlstate: typing.ClassVar[str] = '2200S' class InvalidXmlProcessingInstructionError(DataError): - sqlstate = '2200T' + sqlstate: typing.ClassVar[str] = '2200T' class DuplicateJsonObjectKeyValueError(DataError): - sqlstate = '22030' + sqlstate: typing.ClassVar[str] = '22030' class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError): - sqlstate = '22031' + sqlstate: typing.ClassVar[str] = '22031' class InvalidJsonTextError(DataError): - sqlstate = '22032' + sqlstate: typing.ClassVar[str] = '22032' class InvalidSQLJsonSubscriptError(DataError): - sqlstate = '22033' + sqlstate: typing.ClassVar[str] = '22033' class MoreThanOneSQLJsonItemError(DataError): - sqlstate = '22034' + sqlstate: typing.ClassVar[str] = '22034' class NoSQLJsonItemError(DataError): - sqlstate = '22035' + sqlstate: typing.ClassVar[str] = '22035' class NonNumericSQLJsonItemError(DataError): - sqlstate = '22036' + sqlstate: typing.ClassVar[str] = '22036' class NonUniqueKeysInAJsonObjectError(DataError): - sqlstate = '22037' + sqlstate: typing.ClassVar[str] = '22037' class SingletonSQLJsonItemRequiredError(DataError): - sqlstate = '22038' + sqlstate: typing.ClassVar[str] = '22038' class SQLJsonArrayNotFoundError(DataError): - sqlstate = '22039' + sqlstate: typing.ClassVar[str] = '22039' class SQLJsonMemberNotFoundError(DataError): - sqlstate = '2203A' + sqlstate: typing.ClassVar[str] = '2203A' class SQLJsonNumberNotFoundError(DataError): - sqlstate = '2203B' + sqlstate: typing.ClassVar[str] = '2203B' class SQLJsonObjectNotFoundError(DataError): - sqlstate = '2203C' + sqlstate: typing.ClassVar[str] = '2203C' class TooManyJsonArrayElementsError(DataError): - sqlstate = '2203D' + sqlstate: typing.ClassVar[str] = '2203D' class TooManyJsonObjectMembersError(DataError): - sqlstate = '2203E' + sqlstate: typing.ClassVar[str] = '2203E' class SQLJsonScalarRequiredError(DataError): - sqlstate = '2203F' + sqlstate: typing.ClassVar[str] = '2203F' class SQLJsonItemCannotBeCastToTargetTypeError(DataError): - sqlstate = '2203G' + sqlstate: typing.ClassVar[str] = '2203G' class IntegrityConstraintViolationError(_base.PostgresError): - sqlstate = '23000' + sqlstate: typing.ClassVar[str] = '23000' class RestrictViolationError(IntegrityConstraintViolationError): - sqlstate = '23001' + sqlstate: typing.ClassVar[str] = '23001' class NotNullViolationError(IntegrityConstraintViolationError): - sqlstate = '23502' + sqlstate: typing.ClassVar[str] = '23502' class ForeignKeyViolationError(IntegrityConstraintViolationError): - sqlstate = '23503' + sqlstate: typing.ClassVar[str] = '23503' class UniqueViolationError(IntegrityConstraintViolationError): - sqlstate = '23505' + sqlstate: typing.ClassVar[str] = '23505' class CheckViolationError(IntegrityConstraintViolationError): - sqlstate = '23514' + sqlstate: typing.ClassVar[str] = '23514' class ExclusionViolationError(IntegrityConstraintViolationError): - sqlstate = '23P01' + sqlstate: typing.ClassVar[str] = '23P01' class InvalidCursorStateError(_base.PostgresError): - sqlstate = '24000' + sqlstate: typing.ClassVar[str] = '24000' class InvalidTransactionStateError(_base.PostgresError): - sqlstate = '25000' + sqlstate: typing.ClassVar[str] = '25000' class ActiveSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25001' + sqlstate: typing.ClassVar[str] = '25001' class BranchTransactionAlreadyActiveError(InvalidTransactionStateError): - sqlstate = '25002' + sqlstate: typing.ClassVar[str] = '25002' class HeldCursorRequiresSameIsolationLevelError(InvalidTransactionStateError): - sqlstate = '25008' + sqlstate: typing.ClassVar[str] = '25008' class InappropriateAccessModeForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25003' + sqlstate: typing.ClassVar[str] = '25003' class InappropriateIsolationLevelForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25004' + sqlstate: typing.ClassVar[str] = '25004' class NoActiveSQLTransactionForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25005' + sqlstate: typing.ClassVar[str] = '25005' class ReadOnlySQLTransactionError(InvalidTransactionStateError): - sqlstate = '25006' + sqlstate: typing.ClassVar[str] = '25006' class SchemaAndDataStatementMixingNotSupportedError( InvalidTransactionStateError): - sqlstate = '25007' + sqlstate: typing.ClassVar[str] = '25007' class NoActiveSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25P01' + sqlstate: typing.ClassVar[str] = '25P01' class InFailedSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25P02' + sqlstate: typing.ClassVar[str] = '25P02' class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError): - sqlstate = '25P03' + sqlstate: typing.ClassVar[str] = '25P03' class InvalidSQLStatementNameError(_base.PostgresError): - sqlstate = '26000' + sqlstate: typing.ClassVar[str] = '26000' class TriggeredDataChangeViolationError(_base.PostgresError): - sqlstate = '27000' + sqlstate: typing.ClassVar[str] = '27000' class InvalidAuthorizationSpecificationError(_base.PostgresError): - sqlstate = '28000' + sqlstate: typing.ClassVar[str] = '28000' class InvalidPasswordError(InvalidAuthorizationSpecificationError): - sqlstate = '28P01' + sqlstate: typing.ClassVar[str] = '28P01' class DependentPrivilegeDescriptorsStillExistError(_base.PostgresError): - sqlstate = '2B000' + sqlstate: typing.ClassVar[str] = '2B000' class DependentObjectsStillExistError( DependentPrivilegeDescriptorsStillExistError): - sqlstate = '2BP01' + sqlstate: typing.ClassVar[str] = '2BP01' class InvalidTransactionTerminationError(_base.PostgresError): - sqlstate = '2D000' + sqlstate: typing.ClassVar[str] = '2D000' class SQLRoutineError(_base.PostgresError): - sqlstate = '2F000' + sqlstate: typing.ClassVar[str] = '2F000' class FunctionExecutedNoReturnStatementError(SQLRoutineError): - sqlstate = '2F005' + sqlstate: typing.ClassVar[str] = '2F005' class ModifyingSQLDataNotPermittedError(SQLRoutineError): - sqlstate = '2F002' + sqlstate: typing.ClassVar[str] = '2F002' class ProhibitedSQLStatementAttemptedError(SQLRoutineError): - sqlstate = '2F003' + sqlstate: typing.ClassVar[str] = '2F003' class ReadingSQLDataNotPermittedError(SQLRoutineError): - sqlstate = '2F004' + sqlstate: typing.ClassVar[str] = '2F004' class InvalidCursorNameError(_base.PostgresError): - sqlstate = '34000' + sqlstate: typing.ClassVar[str] = '34000' class ExternalRoutineError(_base.PostgresError): - sqlstate = '38000' + sqlstate: typing.ClassVar[str] = '38000' class ContainingSQLNotPermittedError(ExternalRoutineError): - sqlstate = '38001' + sqlstate: typing.ClassVar[str] = '38001' class ModifyingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): - sqlstate = '38002' + sqlstate: typing.ClassVar[str] = '38002' class ProhibitedExternalRoutineSQLStatementAttemptedError( ExternalRoutineError): - sqlstate = '38003' + sqlstate: typing.ClassVar[str] = '38003' class ReadingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): - sqlstate = '38004' + sqlstate: typing.ClassVar[str] = '38004' class ExternalRoutineInvocationError(_base.PostgresError): - sqlstate = '39000' + sqlstate: typing.ClassVar[str] = '39000' class InvalidSqlstateReturnedError(ExternalRoutineInvocationError): - sqlstate = '39001' + sqlstate: typing.ClassVar[str] = '39001' class NullValueInExternalRoutineNotAllowedError( ExternalRoutineInvocationError): - sqlstate = '39004' + sqlstate: typing.ClassVar[str] = '39004' class TriggerProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P01' + sqlstate: typing.ClassVar[str] = '39P01' class SrfProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P02' + sqlstate: typing.ClassVar[str] = '39P02' class EventTriggerProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P03' + sqlstate: typing.ClassVar[str] = '39P03' class SavepointError(_base.PostgresError): - sqlstate = '3B000' + sqlstate: typing.ClassVar[str] = '3B000' class InvalidSavepointSpecificationError(SavepointError): - sqlstate = '3B001' + sqlstate: typing.ClassVar[str] = '3B001' class InvalidCatalogNameError(_base.PostgresError): - sqlstate = '3D000' + sqlstate: typing.ClassVar[str] = '3D000' class InvalidSchemaNameError(_base.PostgresError): - sqlstate = '3F000' + sqlstate: typing.ClassVar[str] = '3F000' class TransactionRollbackError(_base.PostgresError): - sqlstate = '40000' + sqlstate: typing.ClassVar[str] = '40000' class TransactionIntegrityConstraintViolationError(TransactionRollbackError): - sqlstate = '40002' + sqlstate: typing.ClassVar[str] = '40002' class SerializationError(TransactionRollbackError): - sqlstate = '40001' + sqlstate: typing.ClassVar[str] = '40001' class StatementCompletionUnknownError(TransactionRollbackError): - sqlstate = '40003' + sqlstate: typing.ClassVar[str] = '40003' class DeadlockDetectedError(TransactionRollbackError): - sqlstate = '40P01' + sqlstate: typing.ClassVar[str] = '40P01' class SyntaxOrAccessError(_base.PostgresError): - sqlstate = '42000' + sqlstate: typing.ClassVar[str] = '42000' class PostgresSyntaxError(SyntaxOrAccessError): - sqlstate = '42601' + sqlstate: typing.ClassVar[str] = '42601' class InsufficientPrivilegeError(SyntaxOrAccessError): - sqlstate = '42501' + sqlstate: typing.ClassVar[str] = '42501' class CannotCoerceError(SyntaxOrAccessError): - sqlstate = '42846' + sqlstate: typing.ClassVar[str] = '42846' class GroupingError(SyntaxOrAccessError): - sqlstate = '42803' + sqlstate: typing.ClassVar[str] = '42803' class WindowingError(SyntaxOrAccessError): - sqlstate = '42P20' + sqlstate: typing.ClassVar[str] = '42P20' class InvalidRecursionError(SyntaxOrAccessError): - sqlstate = '42P19' + sqlstate: typing.ClassVar[str] = '42P19' class InvalidForeignKeyError(SyntaxOrAccessError): - sqlstate = '42830' + sqlstate: typing.ClassVar[str] = '42830' class InvalidNameError(SyntaxOrAccessError): - sqlstate = '42602' + sqlstate: typing.ClassVar[str] = '42602' class NameTooLongError(SyntaxOrAccessError): - sqlstate = '42622' + sqlstate: typing.ClassVar[str] = '42622' class ReservedNameError(SyntaxOrAccessError): - sqlstate = '42939' + sqlstate: typing.ClassVar[str] = '42939' class DatatypeMismatchError(SyntaxOrAccessError): - sqlstate = '42804' + sqlstate: typing.ClassVar[str] = '42804' class IndeterminateDatatypeError(SyntaxOrAccessError): - sqlstate = '42P18' + sqlstate: typing.ClassVar[str] = '42P18' class CollationMismatchError(SyntaxOrAccessError): - sqlstate = '42P21' + sqlstate: typing.ClassVar[str] = '42P21' class IndeterminateCollationError(SyntaxOrAccessError): - sqlstate = '42P22' + sqlstate: typing.ClassVar[str] = '42P22' class WrongObjectTypeError(SyntaxOrAccessError): - sqlstate = '42809' + sqlstate: typing.ClassVar[str] = '42809' class GeneratedAlwaysError(SyntaxOrAccessError): - sqlstate = '428C9' + sqlstate: typing.ClassVar[str] = '428C9' class UndefinedColumnError(SyntaxOrAccessError): - sqlstate = '42703' + sqlstate: typing.ClassVar[str] = '42703' class UndefinedFunctionError(SyntaxOrAccessError): - sqlstate = '42883' + sqlstate: typing.ClassVar[str] = '42883' class UndefinedTableError(SyntaxOrAccessError): - sqlstate = '42P01' + sqlstate: typing.ClassVar[str] = '42P01' class UndefinedParameterError(SyntaxOrAccessError): - sqlstate = '42P02' + sqlstate: typing.ClassVar[str] = '42P02' class UndefinedObjectError(SyntaxOrAccessError): - sqlstate = '42704' + sqlstate: typing.ClassVar[str] = '42704' class DuplicateColumnError(SyntaxOrAccessError): - sqlstate = '42701' + sqlstate: typing.ClassVar[str] = '42701' class DuplicateCursorError(SyntaxOrAccessError): - sqlstate = '42P03' + sqlstate: typing.ClassVar[str] = '42P03' class DuplicateDatabaseError(SyntaxOrAccessError): - sqlstate = '42P04' + sqlstate: typing.ClassVar[str] = '42P04' class DuplicateFunctionError(SyntaxOrAccessError): - sqlstate = '42723' + sqlstate: typing.ClassVar[str] = '42723' class DuplicatePreparedStatementError(SyntaxOrAccessError): - sqlstate = '42P05' + sqlstate: typing.ClassVar[str] = '42P05' class DuplicateSchemaError(SyntaxOrAccessError): - sqlstate = '42P06' + sqlstate: typing.ClassVar[str] = '42P06' class DuplicateTableError(SyntaxOrAccessError): - sqlstate = '42P07' + sqlstate: typing.ClassVar[str] = '42P07' class DuplicateAliasError(SyntaxOrAccessError): - sqlstate = '42712' + sqlstate: typing.ClassVar[str] = '42712' class DuplicateObjectError(SyntaxOrAccessError): - sqlstate = '42710' + sqlstate: typing.ClassVar[str] = '42710' class AmbiguousColumnError(SyntaxOrAccessError): - sqlstate = '42702' + sqlstate: typing.ClassVar[str] = '42702' class AmbiguousFunctionError(SyntaxOrAccessError): - sqlstate = '42725' + sqlstate: typing.ClassVar[str] = '42725' class AmbiguousParameterError(SyntaxOrAccessError): - sqlstate = '42P08' + sqlstate: typing.ClassVar[str] = '42P08' class AmbiguousAliasError(SyntaxOrAccessError): - sqlstate = '42P09' + sqlstate: typing.ClassVar[str] = '42P09' class InvalidColumnReferenceError(SyntaxOrAccessError): - sqlstate = '42P10' + sqlstate: typing.ClassVar[str] = '42P10' class InvalidColumnDefinitionError(SyntaxOrAccessError): - sqlstate = '42611' + sqlstate: typing.ClassVar[str] = '42611' class InvalidCursorDefinitionError(SyntaxOrAccessError): - sqlstate = '42P11' + sqlstate: typing.ClassVar[str] = '42P11' class InvalidDatabaseDefinitionError(SyntaxOrAccessError): - sqlstate = '42P12' + sqlstate: typing.ClassVar[str] = '42P12' class InvalidFunctionDefinitionError(SyntaxOrAccessError): - sqlstate = '42P13' + sqlstate: typing.ClassVar[str] = '42P13' class InvalidPreparedStatementDefinitionError(SyntaxOrAccessError): - sqlstate = '42P14' + sqlstate: typing.ClassVar[str] = '42P14' class InvalidSchemaDefinitionError(SyntaxOrAccessError): - sqlstate = '42P15' + sqlstate: typing.ClassVar[str] = '42P15' class InvalidTableDefinitionError(SyntaxOrAccessError): - sqlstate = '42P16' + sqlstate: typing.ClassVar[str] = '42P16' class InvalidObjectDefinitionError(SyntaxOrAccessError): - sqlstate = '42P17' + sqlstate: typing.ClassVar[str] = '42P17' class WithCheckOptionViolationError(_base.PostgresError): - sqlstate = '44000' + sqlstate: typing.ClassVar[str] = '44000' class InsufficientResourcesError(_base.PostgresError): - sqlstate = '53000' + sqlstate: typing.ClassVar[str] = '53000' class DiskFullError(InsufficientResourcesError): - sqlstate = '53100' + sqlstate: typing.ClassVar[str] = '53100' class OutOfMemoryError(InsufficientResourcesError): - sqlstate = '53200' + sqlstate: typing.ClassVar[str] = '53200' class TooManyConnectionsError(InsufficientResourcesError): - sqlstate = '53300' + sqlstate: typing.ClassVar[str] = '53300' class ConfigurationLimitExceededError(InsufficientResourcesError): - sqlstate = '53400' + sqlstate: typing.ClassVar[str] = '53400' class ProgramLimitExceededError(_base.PostgresError): - sqlstate = '54000' + sqlstate: typing.ClassVar[str] = '54000' class StatementTooComplexError(ProgramLimitExceededError): - sqlstate = '54001' + sqlstate: typing.ClassVar[str] = '54001' class TooManyColumnsError(ProgramLimitExceededError): - sqlstate = '54011' + sqlstate: typing.ClassVar[str] = '54011' class TooManyArgumentsError(ProgramLimitExceededError): - sqlstate = '54023' + sqlstate: typing.ClassVar[str] = '54023' class ObjectNotInPrerequisiteStateError(_base.PostgresError): - sqlstate = '55000' + sqlstate: typing.ClassVar[str] = '55000' class ObjectInUseError(ObjectNotInPrerequisiteStateError): - sqlstate = '55006' + sqlstate: typing.ClassVar[str] = '55006' class CantChangeRuntimeParamError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P02' + sqlstate: typing.ClassVar[str] = '55P02' class LockNotAvailableError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P03' + sqlstate: typing.ClassVar[str] = '55P03' class UnsafeNewEnumValueUsageError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P04' + sqlstate: typing.ClassVar[str] = '55P04' class OperatorInterventionError(_base.PostgresError): - sqlstate = '57000' + sqlstate: typing.ClassVar[str] = '57000' class QueryCanceledError(OperatorInterventionError): - sqlstate = '57014' + sqlstate: typing.ClassVar[str] = '57014' class AdminShutdownError(OperatorInterventionError): - sqlstate = '57P01' + sqlstate: typing.ClassVar[str] = '57P01' class CrashShutdownError(OperatorInterventionError): - sqlstate = '57P02' + sqlstate: typing.ClassVar[str] = '57P02' class CannotConnectNowError(OperatorInterventionError): - sqlstate = '57P03' + sqlstate: typing.ClassVar[str] = '57P03' class DatabaseDroppedError(OperatorInterventionError): - sqlstate = '57P04' + sqlstate: typing.ClassVar[str] = '57P04' class IdleSessionTimeoutError(OperatorInterventionError): - sqlstate = '57P05' + sqlstate: typing.ClassVar[str] = '57P05' class PostgresSystemError(_base.PostgresError): - sqlstate = '58000' + sqlstate: typing.ClassVar[str] = '58000' class PostgresIOError(PostgresSystemError): - sqlstate = '58030' + sqlstate: typing.ClassVar[str] = '58030' class UndefinedFileError(PostgresSystemError): - sqlstate = '58P01' + sqlstate: typing.ClassVar[str] = '58P01' class DuplicateFileError(PostgresSystemError): - sqlstate = '58P02' + sqlstate: typing.ClassVar[str] = '58P02' class SnapshotTooOldError(_base.PostgresError): - sqlstate = '72000' + sqlstate: typing.ClassVar[str] = '72000' class ConfigFileError(_base.PostgresError): - sqlstate = 'F0000' + sqlstate: typing.ClassVar[str] = 'F0000' class LockFileExistsError(ConfigFileError): - sqlstate = 'F0001' + sqlstate: typing.ClassVar[str] = 'F0001' class FDWError(_base.PostgresError): - sqlstate = 'HV000' + sqlstate: typing.ClassVar[str] = 'HV000' class FDWColumnNameNotFoundError(FDWError): - sqlstate = 'HV005' + sqlstate: typing.ClassVar[str] = 'HV005' class FDWDynamicParameterValueNeededError(FDWError): - sqlstate = 'HV002' + sqlstate: typing.ClassVar[str] = 'HV002' class FDWFunctionSequenceError(FDWError): - sqlstate = 'HV010' + sqlstate: typing.ClassVar[str] = 'HV010' class FDWInconsistentDescriptorInformationError(FDWError): - sqlstate = 'HV021' + sqlstate: typing.ClassVar[str] = 'HV021' class FDWInvalidAttributeValueError(FDWError): - sqlstate = 'HV024' + sqlstate: typing.ClassVar[str] = 'HV024' class FDWInvalidColumnNameError(FDWError): - sqlstate = 'HV007' + sqlstate: typing.ClassVar[str] = 'HV007' class FDWInvalidColumnNumberError(FDWError): - sqlstate = 'HV008' + sqlstate: typing.ClassVar[str] = 'HV008' class FDWInvalidDataTypeError(FDWError): - sqlstate = 'HV004' + sqlstate: typing.ClassVar[str] = 'HV004' class FDWInvalidDataTypeDescriptorsError(FDWError): - sqlstate = 'HV006' + sqlstate: typing.ClassVar[str] = 'HV006' class FDWInvalidDescriptorFieldIdentifierError(FDWError): - sqlstate = 'HV091' + sqlstate: typing.ClassVar[str] = 'HV091' class FDWInvalidHandleError(FDWError): - sqlstate = 'HV00B' + sqlstate: typing.ClassVar[str] = 'HV00B' class FDWInvalidOptionIndexError(FDWError): - sqlstate = 'HV00C' + sqlstate: typing.ClassVar[str] = 'HV00C' class FDWInvalidOptionNameError(FDWError): - sqlstate = 'HV00D' + sqlstate: typing.ClassVar[str] = 'HV00D' class FDWInvalidStringLengthOrBufferLengthError(FDWError): - sqlstate = 'HV090' + sqlstate: typing.ClassVar[str] = 'HV090' class FDWInvalidStringFormatError(FDWError): - sqlstate = 'HV00A' + sqlstate: typing.ClassVar[str] = 'HV00A' class FDWInvalidUseOfNullPointerError(FDWError): - sqlstate = 'HV009' + sqlstate: typing.ClassVar[str] = 'HV009' class FDWTooManyHandlesError(FDWError): - sqlstate = 'HV014' + sqlstate: typing.ClassVar[str] = 'HV014' class FDWOutOfMemoryError(FDWError): - sqlstate = 'HV001' + sqlstate: typing.ClassVar[str] = 'HV001' class FDWNoSchemasError(FDWError): - sqlstate = 'HV00P' + sqlstate: typing.ClassVar[str] = 'HV00P' class FDWOptionNameNotFoundError(FDWError): - sqlstate = 'HV00J' + sqlstate: typing.ClassVar[str] = 'HV00J' class FDWReplyHandleError(FDWError): - sqlstate = 'HV00K' + sqlstate: typing.ClassVar[str] = 'HV00K' class FDWSchemaNotFoundError(FDWError): - sqlstate = 'HV00Q' + sqlstate: typing.ClassVar[str] = 'HV00Q' class FDWTableNotFoundError(FDWError): - sqlstate = 'HV00R' + sqlstate: typing.ClassVar[str] = 'HV00R' class FDWUnableToCreateExecutionError(FDWError): - sqlstate = 'HV00L' + sqlstate: typing.ClassVar[str] = 'HV00L' class FDWUnableToCreateReplyError(FDWError): - sqlstate = 'HV00M' + sqlstate: typing.ClassVar[str] = 'HV00M' class FDWUnableToEstablishConnectionError(FDWError): - sqlstate = 'HV00N' + sqlstate: typing.ClassVar[str] = 'HV00N' class PLPGSQLError(_base.PostgresError): - sqlstate = 'P0000' + sqlstate: typing.ClassVar[str] = 'P0000' class RaiseError(PLPGSQLError): - sqlstate = 'P0001' + sqlstate: typing.ClassVar[str] = 'P0001' class NoDataFoundError(PLPGSQLError): - sqlstate = 'P0002' + sqlstate: typing.ClassVar[str] = 'P0002' class TooManyRowsError(PLPGSQLError): - sqlstate = 'P0003' + sqlstate: typing.ClassVar[str] = 'P0003' class AssertError(PLPGSQLError): - sqlstate = 'P0004' + sqlstate: typing.ClassVar[str] = 'P0004' class InternalServerError(_base.PostgresError): - sqlstate = 'XX000' + sqlstate: typing.ClassVar[str] = 'XX000' class DataCorruptedError(InternalServerError): - sqlstate = 'XX001' + sqlstate: typing.ClassVar[str] = 'XX001' class IndexCorruptedError(InternalServerError): - sqlstate = 'XX002' + sqlstate: typing.ClassVar[str] = 'XX002' -__all__ = ( +__all__ = [ 'ActiveSQLTransactionError', 'AdminShutdownError', 'AmbiguousAliasError', 'AmbiguousColumnError', 'AmbiguousFunctionError', 'AmbiguousParameterError', @@ -1193,6 +1196,6 @@ class IndexCorruptedError(InternalServerError): 'UnterminatedCStringError', 'UntranslatableCharacterError', 'WindowingError', 'WithCheckOptionViolationError', 'WrongObjectTypeError', 'ZeroLengthCharacterStringError' -) +] __all__ += _base.__all__ diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 00e9699a..5763e180 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -4,169 +4,36 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncpg -import sys -import textwrap +import typing +if typing.TYPE_CHECKING: + import sys -__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + +from ._postgres_message import PostgresMessage as PostgresMessage + +__all__ = ['PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'ClientConfigurationError', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', - 'UnsupportedServerFeatureError') - - -def _is_asyncpg_class(cls): - modname = cls.__module__ - return modname == 'asyncpg' or modname.startswith('asyncpg.') - - -class PostgresMessageMeta(type): - - _message_map = {} - _field_map = { - 'S': 'severity', - 'V': 'severity_en', - 'C': 'sqlstate', - 'M': 'message', - 'D': 'detail', - 'H': 'hint', - 'P': 'position', - 'p': 'internal_position', - 'q': 'internal_query', - 'W': 'context', - 's': 'schema_name', - 't': 'table_name', - 'c': 'column_name', - 'd': 'data_type_name', - 'n': 'constraint_name', - 'F': 'server_source_filename', - 'L': 'server_source_line', - 'R': 'server_source_function' - } - - def __new__(mcls, name, bases, dct): - cls = super().__new__(mcls, name, bases, dct) - if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': - for f in mcls._field_map.values(): - setattr(cls, f, None) - - if _is_asyncpg_class(cls): - mod = sys.modules[cls.__module__] - if hasattr(mod, name): - raise RuntimeError('exception class redefinition: {}'.format( - name)) - - code = dct.get('sqlstate') - if code is not None: - existing = mcls._message_map.get(code) - if existing is not None: - raise TypeError('{} has duplicate SQLSTATE code, which is' - 'already defined by {}'.format( - name, existing.__name__)) - mcls._message_map[code] = cls - - return cls - - @classmethod - def get_message_class_for_sqlstate(mcls, code): - return mcls._message_map.get(code, UnknownPostgresError) + 'UnsupportedServerFeatureError'] - -class PostgresMessage(metaclass=PostgresMessageMeta): - - @classmethod - def _get_error_class(cls, fields): - sqlstate = fields.get('C') - return type(cls).get_message_class_for_sqlstate(sqlstate) - - @classmethod - def _get_error_dict(cls, fields, query): - dct = { - 'query': query - } - - field_map = type(cls)._field_map - for k, v in fields.items(): - field = field_map.get(k) - if field: - dct[field] = v - - return dct - - @classmethod - def _make_constructor(cls, fields, query=None): - dct = cls._get_error_dict(fields, query) - - exccls = cls._get_error_class(fields) - message = dct.get('message', '') - - # PostgreSQL will raise an exception when it detects - # that the result type of the query has changed from - # when the statement was prepared. - # - # The original error is somewhat cryptic and unspecific, - # so we raise a custom subclass that is easier to handle - # and identify. - # - # Note that we specifically do not rely on the error - # message, as it is localizable. - is_icse = ( - exccls.__name__ == 'FeatureNotSupportedError' and - _is_asyncpg_class(exccls) and - dct.get('server_source_function') == 'RevalidateCachedQuery' - ) - - if is_icse: - exceptions = sys.modules[exccls.__module__] - exccls = exceptions.InvalidCachedStatementError - message = ('cached statement plan is invalid due to a database ' - 'schema or configuration change') - - is_prepared_stmt_error = ( - exccls.__name__ in ('DuplicatePreparedStatementError', - 'InvalidSQLStatementNameError') and - _is_asyncpg_class(exccls) - ) - - if is_prepared_stmt_error: - hint = dct.get('hint', '') - hint += textwrap.dedent("""\ - - NOTE: pgbouncer with pool_mode set to "transaction" or - "statement" does not support prepared statements properly. - You have two options: - - * if you are using pgbouncer for connection pooling to a - single server, switch to the connection pool functionality - provided by asyncpg, it is a much better option for this - purpose; - - * if you have no option of avoiding the use of pgbouncer, - then you can set statement_cache_size to 0 when creating - the asyncpg connection object. - """) - - dct['hint'] = hint - - return exccls, message, dct - - def as_dict(self): - dct = {} - for f in type(self)._field_map.values(): - val = getattr(self, f) - if val is not None: - dct[f] = val - return dct +_PM = typing.TypeVar('_PM', bound='PostgresMessage') class PostgresError(PostgresMessage, Exception): """Base class for all Postgres errors.""" - def __str__(self): - msg = self.args[0] + def __str__(self) -> str: + msg: str = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) if self.hint: @@ -175,7 +42,7 @@ def __str__(self): return msg @classmethod - def new(cls, fields, query=None): + def new(cls, fields: dict[str, str], query: str | None = None) -> Self: exccls, message, dct = cls._make_constructor(fields, query) ex = exccls(message) ex.__dict__.update(dct) @@ -191,11 +58,20 @@ class UnknownPostgresError(FatalPostgresError): class InterfaceMessage: - def __init__(self, *, detail=None, hint=None): + args: tuple[str, ...] + detail: str | None + hint: str | None + + def __init__( + self, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: self.detail = detail self.hint = hint - def __str__(self): + def __str__(self) -> str: msg = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) @@ -208,11 +84,17 @@ def __str__(self): class InterfaceError(InterfaceMessage, Exception): """An error caused by improper use of asyncpg API.""" - def __init__(self, msg, *, detail=None, hint=None): + def __init__( + self, + msg: str, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: InterfaceMessage.__init__(self, detail=detail, hint=hint) Exception.__init__(self, msg) - def with_msg(self, msg): + def with_msg(self, msg: str) -> Self: return type(self)( msg, detail=self.detail, @@ -241,7 +123,13 @@ class UnsupportedServerFeatureError(InterfaceError): class InterfaceWarning(InterfaceMessage, UserWarning): """A warning caused by an improper use of asyncpg API.""" - def __init__(self, msg, *, detail=None, hint=None): + def __init__( + self, + msg: str, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: InterfaceMessage.__init__(self, detail=detail, hint=hint) UserWarning.__init__(self, msg) @@ -261,7 +149,18 @@ class TargetServerAttributeNotMatched(InternalClientError): class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" - def __init__(self, msg, *, schema=None, data_type=None, position=None): + schema_name: str | None + data_type_name: str | None + position: str | None + + def __init__( + self, + msg: str, + *, + schema: str | None = None, + data_type: str | None = None, + position: str | None = None, + ) -> None: super().__init__(msg) self.schema_name = schema self.data_type_name = data_type @@ -271,15 +170,18 @@ def __init__(self, msg, *, schema=None, data_type=None, position=None): class PostgresLogMessage(PostgresMessage): """A base class for non-error server messages.""" - def __str__(self): + def __str__(self) -> str: return '{}: {}'.format(type(self).__name__, self.message) - def __setattr__(self, name, val): + def __setattr__(self, name: str, val: object) -> None: raise TypeError('instances of {} are immutable'.format( type(self).__name__)) @classmethod - def new(cls, fields, query=None): + def new( + cls: type[_PM], fields: dict[str, str], query: str | None = None + ) -> PostgresMessage: + exccls: type[PostgresMessage] exccls, message_text, dct = cls._make_constructor(fields, query) if exccls is UnknownPostgresError: @@ -291,7 +193,7 @@ def new(cls, fields, query=None): exccls = asyncpg.PostgresWarning if issubclass(exccls, (BaseException, Warning)): - msg = exccls(message_text) + msg: PostgresMessage = exccls(message_text) else: msg = exccls() diff --git a/asyncpg/exceptions/_postgres_message.py b/asyncpg/exceptions/_postgres_message.py new file mode 100644 index 00000000..c281d7dd --- /dev/null +++ b/asyncpg/exceptions/_postgres_message.py @@ -0,0 +1,155 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import asyncpg +import sys +import textwrap + + +def _is_asyncpg_class(cls): + modname = cls.__module__ + return modname == 'asyncpg' or modname.startswith('asyncpg.') + + +class PostgresMessageMeta(type): + + _message_map = {} + _field_map = { + 'S': 'severity', + 'V': 'severity_en', + 'C': 'sqlstate', + 'M': 'message', + 'D': 'detail', + 'H': 'hint', + 'P': 'position', + 'p': 'internal_position', + 'q': 'internal_query', + 'W': 'context', + 's': 'schema_name', + 't': 'table_name', + 'c': 'column_name', + 'd': 'data_type_name', + 'n': 'constraint_name', + 'F': 'server_source_filename', + 'L': 'server_source_line', + 'R': 'server_source_function' + } + + def __new__(mcls, name, bases, dct): + cls = super().__new__(mcls, name, bases, dct) + if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': + for f in mcls._field_map.values(): + setattr(cls, f, None) + + if _is_asyncpg_class(cls): + mod = sys.modules[cls.__module__] + if hasattr(mod, name): + raise RuntimeError('exception class redefinition: {}'.format( + name)) + + code = dct.get('sqlstate') + if code is not None: + existing = mcls._message_map.get(code) + if existing is not None: + raise TypeError('{} has duplicate SQLSTATE code, which is' + 'already defined by {}'.format( + name, existing.__name__)) + mcls._message_map[code] = cls + + return cls + + @classmethod + def get_message_class_for_sqlstate(mcls, code): + return mcls._message_map.get(code, asyncpg.UnknownPostgresError) + + +class PostgresMessage(metaclass=PostgresMessageMeta): + + @classmethod + def _get_error_class(cls, fields): + sqlstate = fields.get('C') + return type(cls).get_message_class_for_sqlstate(sqlstate) + + @classmethod + def _get_error_dict(cls, fields, query): + dct = { + 'query': query + } + + field_map = type(cls)._field_map + for k, v in fields.items(): + field = field_map.get(k) + if field: + dct[field] = v + + return dct + + @classmethod + def _make_constructor(cls, fields, query=None): + dct = cls._get_error_dict(fields, query) + + exccls = cls._get_error_class(fields) + message = dct.get('message', '') + + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. + # + # The original error is somewhat cryptic and unspecific, + # so we raise a custom subclass that is easier to handle + # and identify. + # + # Note that we specifically do not rely on the error + # message, as it is localizable. + is_icse = ( + exccls.__name__ == 'FeatureNotSupportedError' and + _is_asyncpg_class(exccls) and + dct.get('server_source_function') == 'RevalidateCachedQuery' + ) + + if is_icse: + exceptions = sys.modules[exccls.__module__] + exccls = exceptions.InvalidCachedStatementError + message = ('cached statement plan is invalid due to a database ' + 'schema or configuration change') + + is_prepared_stmt_error = ( + exccls.__name__ in ('DuplicatePreparedStatementError', + 'InvalidSQLStatementNameError') and + _is_asyncpg_class(exccls) + ) + + if is_prepared_stmt_error: + hint = dct.get('hint', '') + hint += textwrap.dedent("""\ + + NOTE: pgbouncer with pool_mode set to "transaction" or + "statement" does not support prepared statements properly. + You have two options: + + * if you are using pgbouncer for connection pooling to a + single server, switch to the connection pool functionality + provided by asyncpg, it is a much better option for this + purpose; + + * if you have no option of avoiding the use of pgbouncer, + then you can set statement_cache_size to 0 when creating + the asyncpg connection object. + """) + + dct['hint'] = hint + + return exccls, message, dct + + def as_dict(self): + dct = {} + for f in type(self)._field_map.values(): + val = getattr(self, f) + if val is not None: + dct[f] = val + return dct diff --git a/asyncpg/exceptions/_postgres_message.pyi b/asyncpg/exceptions/_postgres_message.pyi new file mode 100644 index 00000000..7d2d4e75 --- /dev/null +++ b/asyncpg/exceptions/_postgres_message.pyi @@ -0,0 +1,36 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +import typing + +_PM = typing.TypeVar('_PM', bound=PostgresMessage) + +class PostgresMessageMeta(type): ... + +class PostgresMessage(metaclass=PostgresMessageMeta): + severity: str | None + severity_en: str | None + sqlstate: typing.ClassVar[str] + message: str + detail: str | None + hint: str | None + position: str | None + internal_position: str | None + internal_query: str | None + context: str | None + schema_name: str | None + table_name: str | None + column_name: str | None + data_type_name: str | None + constraint_name: str | None + server_source_filename: str | None + server_source_line: str | None + server_source_function: str | None + @classmethod + def _make_constructor( + cls: type[_PM], fields: dict[str, str], query: str | None = ... + ) -> tuple[type[_PM], str, dict[str, str]]: ... + def as_dict(self) -> dict[str, str]: ... diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index 6c2caf03..95ce0f0a 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -4,8 +4,15 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -_TYPEINFO_13 = '''\ +import typing + +if typing.TYPE_CHECKING: + from . import protocol + + +_TYPEINFO_13: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -82,7 +89,7 @@ ''' -INTRO_LOOKUP_TYPES_13 = '''\ +INTRO_LOOKUP_TYPES_13: typing.Final = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) @@ -124,7 +131,7 @@ '''.format(typeinfo=_TYPEINFO_13) -_TYPEINFO = '''\ +_TYPEINFO: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -206,7 +213,7 @@ ''' -INTRO_LOOKUP_TYPES = '''\ +INTRO_LOOKUP_TYPES: typing.Final = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) @@ -248,7 +255,7 @@ '''.format(typeinfo=_TYPEINFO) -TYPE_BY_NAME = '''\ +TYPE_BY_NAME: typing.Final = '''\ SELECT t.oid, t.typelem AS elemtype, @@ -274,19 +281,19 @@ # 'b' for a base type, 'd' for a domain, 'e' for enum. -SCALAR_TYPE_KINDS = (b'b', b'd', b'e') +SCALAR_TYPE_KINDS: typing.Final = (b'b', b'd', b'e') -def is_scalar_type(typeinfo) -> bool: +def is_scalar_type(typeinfo: protocol.Record) -> bool: return ( typeinfo['kind'] in SCALAR_TYPE_KINDS and not typeinfo['elemtype'] ) -def is_domain_type(typeinfo) -> bool: - return typeinfo['kind'] == b'd' +def is_domain_type(typeinfo: protocol.Record) -> bool: + return typing.cast(bytes, typeinfo['kind']) == b'd' -def is_composite_type(typeinfo) -> bool: - return typeinfo['kind'] == b'c' +def is_composite_type(typeinfo: protocol.Record) -> bool: + return typing.cast(bytes, typeinfo['kind']) == b'c' diff --git a/asyncpg/pgproto b/asyncpg/pgproto index 1c3cad14..dbb69452 160000 --- a/asyncpg/pgproto +++ b/asyncpg/pgproto @@ -1 +1 @@ -Subproject commit 1c3cad14d53c8f3088106f4eab8f612b7293569b +Subproject commit dbb69452baaac89ae46cbae0fb6b4a267083d16f diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 06e698df..1544a6c9 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -4,94 +4,53 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio -import functools -import inspect import logging import time +import typing import warnings from . import compat from . import connection from . import exceptions +from . import pool_connection_proxy from . import protocol +if typing.TYPE_CHECKING: + import sys -logger = logging.getLogger(__name__) + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + from . import connect_utils -class PoolConnectionProxyMeta(type): +_ConnectionT = typing.TypeVar( + '_ConnectionT', bound=connection.Connection[typing.Any] +) +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record) - def __new__(mcls, name, bases, dct, *, wrap=False): - if wrap: - for attrname in dir(connection.Connection): - if attrname.startswith('_') or attrname in dct: - continue +_logger = logging.getLogger(__name__) - meth = getattr(connection.Connection, attrname) - if not inspect.isfunction(meth): - continue - wrapper = mcls._wrap_connection_method(attrname) - wrapper = functools.update_wrapper(wrapper, meth) - dct[attrname] = wrapper - - if '__doc__' not in dct: - dct['__doc__'] = connection.Connection.__doc__ - - return super().__new__(mcls, name, bases, dct) - - @staticmethod - def _wrap_connection_method(meth_name): - def call_con_method(self, *args, **kwargs): - # This method will be owned by PoolConnectionProxy class. - if self._con is None: - raise exceptions.InterfaceError( - 'cannot call Connection.{}(): ' - 'connection has been released back to the pool'.format( - meth_name)) - - meth = getattr(self._con.__class__, meth_name) - return meth(self._con, *args, **kwargs) - - return call_con_method - - -class PoolConnectionProxy(connection._ConnectionProxy, - metaclass=PoolConnectionProxyMeta, - wrap=True): - - __slots__ = ('_con', '_holder') - - def __init__(self, holder: 'PoolConnectionHolder', - con: connection.Connection): - self._con = con - self._holder = holder - con._set_proxy(self) - - def __getattr__(self, attr): - # Proxy all unresolved attributes to the wrapped Connection object. - return getattr(self._con, attr) - - def _detach(self) -> connection.Connection: - if self._con is None: - return +class _SetupCallback(typing.Protocol[_RecordT]): + async def __call__( + self, + __proxy: pool_connection_proxy.PoolConnectionProxy[_RecordT] + ) -> None: + ... - con, self._con = self._con, None - con._set_proxy(None) - return con - def __repr__(self): - if self._con is None: - return '<{classname} [released] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) - else: - return '<{classname} {con!r} {id:#x}>'.format( - classname=self.__class__.__name__, con=self._con, id=id(self)) +class _InitCallback(typing.Protocol[_RecordT]): + async def __call__(self, __con: connection.Connection[_RecordT]) -> None: + ... -class PoolConnectionHolder: +class PoolConnectionHolder(typing.Generic[_RecordT]): __slots__ = ('_con', '_pool', '_loop', '_proxy', '_max_queries', '_setup', @@ -99,7 +58,25 @@ class PoolConnectionHolder: '_inactive_callback', '_timeout', '_generation') - def __init__(self, pool, *, max_queries, setup, max_inactive_time): + _con: connection.Connection[_RecordT] | None + _pool: Pool[_RecordT] + _proxy: pool_connection_proxy.PoolConnectionProxy[_RecordT] | None + _max_queries: int + _setup: _SetupCallback[_RecordT] | None + _max_inactive_time: float + _in_use: asyncio.Future[None] | None + _inactive_callback: asyncio.TimerHandle | None + _timeout: float | None + _generation: int | None + + def __init__( + self, + pool: Pool[_RecordT], + *, + max_queries: int, + setup: _SetupCallback[_RecordT] | None, + max_inactive_time: float + ) -> None: self._pool = pool self._con = None @@ -109,17 +86,17 @@ def __init__(self, pool, *, max_queries, setup, max_inactive_time): self._max_inactive_time = max_inactive_time self._setup = setup self._inactive_callback = None - self._in_use = None # type: asyncio.Future + self._in_use = None self._timeout = None self._generation = None - def is_connected(self): + def is_connected(self) -> bool: return self._con is not None and not self._con.is_closed() - def is_idle(self): + def is_idle(self) -> bool: return not self._in_use - async def connect(self): + async def connect(self) -> None: if self._con is not None: raise exceptions.InternalClientError( 'PoolConnectionHolder.connect() called while another ' @@ -130,7 +107,9 @@ async def connect(self): self._maybe_cancel_inactive_callback() self._setup_inactive_callback() - async def acquire(self) -> PoolConnectionProxy: + async def acquire( + self + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: if self._con is None or self._con.is_closed(): self._con = None await self.connect() @@ -142,9 +121,14 @@ async def acquire(self) -> PoolConnectionProxy: self._con = None await self.connect() + if typing.TYPE_CHECKING: + assert self._con is not None + self._maybe_cancel_inactive_callback() - self._proxy = proxy = PoolConnectionProxy(self, self._con) + self._proxy = proxy = pool_connection_proxy.PoolConnectionProxy( + self, self._con + ) if self._setup is not None: try: @@ -167,12 +151,15 @@ async def acquire(self) -> PoolConnectionProxy: return proxy - async def release(self, timeout): + async def release(self, timeout: float | None) -> None: if self._in_use is None: raise exceptions.InternalClientError( 'PoolConnectionHolder.release() called on ' 'a free connection holder') + if typing.TYPE_CHECKING: + assert self._con is not None + if self._con.is_closed(): # When closing, pool connections perform the necessary # cleanup, so we don't have to do anything else here. @@ -225,25 +212,25 @@ async def release(self, timeout): # Rearm the connection inactivity timer. self._setup_inactive_callback() - async def wait_until_released(self): + async def wait_until_released(self) -> None: if self._in_use is None: return else: await self._in_use - async def close(self): + async def close(self) -> None: if self._con is not None: # Connection.close() will call _release_on_close() to # finish holder cleanup. await self._con.close() - def terminate(self): + def terminate(self) -> None: if self._con is not None: # Connection.terminate() will call _release_on_close() to # finish holder cleanup. self._con.terminate() - def _setup_inactive_callback(self): + def _setup_inactive_callback(self) -> None: if self._inactive_callback is not None: raise exceptions.InternalClientError( 'pool connection inactivity timer already exists') @@ -252,12 +239,12 @@ def _setup_inactive_callback(self): self._inactive_callback = self._pool._loop.call_later( self._max_inactive_time, self._deactivate_inactive_connection) - def _maybe_cancel_inactive_callback(self): + def _maybe_cancel_inactive_callback(self) -> None: if self._inactive_callback is not None: self._inactive_callback.cancel() self._inactive_callback = None - def _deactivate_inactive_connection(self): + def _deactivate_inactive_connection(self) -> None: if self._in_use is not None: raise exceptions.InternalClientError( 'attempting to deactivate an acquired connection') @@ -271,12 +258,12 @@ def _deactivate_inactive_connection(self): # so terminate() above will not call the below. self._release_on_close() - def _release_on_close(self): + def _release_on_close(self) -> None: self._maybe_cancel_inactive_callback() self._release() self._con = None - def _release(self): + def _release(self) -> None: """Release this connection holder.""" if self._in_use is None: # The holder is not checked out. @@ -292,11 +279,14 @@ def _release(self): self._proxy._detach() self._proxy = None + if typing.TYPE_CHECKING: + assert self._pool._queue is not None + # Put ourselves back to the pool queue. self._pool._queue.put_nowait(self) -class Pool: +class Pool(typing.Generic[_RecordT]): """A connection pool. Connection pool can be used to manage a set of connections to the database. @@ -315,17 +305,42 @@ class Pool: '_setup', '_max_queries', '_max_inactive_connection_lifetime' ) - def __init__(self, *connect_args, - min_size, - max_size, - max_queries, - max_inactive_connection_lifetime, - setup, - init, - loop, - connection_class, - record_class, - **connect_kwargs): + _queue: asyncio.LifoQueue[PoolConnectionHolder[_RecordT]] | None + _loop: asyncio.AbstractEventLoop + _minsize: int + _maxsize: int + _init: _InitCallback[_RecordT] | None + _connect_args: tuple[str | None] | tuple[()] + _connect_kwargs: dict[str, object] + _working_addr: typing.Tuple[str, int] | str + _working_config: connect_utils._ClientConfiguration | None + _working_params: connect_utils._ConnectionParameters | None + _holders: list[PoolConnectionHolder[_RecordT]] + _initialized: bool + _initializing: bool + _closing: bool + _closed: bool + _connection_class: type[connection.Connection[_RecordT]] + _record_class: type[_RecordT] + _generation: int + _setup: _SetupCallback[_RecordT] | None + _max_queries: int + _max_inactive_connection_lifetime: float + + def __init__( + self, + *connect_args: str | None, + min_size: int, + max_size: int, + max_queries: int, + max_inactive_connection_lifetime: float, + setup: _SetupCallback[_RecordT] | None, + init: _InitCallback[_RecordT] | None, + loop: asyncio.AbstractEventLoop | None, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + **connect_kwargs: object + ): if len(connect_args) > 1: warnings.warn( @@ -382,7 +397,9 @@ def __init__(self, *connect_args, self._closed = False self._generation = 0 self._init = init - self._connect_args = connect_args + self._connect_args = ( + () if not len(connect_args) else (connect_args[0],) + ) self._connect_kwargs = connect_kwargs self._setup = setup @@ -390,9 +407,9 @@ def __init__(self, *connect_args, self._max_inactive_connection_lifetime = \ max_inactive_connection_lifetime - async def _async__init__(self): + async def _async__init__(self) -> Self | None: if self._initialized: - return + return None if self._initializing: raise exceptions.InterfaceError( 'pool is being initialized in another task') @@ -406,7 +423,7 @@ async def _async__init__(self): self._initializing = False self._initialized = True - async def _initialize(self): + async def _initialize(self) -> None: self._queue = asyncio.LifoQueue(maxsize=self._maxsize) for _ in range(self._maxsize): ch = PoolConnectionHolder( @@ -426,11 +443,11 @@ async def _initialize(self): # Connect the first connection holder in the queue so that # any connection issues are visible early. - first_ch = self._holders[-1] # type: PoolConnectionHolder + first_ch: PoolConnectionHolder[_RecordT] = self._holders[-1] await first_ch.connect() if self._minsize > 1: - connect_tasks = [] + connect_tasks: list[compat.Awaitable[None]] = [] for i, ch in enumerate(reversed(self._holders[:-1])): # `minsize - 1` because we already have first_ch if i >= self._minsize - 1: @@ -439,42 +456,44 @@ async def _initialize(self): await asyncio.gather(*connect_tasks) - def is_closing(self): + def is_closing(self) -> bool: """Return ``True`` if the pool is closing or is closed. .. versionadded:: 0.28.0 """ return self._closed or self._closing - def get_size(self): + def get_size(self) -> int: """Return the current number of connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() for h in self._holders) - def get_min_size(self): + def get_min_size(self) -> int: """Return the minimum number of connections in this pool. .. versionadded:: 0.25.0 """ return self._minsize - def get_max_size(self): + def get_max_size(self) -> int: """Return the maximum allowed number of connections in this pool. .. versionadded:: 0.25.0 """ return self._maxsize - def get_idle_size(self): + def get_idle_size(self) -> int: """Return the current number of idle connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() and h.is_idle() for h in self._holders) - def set_connect_args(self, dsn=None, **connect_kwargs): + def set_connect_args( + self, dsn: str | None = None, **connect_kwargs: object + ) -> None: r"""Set the new connection arguments for this pool. The new connection arguments will be used for all subsequent @@ -495,16 +514,16 @@ def set_connect_args(self, dsn=None, **connect_kwargs): .. versionadded:: 0.16.0 """ - self._connect_args = [dsn] + self._connect_args = (dsn,) self._connect_kwargs = connect_kwargs - async def _get_new_connection(self): - con = await connection.connect( + async def _get_new_connection(self) -> connection.Connection[_RecordT]: + con: connection.Connection[_RecordT] = await connection.connect( *self._connect_args, loop=self._loop, connection_class=self._connection_class, record_class=self._record_class, - **self._connect_kwargs, + **typing.cast(typing.Any, self._connect_kwargs), ) if self._init is not None: @@ -526,7 +545,9 @@ async def _get_new_connection(self): return con - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, query: str, *args: object, timeout: float | None = None + ) -> str: """Execute an SQL command (or commands). Pool performs this operation using one of its connections. Other than @@ -538,7 +559,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None, + ) -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than @@ -551,13 +578,43 @@ async def executemany(self, command: str, args, *, timeout: float=None): async with self.acquire() as con: return await con.executemany(command, args, timeout=timeout) + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: + ... + async def fetch( self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> list[_RecordT] | list[_OtherRecordT]: """Run a query and return the results as a list of :class:`Record`. Pool performs this operation using one of its connections. Other than @@ -574,7 +631,13 @@ async def fetch( record_class=record_class ) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval( + self, + query: str, + *args: object, + column: int = 0, + timeout: float | None = None, + ) -> typing.Any: """Run a query and return a value in the first row. Pool performs this operation using one of its connections. Other than @@ -588,7 +651,43 @@ async def fetchval(self, query, *args, column=0, timeout=None): return await con.fetchval( query, *args, column=column, timeout=timeout) - async def fetchrow(self, query, *args, timeout=None, record_class=None): + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: + ... + + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> _RecordT | _OtherRecordT | None: """Run a query and return the first row. Pool performs this operation using one of its connections. Other than @@ -607,22 +706,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None): async def copy_from_table( self, - table_name, + table_name: str, *, - output, - columns=None, - schema_name=None, - timeout=None, - format=None, - oids=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - encoding=None - ): + output: connection._OutputType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy table contents to a file or file-like object. Pool performs this operation using one of its connections. Other than @@ -652,20 +751,20 @@ async def copy_from_table( async def copy_from_query( self, - query, - *args, - output, - timeout=None, - format=None, - oids=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - encoding=None - ): + query: str, + *args: object, + output: connection._OutputType, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy the results of a query to a file or file-like object. Pool performs this operation using one of its connections. Other than @@ -694,26 +793,26 @@ async def copy_from_query( async def copy_to_table( self, - table_name, + table_name: str, *, - source, - columns=None, - schema_name=None, - timeout=None, - format=None, - oids=None, - freeze=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - force_not_null=None, - force_null=None, - encoding=None, - where=None - ): + source: connection._SourceType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + where: str | None = None, + ) -> str: """Copy data to the specified table. Pool performs this operation using one of its connections. Other than @@ -747,14 +846,16 @@ async def copy_to_table( async def copy_records_to_table( self, - table_name, + table_name: str, *, - records, - columns=None, - schema_name=None, - timeout=None, - where=None - ): + records: compat.Iterable[ + compat.Sequence[object] + ] | compat.AsyncIterable[compat.Sequence[object]], + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + where: str | None = None, + ) -> str: """Copy a list of records to the specified table using binary COPY. Pool performs this operation using one of its connections. Other than @@ -774,7 +875,9 @@ async def copy_records_to_table( where=where ) - def acquire(self, *, timeout=None): + def acquire( + self, *, timeout: float | None = None + ) -> PoolAcquireContext[_RecordT]: """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. @@ -799,11 +902,18 @@ def acquire(self, *, timeout=None): """ return PoolAcquireContext(self, timeout) - async def _acquire(self, timeout): - async def _acquire_impl(): - ch = await self._queue.get() # type: PoolConnectionHolder + async def _acquire( + self, timeout: float | None + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: + async def _acquire_impl() -> pool_connection_proxy.PoolConnectionProxy[ + _RecordT + ]: + if typing.TYPE_CHECKING: + assert self._queue is not None + + ch: PoolConnectionHolder[_RecordT] = await self._queue.get() try: - proxy = await ch.acquire() # type: PoolConnectionProxy + proxy = await ch.acquire() except (Exception, asyncio.CancelledError): self._queue.put_nowait(ch) raise @@ -823,7 +933,12 @@ async def _acquire_impl(): return await compat.wait_for( _acquire_impl(), timeout=timeout) - async def release(self, connection, *, timeout=None): + async def release( + self, + connection: pool_connection_proxy.PoolConnectionProxy[_RecordT], + *, + timeout: float | None = None, + ) -> None: """Release a database connection back to the pool. :param Connection connection: @@ -836,8 +951,8 @@ async def release(self, connection, *, timeout=None): .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ - if (type(connection) is not PoolConnectionProxy or - connection._holder._pool is not self): + if (type(connection) is not pool_connection_proxy.PoolConnectionProxy + or connection._holder._pool is not self): raise exceptions.InterfaceError( 'Pool.release() received invalid connection: ' '{connection!r} is not a member of this pool'.format( @@ -861,7 +976,7 @@ async def release(self, connection, *, timeout=None): # pool properly. return await asyncio.shield(ch.release(timeout)) - async def close(self): + async def close(self) -> None: """Attempt to gracefully close all connections in the pool. Wait until all pool connections are released, close them and @@ -906,13 +1021,13 @@ async def close(self): self._closed = True self._closing = False - def _warn_on_long_close(self): - logger.warning('Pool.close() is taking over 60 seconds to complete. ' - 'Check if you have any unreleased connections left. ' - 'Use asyncio.wait_for() to set a timeout for ' - 'Pool.close().') + def _warn_on_long_close(self) -> None: + _logger.warning('Pool.close() is taking over 60 seconds to complete. ' + 'Check if you have any unreleased connections left. ' + 'Use asyncio.wait_for() to set a timeout for ' + 'Pool.close().') - def terminate(self): + def terminate(self) -> None: """Terminate all connections in the pool.""" if self._closed: return @@ -921,7 +1036,7 @@ def terminate(self): ch.terminate() self._closed = True - async def expire_connections(self): + async def expire_connections(self) -> None: """Expire all currently open connections. Cause all currently open connections to get replaced on the @@ -931,7 +1046,7 @@ async def expire_connections(self): """ self._generation += 1 - def _check_init(self): + def _check_init(self) -> None: if not self._initialized: if self._initializing: raise exceptions.InterfaceError( @@ -942,67 +1057,142 @@ def _check_init(self): if self._closed: raise exceptions.InterfaceError('pool is closed') - def _drop_statement_cache(self): + def _drop_statement_cache(self) -> None: # Drop statement cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_statement_cache() - def _drop_type_cache(self): + def _drop_type_cache(self) -> None: # Drop type codec cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_type_cache() - def __await__(self): + def __await__(self) -> compat.Generator[typing.Any, None, Self | None]: return self._async__init__().__await__() - async def __aenter__(self): + async def __aenter__(self) -> Self: await self._async__init__() return self - async def __aexit__(self, *exc): + async def __aexit__(self, *exc: object) -> None: await self.close() -class PoolAcquireContext: +class PoolAcquireContext(typing.Generic[_RecordT]): __slots__ = ('timeout', 'connection', 'done', 'pool') - def __init__(self, pool, timeout): + timeout: float | None + connection: pool_connection_proxy.PoolConnectionProxy[_RecordT] | None + done: bool + pool: Pool[_RecordT] + + def __init__(self, pool: Pool[_RecordT], timeout: float | None) -> None: self.pool = pool self.timeout = timeout self.connection = None self.done = False - async def __aenter__(self): + async def __aenter__( + self + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: if self.connection is not None or self.done: raise exceptions.InterfaceError('a connection is already acquired') self.connection = await self.pool._acquire(self.timeout) return self.connection - async def __aexit__(self, *exc): + async def __aexit__(self, *exc: object) -> None: self.done = True con = self.connection self.connection = None + if typing.TYPE_CHECKING: + assert con is not None await self.pool.release(con) - def __await__(self): + def __await__(self) -> compat.Generator[ + typing.Any, None, pool_connection_proxy.PoolConnectionProxy[_RecordT] + ]: self.done = True return self.pool._acquire(self.timeout).__await__() -def create_pool(dsn=None, *, - min_size=10, - max_size=10, - max_queries=50000, - max_inactive_connection_lifetime=300.0, - setup=None, - init=None, - loop=None, - connection_class=connection.Connection, - record_class=protocol.Record, - **connect_kwargs): +@typing.overload +def create_pool( + dsn: str | None = ..., + *, + min_size: int = ..., + max_size: int = ..., + max_queries: int = ..., + max_inactive_connection_lifetime: float = ..., + setup: _SetupCallback[_RecordT] | None = ..., + init: _InitCallback[_RecordT] | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + connection_class: type[connection.Connection[_RecordT]] = ..., + record_class: type[_RecordT], + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + server_settings: dict[str, str] | None = ..., +) -> Pool[_RecordT]: + ... + + +@typing.overload +def create_pool( + dsn: str | None = ..., + *, + min_size: int = ..., + max_size: int = ..., + max_queries: int = ..., + max_inactive_connection_lifetime: float = ..., + setup: _SetupCallback[protocol.Record] | None = ..., + init: _InitCallback[protocol.Record] | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + connection_class: type[connection.Connection[protocol.Record]] = ..., + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + server_settings: dict[str, str] | None = ..., +) -> Pool[protocol.Record]: + ... + + +def create_pool( + dsn: str | None = None, + *, + min_size: int = 10, + max_size: int = 10, + max_queries: int = 50000, + max_inactive_connection_lifetime: float = 300.0, + setup: _SetupCallback[typing.Any] | None = None, + init: _InitCallback[typing.Any] | None = None, + loop: asyncio.AbstractEventLoop | None = None, + connection_class: type[ + connection.Connection[typing.Any] + ] = connection.Connection, + record_class: type[protocol.Record] | type[_RecordT] = protocol.Record, + **connect_kwargs: typing.Any +) -> Pool[typing.Any]: r"""Create a connection pool. Can be used either with an ``async with`` block: diff --git a/asyncpg/pool_connection_proxy.py b/asyncpg/pool_connection_proxy.py new file mode 100644 index 00000000..b4d2e4fc --- /dev/null +++ b/asyncpg/pool_connection_proxy.py @@ -0,0 +1,91 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import functools +import inspect +import typing + +from . import connection +from . import exceptions + +if typing.TYPE_CHECKING: + from . import pool + from . import protocol + + +_RecordT = typing.TypeVar('_RecordT', bound='protocol.Record') + + +class PoolConnectionProxyMeta(type): + + def __new__(mcls, name, bases, dct, *, wrap=False): + if wrap: + for attrname in dir(connection.Connection): + if attrname.startswith('_') or attrname in dct: + continue + + meth = getattr(connection.Connection, attrname) + if not inspect.isfunction(meth): + continue + + wrapper = mcls._wrap_connection_method(attrname) + wrapper = functools.update_wrapper(wrapper, meth) + dct[attrname] = wrapper + + if '__doc__' not in dct: + dct['__doc__'] = connection.Connection.__doc__ + + return super().__new__(mcls, name, bases, dct) + + @staticmethod + def _wrap_connection_method(meth_name): + def call_con_method(self, *args, **kwargs): + # This method will be owned by PoolConnectionProxy class. + if self._con is None: + raise exceptions.InterfaceError( + 'cannot call Connection.{}(): ' + 'connection has been released back to the pool'.format( + meth_name)) + + meth = getattr(self._con.__class__, meth_name) + return meth(self._con, *args, **kwargs) + + return call_con_method + + +class PoolConnectionProxy(connection._ConnectionProxy[_RecordT], + metaclass=PoolConnectionProxyMeta, + wrap=True): + + __slots__ = ('_con', '_holder') + + def __init__(self, holder: pool.PoolConnectionHolder, + con: connection.Connection[_RecordT]): + self._con = con + self._holder = holder + con._set_proxy(self) + + def __getattr__(self, attr): + # Proxy all unresolved attributes to the wrapped Connection object. + return getattr(self._con, attr) + + def _detach(self) -> connection.Connection[_RecordT]: + if self._con is None: + return + + con, self._con = self._con, None + con._set_proxy(None) + return con + + def __repr__(self): + if self._con is None: + return '<{classname} [released] {id:#x}>'.format( + classname=self.__class__.__name__, id=id(self)) + else: + return '<{classname} {con!r} {id:#x}>'.format( + classname=self.__class__.__name__, con=self._con, id=id(self)) diff --git a/asyncpg/pool_connection_proxy.pyi b/asyncpg/pool_connection_proxy.pyi new file mode 100644 index 00000000..cdb03af9 --- /dev/null +++ b/asyncpg/pool_connection_proxy.pyi @@ -0,0 +1,284 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +import contextlib +from collections.abc import ( + AsyncIterable, + Callable, + Iterable, + Iterator, + Sequence, +) +from typing import Any, TypeVar, overload + +from . import connection +from . import cursor +from . import pool +from . import prepared_stmt +from . import protocol +from . import transaction +from . import types +from .protocol import protocol as _cprotocol + +_RecordT = TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = TypeVar('_OtherRecordT', bound=protocol.Record) + +class PoolConnectionProxyMeta(type): ... + +class PoolConnectionProxy( + connection._ConnectionProxy[_RecordT], metaclass=PoolConnectionProxyMeta +): + __slots__ = ('_con', '_holder') + _con: connection.Connection[_RecordT] + _holder: pool.PoolConnectionHolder[_RecordT] + def __init__( + self, + holder: pool.PoolConnectionHolder[_RecordT], + con: connection.Connection[_RecordT], + ) -> None: ... + def _detach(self) -> connection.Connection[_RecordT]: ... + + # The following methods are copied from Connection + async def add_listener( + self, channel: str, callback: connection.Listener + ) -> None: ... + async def remove_listener( + self, channel: str, callback: connection.Listener + ) -> None: ... + def add_log_listener(self, callback: connection.LogListener) -> None: ... + def remove_log_listener(self, callback: connection.LogListener) -> None: ... + def add_termination_listener( + self, callback: connection.TerminationListener + ) -> None: ... + def remove_termination_listener( + self, callback: connection.TerminationListener + ) -> None: ... + def add_query_logger(self, callback: connection.QueryLogger) -> None: ... + def remove_query_logger(self, callback: connection.QueryLogger) -> None: ... + def get_server_pid(self) -> int: ... + def get_server_version(self) -> types.ServerVersion: ... + def get_settings(self) -> _cprotocol.ConnectionSettings: ... + def transaction( + self, + *, + isolation: transaction.IsolationLevels | None = ..., + readonly: bool = ..., + deferrable: bool = ..., + ) -> transaction.Transaction: ... + def is_in_transaction(self) -> bool: ... + async def execute( + self, query: str, *args: object, timeout: float | None = ... + ) -> str: ... + async def executemany( + self, + command: str, + args: Iterable[Sequence[object]], + *, + timeout: float | None = ..., + ) -> None: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> cursor.CursorFactory[_RecordT]: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> cursor.CursorFactory[_OtherRecordT]: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> cursor.CursorFactory[_RecordT] | cursor.CursorFactory[_OtherRecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: ... + async def fetchval( + self, + query: str, + *args: object, + column: int = ..., + timeout: float | None = ..., + ) -> Any: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: ... + async def copy_from_table( + self, + table_name: str, + *, + output: connection._OutputType, + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + ) -> str: ... + async def copy_from_query( + self, + query: str, + *args: object, + output: connection._OutputType, + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + ) -> str: ... + async def copy_to_table( + self, + table_name: str, + *, + source: connection._SourceType, + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + freeze: bool | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + force_not_null: bool | Iterable[str] | None = ..., + force_null: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + where: str | None = ..., + ) -> str: ... + async def copy_records_to_table( + self, + table_name: str, + *, + records: Iterable[Sequence[object]] | AsyncIterable[Sequence[object]], + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + where: str | None = ..., + ) -> str: ... + async def set_type_codec( + self, + typename: str, + *, + schema: str = ..., + encoder: Callable[[Any], Any], + decoder: Callable[[Any], Any], + format: str = ..., + ) -> None: ... + async def reset_type_codec(self, typename: str, *, schema: str = ...) -> None: ... + async def set_builtin_type_codec( + self, + typename: str, + *, + schema: str = ..., + codec_name: str, + format: str | None = ..., + ) -> None: ... + def is_closed(self) -> bool: ... + async def close(self, *, timeout: float | None = ...) -> None: ... + def terminate(self) -> None: ... + async def reset(self, *, timeout: float | None = ...) -> None: ... + async def reload_schema_state(self) -> None: ... + @contextlib.contextmanager + def query_logger(self, callback: connection.QueryLogger) -> Iterator[None]: ... diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 8e241d67..f49163a2 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -4,20 +4,52 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import json +import typing from . import connresource from . import cursor from . import exceptions +if typing.TYPE_CHECKING: + from .protocol import protocol as _cprotocol + from . import compat + from . import connection as _connection + from . import types -class PreparedStatement(connresource.ConnectionResource): + +_RecordT = typing.TypeVar('_RecordT', bound='_cprotocol.Record') +_T = typing.TypeVar('_T') +_T_co = typing.TypeVar('_T_co', covariant=True) + + +class _Executor(typing.Protocol[_T_co]): + def __call__( + self, __protocol: _cprotocol.BaseProtocol[typing.Any] + ) -> compat.Awaitable[_T_co]: + ... + + +class PreparedStatement( + connresource.ConnectionResource, + typing.Generic[_RecordT] +): """A representation of a prepared statement.""" __slots__ = ('_state', '_query', '_last_status') - def __init__(self, connection, query, state): + _state: _cprotocol.PreparedStatementState[_RecordT] + _query: str + _last_status: bytes | None + + def __init__( + self, + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT], + ) -> None: super().__init__(connection) self._state = state self._query = query @@ -44,7 +76,7 @@ def get_query(self) -> str: return self._query @connresource.guarded - def get_statusmsg(self) -> str: + def get_statusmsg(self) -> str | None: """Return the status of the executed command. Example:: @@ -58,7 +90,7 @@ def get_statusmsg(self) -> str: return self._last_status.decode() @connresource.guarded - def get_parameters(self): + def get_parameters(self) -> tuple[types.Type, ...]: """Return a description of statement parameters types. :return: A tuple of :class:`asyncpg.types.Type`. @@ -75,7 +107,7 @@ def get_parameters(self): return self._state._get_parameters() @connresource.guarded - def get_attributes(self): + def get_attributes(self) -> tuple[types.Attribute, ...]: """Return a description of relation attributes (columns). :return: A tuple of :class:`asyncpg.types.Attribute`. @@ -100,8 +132,8 @@ def get_attributes(self): return self._state._get_attributes() @connresource.guarded - def cursor(self, *args, prefetch=None, - timeout=None) -> cursor.CursorFactory: + def cursor(self, *args: object, prefetch: int | None = None, + timeout: float | None = None) -> cursor.CursorFactory[_RecordT]: """Return a *cursor factory* for the prepared statement. :param args: Query arguments. @@ -122,7 +154,9 @@ def cursor(self, *args, prefetch=None, ) @connresource.guarded - async def explain(self, *args, analyze=False): + async def explain( + self, *args: object, analyze: bool = False + ) -> typing.Any: """Return the execution plan of the statement. :param args: Query arguments. @@ -164,7 +198,9 @@ async def explain(self, *args, analyze=False): return json.loads(data) @connresource.guarded - async def fetch(self, *args, timeout=None): + async def fetch( + self, *args: object, timeout: float | None = None + ) -> list[_RecordT]: r"""Execute the statement and return a list of :class:`Record` objects. :param str query: Query text @@ -177,7 +213,9 @@ async def fetch(self, *args, timeout=None): return data @connresource.guarded - async def fetchval(self, *args, column=0, timeout=None): + async def fetchval( + self, *args: object, column: int = 0, timeout: float | None = None + ) -> typing.Any: """Execute the statement and return a value in the first row. :param args: Query arguments. @@ -196,7 +234,9 @@ async def fetchval(self, *args, column=0, timeout=None): return data[0][column] @connresource.guarded - async def fetchrow(self, *args, timeout=None): + async def fetchrow( + self, *args: object, timeout: float | None = None + ) -> _RecordT | None: """Execute the statement and return the first row. :param str query: Query text @@ -211,7 +251,12 @@ async def fetchrow(self, *args, timeout=None): return data[0] @connresource.guarded - async def executemany(self, args, *, timeout: float=None): + async def executemany( + self, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None + ) -> None: """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. @@ -224,7 +269,7 @@ async def executemany(self, args, *, timeout: float=None): lambda protocol: protocol.bind_execute_many( self._state, args, '', timeout)) - async def __do_execute(self, executor): + async def __do_execute(self, executor: _Executor[_T]) -> _T: protocol = self._connection._protocol try: return await executor(protocol) @@ -237,23 +282,28 @@ async def __do_execute(self, executor): self._state.mark_closed() raise - async def __bind_execute(self, args, limit, timeout): - data, status, _ = await self.__do_execute( - lambda protocol: protocol.bind_execute( - self._state, args, '', limit, True, timeout)) + async def __bind_execute( + self, args: compat.Sequence[object], limit: int, timeout: float | None + ) -> list[_RecordT]: + executor: _Executor[ + tuple[list[_RecordT], bytes, bool] + ] = lambda protocol: protocol.bind_execute( + self._state, args, '', limit, True, timeout + ) + data, status, _ = await self.__do_execute(executor) self._last_status = status return data - def _check_open(self, meth_name): + def _check_open(self, meth_name: str) -> None: if self._state.closed: raise exceptions.InterfaceError( 'cannot call PreparedStmt.{}(): ' 'the prepared statement is closed'.format(meth_name)) - def _check_conn_validity(self, meth_name): + def _check_conn_validity(self, meth_name: str) -> None: self._check_open(meth_name) super()._check_conn_validity(meth_name) - def __del__(self): + def __del__(self) -> None: self._state.detach() self._connection._maybe_gc_stmt(self._state) diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py index 8b3e06a0..af9287bd 100644 --- a/asyncpg/protocol/__init__.py +++ b/asyncpg/protocol/__init__.py @@ -6,4 +6,6 @@ # flake8: NOQA +from __future__ import annotations + from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/asyncpg/protocol/protocol.pyi b/asyncpg/protocol/protocol.pyi new file mode 100644 index 00000000..ea468e6d --- /dev/null +++ b/asyncpg/protocol/protocol.pyi @@ -0,0 +1,300 @@ +import asyncio +import asyncio.protocols +import hmac +from codecs import CodecInfo +from collections.abc import Callable, Iterable, Iterator, Sequence +from hashlib import md5, sha256 +from typing import ( + Any, + ClassVar, + Final, + Generic, + Literal, + NewType, + TypeVar, + final, + overload, +) +from typing_extensions import TypeAlias + +import asyncpg.pgproto.pgproto + +from ..connect_utils import _ConnectionParameters +from ..pgproto.pgproto import WriteBuffer +from ..types import Attribute, Type + +_T = TypeVar('_T') +_Record = TypeVar('_Record', bound=Record) +_OtherRecord = TypeVar('_OtherRecord', bound=Record) +_PreparedStatementState = TypeVar( + '_PreparedStatementState', bound=PreparedStatementState[Any] +) + +_NoTimeoutType = NewType('_NoTimeoutType', object) +_TimeoutType: TypeAlias = float | None | _NoTimeoutType + +BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]] +BUILTIN_TYPE_OID_MAP: Final[dict[int, str]] +NO_TIMEOUT: Final[_NoTimeoutType] + +hashlib_md5 = md5 + +@final +class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext): + __pyx_vtable__: Any + def __init__(self, conn_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typeinfos: Iterable[object], + typekind: str, + encoder: Callable[[Any], Any], + decoder: Callable[[Any], Any], + format: object, + ) -> Any: ... + def clear_type_cache(self) -> None: ... + def get_data_codec( + self, oid: int, format: object = ..., ignore_custom_codec: bool = ... + ) -> Any: ... + def get_text_codec(self) -> CodecInfo: ... + def register_data_types(self, types: Iterable[object]) -> None: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> None: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __getattr__(self, name: str) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class PreparedStatementState(Generic[_Record]): + closed: bool + prepared: bool + name: str + query: str + refs: int + record_class: type[_Record] + ignore_custom_codec: bool + __pyx_vtable__: Any + def __init__( + self, + name: str, + query: str, + protocol: BaseProtocol[Any], + record_class: type[_Record], + ignore_custom_codec: bool, + ) -> None: ... + def _get_parameters(self) -> tuple[Type, ...]: ... + def _get_attributes(self) -> tuple[Attribute, ...]: ... + def _init_types(self) -> set[int]: ... + def _init_codecs(self) -> None: ... + def attach(self) -> None: ... + def detach(self) -> None: ... + def mark_closed(self) -> None: ... + def mark_unprepared(self) -> None: ... + def __reduce__(self) -> Any: ... + +class CoreProtocol: + backend_pid: Any + backend_secret: Any + __pyx_vtable__: Any + def __init__(self, con_params: _ConnectionParameters) -> None: ... + def is_in_transaction(self) -> bool: ... + def __reduce__(self) -> Any: ... + +class BaseProtocol(CoreProtocol, Generic[_Record]): + queries_count: Any + is_ssl: bool + __pyx_vtable__: Any + def __init__( + self, + addr: object, + connected_fut: object, + con_params: _ConnectionParameters, + record_class: type[_Record], + loop: object, + ) -> None: ... + def set_connection(self, connection: object) -> None: ... + def get_server_pid(self, *args: object, **kwargs: object) -> int: ... + def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ... + def get_record_class(self) -> type[_Record]: ... + def abort(self) -> None: ... + async def bind( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + timeout: _TimeoutType, + ) -> Any: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[False], + timeout: _TimeoutType, + ) -> list[_OtherRecord]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[True], + timeout: _TimeoutType, + ) -> tuple[list[_OtherRecord], bytes, bool]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: bool, + timeout: _TimeoutType, + ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ... + async def bind_execute_many( + self, + state: PreparedStatementState[_OtherRecord], + args: Iterable[Sequence[object]], + portal_name: str, + timeout: _TimeoutType, + ) -> None: ... + async def close(self, timeout: _TimeoutType) -> None: ... + def _get_timeout(self, timeout: _TimeoutType) -> float | None: ... + def _is_cancelling(self) -> bool: ... + async def _wait_for_cancellation(self) -> None: ... + async def close_statement( + self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType + ) -> Any: ... + async def copy_in(self, *args: object, **kwargs: object) -> str: ... + async def copy_out(self, *args: object, **kwargs: object) -> str: ... + async def execute(self, *args: object, **kwargs: object) -> Any: ... + def is_closed(self, *args: object, **kwargs: object) -> Any: ... + def is_connected(self, *args: object, **kwargs: object) -> Any: ... + def data_received(self, data: object) -> None: ... + def connection_made(self, transport: object) -> None: ... + def connection_lost(self, exc: Exception | None) -> None: ... + def pause_writing(self, *args: object, **kwargs: object) -> Any: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: _PreparedStatementState, + ignore_custom_codec: bool = ..., + record_class: None, + ) -> _PreparedStatementState: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: None = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecord], + ) -> PreparedStatementState[_OtherRecord]: ... + async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ... + async def query(self, *args: object, **kwargs: object) -> str: ... + def resume_writing(self, *args: object, **kwargs: object) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class Codec: + __pyx_vtable__: Any + def __reduce__(self) -> Any: ... + +class DataCodecConfig: + __pyx_vtable__: Any + def __init__(self, cache_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + typeinfos: Iterable[object], + encoder: Callable[[ConnectionSettings, WriteBuffer, object], object], + decoder: Callable[..., object], + format: object, + xformat: object, + ) -> Any: ... + def add_types(self, types: Iterable[object]) -> Any: ... + def clear_type_cache(self) -> None: ... + def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> Any: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __reduce__(self) -> Any: ... + +class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ... + +class Record: + @overload + def get(self, key: str) -> Any | None: ... + @overload + def get(self, key: str, default: _T) -> Any | _T: ... + def items(self) -> Iterator[tuple[str, Any]]: ... + def keys(self) -> Iterator[str]: ... + def values(self) -> Iterator[Any]: ... + @overload + def __getitem__(self, index: str) -> Any: ... + @overload + def __getitem__(self, index: int) -> Any: ... + @overload + def __getitem__(self, index: slice) -> tuple[Any, ...]: ... + def __iter__(self) -> Iterator[Any]: ... + def __contains__(self, x: object) -> bool: ... + def __len__(self) -> int: ... + +class Timer: + def __init__(self, budget: float | None) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, et: object, e: object, tb: object) -> None: ... + def get_remaining_budget(self) -> float: ... + def has_budget_greater_than(self, amount: float) -> bool: ... + +@final +class SCRAMAuthentication: + AUTHENTICATION_METHODS: ClassVar[list[str]] + DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int] + DIGEST = sha256 + REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]] + REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]] + SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]] + authentication_method: bytes + authorization_message: bytes | None + client_channel_binding: bytes + client_first_message_bare: bytes | None + client_nonce: bytes | None + client_proof: bytes | None + password_salt: bytes | None + password_iterations: int + server_first_message: bytes | None + server_key: hmac.HMAC | None + server_nonce: bytes | None diff --git a/asyncpg/py.typed b/asyncpg/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/asyncpg/serverversion.py b/asyncpg/serverversion.py index 31568a2e..80fca72a 100644 --- a/asyncpg/serverversion.py +++ b/asyncpg/serverversion.py @@ -4,12 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import re +import typing from .types import ServerVersion -version_regex = re.compile( +version_regex: typing.Final = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" @@ -19,7 +21,15 @@ ) -def split_server_version_string(version_string): +class _VersionDict(typing.TypedDict): + major: int + minor: int | None + micro: int | None + releaselevel: str | None + serial: int | None + + +def split_server_version_string(version_string: str) -> ServerVersion: version_match = version_regex.search(version_string) if version_match is None: @@ -28,17 +38,17 @@ def split_server_version_string(version_string): f'version from "{version_string}"' ) - version = version_match.groupdict() + version = typing.cast(_VersionDict, version_match.groupdict()) for ver_key, ver_value in version.items(): # Cast all possible versions parts to int try: - version[ver_key] = int(ver_value) + version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501 except (TypeError, ValueError): pass - if version.get("major") < 10: + if version["major"] < 10: return ServerVersion( - version.get("major"), + version["major"], version.get("minor") or 0, version.get("micro") or 0, version.get("releaselevel") or "final", @@ -52,7 +62,7 @@ def split_server_version_string(version_string): # want to keep that behaviour consistent, i.e not fail # a major version check due to a bugfix release. return ServerVersion( - version.get("major"), + version["major"], 0, version.get("minor") or 0, version.get("releaselevel") or "final", diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index 562811e6..59a6fe7f 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -4,12 +4,17 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import enum +import typing from . import connresource from . import exceptions as apg_errors +if typing.TYPE_CHECKING: + from . import connection as _connection + class TransactionState(enum.Enum): NEW = 0 @@ -19,13 +24,16 @@ class TransactionState(enum.Enum): FAILED = 4 -ISOLATION_LEVELS = { +IsolationLevels = typing.Literal[ + 'read_committed', 'read_uncommitted', 'serializable', 'repeatable_read' +] +ISOLATION_LEVELS: typing.Final[set[IsolationLevels]] = { 'read_committed', 'read_uncommitted', 'serializable', 'repeatable_read', } -ISOLATION_LEVELS_BY_VALUE = { +ISOLATION_LEVELS_BY_VALUE: typing.Final[dict[str, IsolationLevels]] = { 'read committed': 'read_committed', 'read uncommitted': 'read_uncommitted', 'serializable': 'serializable', @@ -41,10 +49,24 @@ class Transaction(connresource.ConnectionResource): function. """ - __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', + __slots__ = ('_isolation', '_readonly', '_deferrable', '_state', '_nested', '_id', '_managed') - def __init__(self, connection, isolation, readonly, deferrable): + _isolation: IsolationLevels | None + _readonly: bool + _deferrable: bool + _state: TransactionState + _nested: bool + _id: str | None + _managed: bool + + def __init__( + self, + connection: _connection.Connection[typing.Any], + isolation: IsolationLevels | None, + readonly: bool, + deferrable: bool, + ) -> None: super().__init__(connection) if isolation and isolation not in ISOLATION_LEVELS: @@ -60,14 +82,14 @@ def __init__(self, connection, isolation, readonly, deferrable): self._id = None self._managed = False - async def __aenter__(self): + async def __aenter__(self) -> None: if self._managed: raise apg_errors.InterfaceError( 'cannot enter context: already in an `async with` block') self._managed = True await self.start() - async def __aexit__(self, extype, ex, tb): + async def __aexit__(self, extype: object, ex: object, tb: object) -> None: try: self._check_conn_validity('__aexit__') except apg_errors.InterfaceError: @@ -93,7 +115,7 @@ async def __aexit__(self, extype, ex, tb): self._managed = False @connresource.guarded - async def start(self): + async def start(self) -> None: """Enter the transaction or savepoint block.""" self.__check_state_base('start') if self._state is TransactionState.STARTED: @@ -150,7 +172,7 @@ async def start(self): else: self._state = TransactionState.STARTED - def __check_state_base(self, opname): + def __check_state_base(self, opname: str) -> None: if self._state is TransactionState.COMMITTED: raise apg_errors.InterfaceError( 'cannot {}; the transaction is already committed'.format( @@ -164,7 +186,7 @@ def __check_state_base(self, opname): 'cannot {}; the transaction is in error state'.format( opname)) - def __check_state(self, opname): + def __check_state(self, opname: str) -> None: if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise apg_errors.InterfaceError( @@ -172,7 +194,7 @@ def __check_state(self, opname): opname)) self.__check_state_base(opname) - async def __commit(self): + async def __commit(self) -> None: self.__check_state('commit') if self._connection._top_xact is self: @@ -191,7 +213,7 @@ async def __commit(self): else: self._state = TransactionState.COMMITTED - async def __rollback(self): + async def __rollback(self) -> None: self.__check_state('rollback') if self._connection._top_xact is self: @@ -211,7 +233,7 @@ async def __rollback(self): self._state = TransactionState.ROLLEDBACK @connresource.guarded - async def commit(self): + async def commit(self) -> None: """Exit the transaction or savepoint block and commit changes.""" if self._managed: raise apg_errors.InterfaceError( @@ -219,15 +241,15 @@ async def commit(self): await self.__commit() @connresource.guarded - async def rollback(self): + async def rollback(self) -> None: """Exit the transaction or savepoint block and rollback changes.""" if self._managed: raise apg_errors.InterfaceError( 'cannot manually rollback from within an `async with` block') await self.__rollback() - def __repr__(self): - attrs = [] + def __repr__(self) -> str: + attrs: list[str] = [] attrs.append('state:{}'.format(self._state.name.lower())) if self._isolation is not None: diff --git a/asyncpg/types.py b/asyncpg/types.py index bd5813fc..11055509 100644 --- a/asyncpg/types.py +++ b/asyncpg/types.py @@ -4,8 +4,17 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -import collections +import typing + +if typing.TYPE_CHECKING: + import sys + + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self from asyncpg.pgproto.types import ( BitString, Point, Path, Polygon, @@ -19,7 +28,13 @@ ) -Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema']) +class Type(typing.NamedTuple): + oid: int + name: str + kind: str + schema: str + + Type.__doc__ = 'Database data type.' Type.oid.__doc__ = 'OID of the type.' Type.name.__doc__ = 'Type name. For example "int2".' @@ -28,25 +43,61 @@ Type.schema.__doc__ = 'Name of the database schema that defines the type.' -Attribute = collections.namedtuple('Attribute', ['name', 'type']) +class Attribute(typing.NamedTuple): + name: str + type: Type + + Attribute.__doc__ = 'Database relation attribute.' Attribute.name.__doc__ = 'Attribute name.' Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' -ServerVersion = collections.namedtuple( - 'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial']) +class ServerVersion(typing.NamedTuple): + major: int + minor: int + micro: int + releaselevel: str + serial: int + + ServerVersion.__doc__ = 'PostgreSQL server version tuple.' -class Range: - """Immutable representation of PostgreSQL `range` type.""" +class _RangeValue(typing.Protocol): + def __eq__(self, __value: object) -> bool: + ... + + def __lt__(self, __other: _RangeValue) -> bool: + ... + + def __gt__(self, __other: _RangeValue) -> bool: + ... + - __slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty' +_RV = typing.TypeVar('_RV', bound=_RangeValue) + + +class Range(typing.Generic[_RV]): + """Immutable representation of PostgreSQL `range` type.""" - def __init__(self, lower=None, upper=None, *, - lower_inc=True, upper_inc=False, - empty=False): + __slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty') + + _lower: _RV | None + _upper: _RV | None + _lower_inc: bool + _upper_inc: bool + _empty: bool + + def __init__( + self, + lower: _RV | None = None, + upper: _RV | None = None, + *, + lower_inc: bool = True, + upper_inc: bool = False, + empty: bool = False + ) -> None: self._empty = empty if empty: self._lower = self._upper = None @@ -58,34 +109,34 @@ def __init__(self, lower=None, upper=None, *, self._upper_inc = upper is not None and upper_inc @property - def lower(self): + def lower(self) -> _RV | None: return self._lower @property - def lower_inc(self): + def lower_inc(self) -> bool: return self._lower_inc @property - def lower_inf(self): + def lower_inf(self) -> bool: return self._lower is None and not self._empty @property - def upper(self): + def upper(self) -> _RV | None: return self._upper @property - def upper_inc(self): + def upper_inc(self) -> bool: return self._upper_inc @property - def upper_inf(self): + def upper_inf(self) -> bool: return self._upper is None and not self._empty @property - def isempty(self): + def isempty(self) -> bool: return self._empty - def _issubset_lower(self, other): + def _issubset_lower(self, other: Self) -> bool: if other._lower is None: return True if self._lower is None: @@ -96,7 +147,7 @@ def _issubset_lower(self, other): and (other._lower_inc or not self._lower_inc) ) - def _issubset_upper(self, other): + def _issubset_upper(self, other: Self) -> bool: if other._upper is None: return True if self._upper is None: @@ -107,7 +158,7 @@ def _issubset_upper(self, other): and (other._upper_inc or not self._upper_inc) ) - def issubset(self, other): + def issubset(self, other: Self) -> bool: if self._empty: return True if other._empty: @@ -115,13 +166,13 @@ def issubset(self, other): return self._issubset_lower(other) and self._issubset_upper(other) - def issuperset(self, other): + def issuperset(self, other: Self) -> bool: return other.issubset(self) - def __bool__(self): + def __bool__(self) -> bool: return not self._empty - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Range): return NotImplemented @@ -132,14 +183,14 @@ def __eq__(self, other): self._upper_inc, self._empty ) == ( - other._lower, - other._upper, + other._lower, # pyright: ignore [reportUnknownMemberType] + other._upper, # pyright: ignore [reportUnknownMemberType] other._lower_inc, other._upper_inc, other._empty ) - def __hash__(self): + def __hash__(self) -> int: return hash(( self._lower, self._upper, @@ -148,7 +199,7 @@ def __hash__(self): self._empty )) - def __repr__(self): + def __repr__(self) -> str: if self._empty: desc = 'empty' else: diff --git a/asyncpg/utils.py b/asyncpg/utils.py index 3940e04d..941ee585 100644 --- a/asyncpg/utils.py +++ b/asyncpg/utils.py @@ -4,24 +4,33 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import re +import typing +if typing.TYPE_CHECKING: + from . import connection -def _quote_ident(ident): + +def _quote_ident(ident: str) -> str: return '"{}"'.format(ident.replace('"', '""')) -def _quote_literal(string): +def _quote_literal(string: str) -> str: return "'{}'".format(string.replace("'", "''")) -async def _mogrify(conn, query, args): +async def _mogrify( + conn: connection.Connection[typing.Any], + query: str, + args: tuple[typing.Any, ...] +) -> str: """Safely inline arguments to query text.""" # Introspect the target query for argument types and # build a list of safely-quoted fully-qualified type names. ps = await conn.prepare(query) - paramtypes = [] + paramtypes: list[str] = [] for t in ps.get_parameters(): if t.name.endswith('[]'): pname = '_' + t.name[:-2] @@ -40,6 +49,9 @@ async def _mogrify(conn, query, args): textified = await conn.fetchrow( 'SELECT {cols}'.format(cols=', '.join(cols)), *args) + if typing.TYPE_CHECKING: + assert textified is not None + # Finally, replace $n references with text values. return re.sub( r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) diff --git a/pyproject.toml b/pyproject.toml index ed2340a7..7c852418 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Topic :: Database :: Front-Ends", ] dependencies = [ - 'async_timeout>=4.0.3; python_version < "3.12.0"' + 'async_timeout>=4.0.3; python_version < "3.12.0"', ] [project.urls] @@ -37,7 +37,9 @@ github = "https://github.com/MagicStack/asyncpg" [project.optional-dependencies] test = [ 'flake8~=6.1', + 'flake8-pyi~=24.1.0', 'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"', + 'mypy~=1.8.0' ] docs = [ 'Sphinx~=5.3.0', @@ -102,3 +104,15 @@ exclude_lines = [ "if __name__ == .__main__.", ] show_missing = true + +[tool.mypy] +incremental = true +strict = true +implicit_reexport = true + +[[tool.mypy.overrides]] +module = [ + "asyncpg._testbase", + "asyncpg._testbase.*" +] +ignore_errors = true diff --git a/setup.py b/setup.py index c4d42d82..f7c3c471 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ with open(str(_ROOT / 'asyncpg' / '_version.py')) as f: for line in f: - if line.startswith('__version__ ='): + if line.startswith('__version__: typing.Final ='): _, _, version = line.partition('=') VERSION = version.strip(" \n'\"") break diff --git a/tests/test__sourcecode.py b/tests/test__sourcecode.py index 28ffdea7..b19044d4 100644 --- a/tests/test__sourcecode.py +++ b/tests/test__sourcecode.py @@ -14,7 +14,7 @@ def find_root(): return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -class TestFlake8(unittest.TestCase): +class TestCodeQuality(unittest.TestCase): def test_flake8(self): try: @@ -38,3 +38,34 @@ def test_flake8(self): output = ex.output.decode() raise AssertionError( 'flake8 validation failed:\n{}'.format(output)) from None + + def test_mypy(self): + try: + import mypy # NoQA + except ImportError: + raise unittest.SkipTest('mypy module is missing') + + root_path = find_root() + config_path = os.path.join(root_path, 'pyproject.toml') + if not os.path.exists(config_path): + raise RuntimeError('could not locate mypy.ini file') + + try: + subprocess.run( + [ + sys.executable, + '-m', + 'mypy', + '--config-file', + config_path, + 'asyncpg' + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=root_path + ) + except subprocess.CalledProcessError as ex: + output = ex.output.decode() + raise AssertionError( + 'mypy validation failed:\n{}'.format(output)) from None diff --git a/tools/generate_exceptions.py b/tools/generate_exceptions.py index 0b626558..bea0d30e 100755 --- a/tools/generate_exceptions.py +++ b/tools/generate_exceptions.py @@ -13,7 +13,8 @@ import string import textwrap -from asyncpg.exceptions import _base as apg_exc +from asyncpg.exceptions import _postgres_message as _pgm_exc +from asyncpg.exceptions import _base as _apg_exc _namemap = { @@ -87,14 +88,15 @@ class {clsname}({base}): buf = '# GENERATED FROM postgresql/src/backend/utils/errcodes.txt\n' + \ '# DO NOT MODIFY, use tools/generate_exceptions.py to update\n\n' + \ - 'from ._base import * # NOQA\nfrom . import _base\n\n\n' + 'from __future__ import annotations\n\n' + \ + 'import typing\nfrom ._base import * # NOQA\nfrom . import _base\n\n\n' classes = [] clsnames = set() def _add_class(clsname, base, sqlstate, docstring): if sqlstate: - sqlstate = "sqlstate = '{}'".format(sqlstate) + sqlstate = "sqlstate: typing.ClassVar[str] = '{}'".format(sqlstate) else: sqlstate = '' @@ -150,10 +152,10 @@ def _add_class(clsname, base, sqlstate, docstring): else: base = section_class - existing = apg_exc.PostgresMessageMeta.get_message_class_for_sqlstate( + existing = _pgm_exc.PostgresMessageMeta.get_message_class_for_sqlstate( sqlstate) - if (existing and existing is not apg_exc.UnknownPostgresError and + if (existing and existing is not _apg_exc.UnknownPostgresError and existing.__doc__): docstring = '"""{}"""\n\n '.format(existing.__doc__) else: @@ -164,7 +166,7 @@ def _add_class(clsname, base, sqlstate, docstring): subclasses = _subclassmap.get(sqlstate, []) for subclass in subclasses: - existing = getattr(apg_exc, subclass, None) + existing = getattr(_apg_exc, subclass, None) if existing and existing.__doc__: docstring = '"""{}"""\n\n '.format(existing.__doc__) else: @@ -176,7 +178,7 @@ def _add_class(clsname, base, sqlstate, docstring): buf += '\n\n\n'.join(classes) _all = textwrap.wrap(', '.join('{!r}'.format(c) for c in sorted(clsnames))) - buf += '\n\n\n__all__ = (\n {}\n)'.format( + buf += '\n\n\n__all__ = [\n {}\n]'.format( '\n '.join(_all)) buf += '\n\n__all__ += _base.__all__'