You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1120 lines
35 KiB
Python
1120 lines
35 KiB
Python
3 months ago
|
# postgresql/asyncpg.py
|
||
|
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
|
||
|
# file>
|
||
|
#
|
||
|
# This module is part of SQLAlchemy and is released under
|
||
|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||
|
r"""
|
||
|
.. dialect:: postgresql+asyncpg
|
||
|
:name: asyncpg
|
||
|
:dbapi: asyncpg
|
||
|
:connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
|
||
|
:url: https://magicstack.github.io/asyncpg/
|
||
|
|
||
|
The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
|
||
|
|
||
|
Using a special asyncio mediation layer, the asyncpg dialect is usable
|
||
|
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
|
||
|
extension package.
|
||
|
|
||
|
This dialect should normally be used only with the
|
||
|
:func:`_asyncio.create_async_engine` engine creation function::
|
||
|
|
||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||
|
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
|
||
|
|
||
|
The dialect can also be run as a "synchronous" dialect within the
|
||
|
:func:`_sa.create_engine` function, which will pass "await" calls into
|
||
|
an ad-hoc event loop. This mode of operation is of **limited use**
|
||
|
and is for special testing scenarios only. The mode can be enabled by
|
||
|
adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
|
||
|
in conjunction with :func:`_sa.create_engine`::
|
||
|
|
||
|
# for testing purposes only; do not use in production!
|
||
|
engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
|
||
|
|
||
|
|
||
|
.. versionadded:: 1.4
|
||
|
|
||
|
.. note::
|
||
|
|
||
|
By default asyncpg does not decode the ``json`` and ``jsonb`` types and
|
||
|
returns them as strings. SQLAlchemy sets default type decoder for ``json``
|
||
|
and ``jsonb`` types using the python builtin ``json.loads`` function.
|
||
|
The json implementation used can be changed by setting the attribute
|
||
|
``json_deserializer`` when creating the engine with
|
||
|
:func:`create_engine` or :func:`create_async_engine`.
|
||
|
|
||
|
|
||
|
.. _asyncpg_prepared_statement_cache:
|
||
|
|
||
|
Prepared Statement Cache
|
||
|
--------------------------
|
||
|
|
||
|
The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
|
||
|
for all statements. The prepared statement objects are cached after
|
||
|
construction which appears to grant a 10% or more performance improvement for
|
||
|
statement invocation. The cache is on a per-DBAPI connection basis, which
|
||
|
means that the primary storage for prepared statements is within DBAPI
|
||
|
connections pooled within the connection pool. The size of this cache
|
||
|
defaults to 100 statements per DBAPI connection and may be adjusted using the
|
||
|
``prepared_statement_cache_size`` DBAPI argument (note that while this argument
|
||
|
is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
|
||
|
asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
|
||
|
argument)::
|
||
|
|
||
|
|
||
|
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
|
||
|
|
||
|
To disable the prepared statement cache, use a value of zero::
|
||
|
|
||
|
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
|
||
|
|
||
|
.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
|
||
|
|
||
|
|
||
|
.. warning:: The ``asyncpg`` database driver necessarily uses caches for
|
||
|
PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
|
||
|
such as ``ENUM`` objects are changed via DDL operations. Additionally,
|
||
|
prepared statements themselves which are optionally cached by SQLAlchemy's
|
||
|
driver as described above may also become "stale" when DDL has been emitted
|
||
|
to the PostgreSQL database which modifies the tables or other objects
|
||
|
involved in a particular prepared statement.
|
||
|
|
||
|
The SQLAlchemy asyncpg dialect will invalidate these caches within its local
|
||
|
process when statements that represent DDL are emitted on a local
|
||
|
connection, but this is only controllable within a single Python process /
|
||
|
database engine. If DDL changes are made from other database engines
|
||
|
and/or processes, a running application may encounter asyncpg exceptions
|
||
|
``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
|
||
|
failed for type <oid>")`` if it refers to pooled database connections which
|
||
|
operated upon the previous structures. The SQLAlchemy asyncpg dialect will
|
||
|
recover from these error cases when the driver raises these exceptions by
|
||
|
clearing its internal caches as well as those of the asyncpg driver in
|
||
|
response to them, but cannot prevent them from being raised in the first
|
||
|
place if the cached prepared statement or asyncpg type caches have gone
|
||
|
stale, nor can it retry the statement as the PostgreSQL transaction is
|
||
|
invalidated when these errors occur.
|
||
|
|
||
|
Disabling the PostgreSQL JIT to improve ENUM datatype handling
|
||
|
---------------------------------------------------------------
|
||
|
|
||
|
Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
|
||
|
using PostgreSQL ENUM datatypes, where upon the creation of new database
|
||
|
connections, an expensive query may be emitted in order to retrieve metadata
|
||
|
regarding custom types which has been shown to negatively affect performance.
|
||
|
To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
|
||
|
client using this setting passed to :func:`_asyncio.create_async_engine`::
|
||
|
|
||
|
engine = create_async_engine(
|
||
|
"postgresql+asyncpg://user:password@localhost/tmp",
|
||
|
connect_args={"server_settings": {"jit": "off"}},
|
||
|
)
|
||
|
|
||
|
.. seealso::
|
||
|
|
||
|
https://github.com/MagicStack/asyncpg/issues/727
|
||
|
|
||
|
""" # noqa
|
||
|
|
||
|
import collections
|
||
|
import decimal
|
||
|
import json as _py_json
|
||
|
import re
|
||
|
import time
|
||
|
|
||
|
from . import json
|
||
|
from .base import _DECIMAL_TYPES
|
||
|
from .base import _FLOAT_TYPES
|
||
|
from .base import _INT_TYPES
|
||
|
from .base import ENUM
|
||
|
from .base import INTERVAL
|
||
|
from .base import OID
|
||
|
from .base import PGCompiler
|
||
|
from .base import PGDialect
|
||
|
from .base import PGExecutionContext
|
||
|
from .base import PGIdentifierPreparer
|
||
|
from .base import REGCLASS
|
||
|
from .base import UUID
|
||
|
from ... import exc
|
||
|
from ... import pool
|
||
|
from ... import processors
|
||
|
from ... import util
|
||
|
from ...engine import AdaptedConnection
|
||
|
from ...sql import sqltypes
|
||
|
from ...util.concurrency import asyncio
|
||
|
from ...util.concurrency import await_fallback
|
||
|
from ...util.concurrency import await_only
|
||
|
|
||
|
|
||
|
try:
|
||
|
from uuid import UUID as _python_UUID # noqa
|
||
|
except ImportError:
|
||
|
_python_UUID = None
|
||
|
|
||
|
|
||
|
class AsyncpgTime(sqltypes.Time):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
if self.timezone:
|
||
|
return dbapi.TIME_W_TZ
|
||
|
else:
|
||
|
return dbapi.TIME
|
||
|
|
||
|
|
||
|
class AsyncpgDate(sqltypes.Date):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.DATE
|
||
|
|
||
|
|
||
|
class AsyncpgDateTime(sqltypes.DateTime):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
if self.timezone:
|
||
|
return dbapi.TIMESTAMP_W_TZ
|
||
|
else:
|
||
|
return dbapi.TIMESTAMP
|
||
|
|
||
|
|
||
|
class AsyncpgBoolean(sqltypes.Boolean):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.BOOLEAN
|
||
|
|
||
|
|
||
|
class AsyncPgInterval(INTERVAL):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.INTERVAL
|
||
|
|
||
|
@classmethod
|
||
|
def adapt_emulated_to_native(cls, interval, **kw):
|
||
|
|
||
|
return AsyncPgInterval(precision=interval.second_precision)
|
||
|
|
||
|
|
||
|
class AsyncPgEnum(ENUM):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.ENUM
|
||
|
|
||
|
|
||
|
class AsyncpgInteger(sqltypes.Integer):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.INTEGER
|
||
|
|
||
|
|
||
|
class AsyncpgBigInteger(sqltypes.BigInteger):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.BIGINTEGER
|
||
|
|
||
|
|
||
|
class AsyncpgJSON(json.JSON):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.JSON
|
||
|
|
||
|
def result_processor(self, dialect, coltype):
|
||
|
return None
|
||
|
|
||
|
|
||
|
class AsyncpgJSONB(json.JSONB):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.JSONB
|
||
|
|
||
|
def result_processor(self, dialect, coltype):
|
||
|
return None
|
||
|
|
||
|
|
||
|
class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
raise NotImplementedError("should not be here")
|
||
|
|
||
|
|
||
|
class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.INTEGER
|
||
|
|
||
|
|
||
|
class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.STRING
|
||
|
|
||
|
|
||
|
class AsyncpgJSONPathType(json.JSONPathType):
|
||
|
def bind_processor(self, dialect):
|
||
|
def process(value):
|
||
|
assert isinstance(value, util.collections_abc.Sequence)
|
||
|
tokens = [util.text_type(elem) for elem in value]
|
||
|
return tokens
|
||
|
|
||
|
return process
|
||
|
|
||
|
|
||
|
class AsyncpgUUID(UUID):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.UUID
|
||
|
|
||
|
def bind_processor(self, dialect):
|
||
|
if not self.as_uuid and dialect.use_native_uuid:
|
||
|
|
||
|
def process(value):
|
||
|
if value is not None:
|
||
|
value = _python_UUID(value)
|
||
|
return value
|
||
|
|
||
|
return process
|
||
|
|
||
|
def result_processor(self, dialect, coltype):
|
||
|
if not self.as_uuid and dialect.use_native_uuid:
|
||
|
|
||
|
def process(value):
|
||
|
if value is not None:
|
||
|
value = str(value)
|
||
|
return value
|
||
|
|
||
|
return process
|
||
|
|
||
|
|
||
|
class AsyncpgNumeric(sqltypes.Numeric):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.NUMBER
|
||
|
|
||
|
def bind_processor(self, dialect):
|
||
|
return None
|
||
|
|
||
|
def result_processor(self, dialect, coltype):
|
||
|
if self.asdecimal:
|
||
|
if coltype in _FLOAT_TYPES:
|
||
|
return processors.to_decimal_processor_factory(
|
||
|
decimal.Decimal, self._effective_decimal_return_scale
|
||
|
)
|
||
|
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||
|
# pg8000 returns Decimal natively for 1700
|
||
|
return None
|
||
|
else:
|
||
|
raise exc.InvalidRequestError(
|
||
|
"Unknown PG numeric type: %d" % coltype
|
||
|
)
|
||
|
else:
|
||
|
if coltype in _FLOAT_TYPES:
|
||
|
# pg8000 returns float natively for 701
|
||
|
return None
|
||
|
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||
|
return processors.to_float
|
||
|
else:
|
||
|
raise exc.InvalidRequestError(
|
||
|
"Unknown PG numeric type: %d" % coltype
|
||
|
)
|
||
|
|
||
|
|
||
|
class AsyncpgFloat(AsyncpgNumeric):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.FLOAT
|
||
|
|
||
|
|
||
|
class AsyncpgREGCLASS(REGCLASS):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.STRING
|
||
|
|
||
|
|
||
|
class AsyncpgOID(OID):
|
||
|
def get_dbapi_type(self, dbapi):
|
||
|
return dbapi.INTEGER
|
||
|
|
||
|
|
||
|
class PGExecutionContext_asyncpg(PGExecutionContext):
|
||
|
def handle_dbapi_exception(self, e):
|
||
|
if isinstance(
|
||
|
e,
|
||
|
(
|
||
|
self.dialect.dbapi.InvalidCachedStatementError,
|
||
|
self.dialect.dbapi.InternalServerError,
|
||
|
),
|
||
|
):
|
||
|
self.dialect._invalidate_schema_cache()
|
||
|
|
||
|
def pre_exec(self):
|
||
|
if self.isddl:
|
||
|
self.dialect._invalidate_schema_cache()
|
||
|
|
||
|
self.cursor._invalidate_schema_cache_asof = (
|
||
|
self.dialect._invalidate_schema_cache_asof
|
||
|
)
|
||
|
|
||
|
if not self.compiled:
|
||
|
return
|
||
|
|
||
|
# we have to exclude ENUM because "enum" not really a "type"
|
||
|
# we can cast to, it has to be the name of the type itself.
|
||
|
# for now we just omit it from casting
|
||
|
self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
|
||
|
|
||
|
def create_server_side_cursor(self):
|
||
|
return self._dbapi_connection.cursor(server_side=True)
|
||
|
|
||
|
|
||
|
class PGCompiler_asyncpg(PGCompiler):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AsyncAdapt_asyncpg_cursor:
|
||
|
__slots__ = (
|
||
|
"_adapt_connection",
|
||
|
"_connection",
|
||
|
"_rows",
|
||
|
"description",
|
||
|
"arraysize",
|
||
|
"rowcount",
|
||
|
"_inputsizes",
|
||
|
"_cursor",
|
||
|
"_invalidate_schema_cache_asof",
|
||
|
)
|
||
|
|
||
|
server_side = False
|
||
|
|
||
|
def __init__(self, adapt_connection):
|
||
|
self._adapt_connection = adapt_connection
|
||
|
self._connection = adapt_connection._connection
|
||
|
self._rows = []
|
||
|
self._cursor = None
|
||
|
self.description = None
|
||
|
self.arraysize = 1
|
||
|
self.rowcount = -1
|
||
|
self._inputsizes = None
|
||
|
self._invalidate_schema_cache_asof = 0
|
||
|
|
||
|
def close(self):
|
||
|
self._rows[:] = []
|
||
|
|
||
|
def _handle_exception(self, error):
|
||
|
self._adapt_connection._handle_exception(error)
|
||
|
|
||
|
def _parameter_placeholders(self, params):
|
||
|
if not self._inputsizes:
|
||
|
return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
|
||
|
else:
|
||
|
return tuple(
|
||
|
"$%d::%s" % (idx, typ) if typ else "$%d" % idx
|
||
|
for idx, typ in enumerate(
|
||
|
(_pg_types.get(typ) for typ in self._inputsizes), 1
|
||
|
)
|
||
|
)
|
||
|
|
||
|
async def _prepare_and_execute(self, operation, parameters):
|
||
|
adapt_connection = self._adapt_connection
|
||
|
|
||
|
async with adapt_connection._execute_mutex:
|
||
|
|
||
|
if not adapt_connection._started:
|
||
|
await adapt_connection._start_transaction()
|
||
|
|
||
|
if parameters is not None:
|
||
|
operation = operation % self._parameter_placeholders(
|
||
|
parameters
|
||
|
)
|
||
|
else:
|
||
|
parameters = ()
|
||
|
|
||
|
try:
|
||
|
prepared_stmt, attributes = await adapt_connection._prepare(
|
||
|
operation, self._invalidate_schema_cache_asof
|
||
|
)
|
||
|
|
||
|
if attributes:
|
||
|
self.description = [
|
||
|
(
|
||
|
attr.name,
|
||
|
attr.type.oid,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
for attr in attributes
|
||
|
]
|
||
|
else:
|
||
|
self.description = None
|
||
|
|
||
|
if self.server_side:
|
||
|
self._cursor = await prepared_stmt.cursor(*parameters)
|
||
|
self.rowcount = -1
|
||
|
else:
|
||
|
self._rows = await prepared_stmt.fetch(*parameters)
|
||
|
status = prepared_stmt.get_statusmsg()
|
||
|
|
||
|
reg = re.match(
|
||
|
r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status
|
||
|
)
|
||
|
if reg:
|
||
|
self.rowcount = int(reg.group(1))
|
||
|
else:
|
||
|
self.rowcount = -1
|
||
|
|
||
|
except Exception as error:
|
||
|
self._handle_exception(error)
|
||
|
|
||
|
async def _executemany(self, operation, seq_of_parameters):
|
||
|
adapt_connection = self._adapt_connection
|
||
|
|
||
|
async with adapt_connection._execute_mutex:
|
||
|
await adapt_connection._check_type_cache_invalidation(
|
||
|
self._invalidate_schema_cache_asof
|
||
|
)
|
||
|
|
||
|
if not adapt_connection._started:
|
||
|
await adapt_connection._start_transaction()
|
||
|
|
||
|
operation = operation % self._parameter_placeholders(
|
||
|
seq_of_parameters[0]
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
return await self._connection.executemany(
|
||
|
operation, seq_of_parameters
|
||
|
)
|
||
|
except Exception as error:
|
||
|
self._handle_exception(error)
|
||
|
|
||
|
def execute(self, operation, parameters=None):
|
||
|
self._adapt_connection.await_(
|
||
|
self._prepare_and_execute(operation, parameters)
|
||
|
)
|
||
|
|
||
|
def executemany(self, operation, seq_of_parameters):
|
||
|
return self._adapt_connection.await_(
|
||
|
self._executemany(operation, seq_of_parameters)
|
||
|
)
|
||
|
|
||
|
def setinputsizes(self, *inputsizes):
|
||
|
self._inputsizes = inputsizes
|
||
|
|
||
|
def __iter__(self):
|
||
|
while self._rows:
|
||
|
yield self._rows.pop(0)
|
||
|
|
||
|
def fetchone(self):
|
||
|
if self._rows:
|
||
|
return self._rows.pop(0)
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def fetchmany(self, size=None):
|
||
|
if size is None:
|
||
|
size = self.arraysize
|
||
|
|
||
|
retval = self._rows[0:size]
|
||
|
self._rows[:] = self._rows[size:]
|
||
|
return retval
|
||
|
|
||
|
def fetchall(self):
|
||
|
retval = self._rows[:]
|
||
|
self._rows[:] = []
|
||
|
return retval
|
||
|
|
||
|
|
||
|
class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||
|
|
||
|
server_side = True
|
||
|
__slots__ = ("_rowbuffer",)
|
||
|
|
||
|
def __init__(self, adapt_connection):
|
||
|
super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection)
|
||
|
self._rowbuffer = None
|
||
|
|
||
|
def close(self):
|
||
|
self._cursor = None
|
||
|
self._rowbuffer = None
|
||
|
|
||
|
def _buffer_rows(self):
|
||
|
new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
|
||
|
self._rowbuffer = collections.deque(new_rows)
|
||
|
|
||
|
def __aiter__(self):
|
||
|
return self
|
||
|
|
||
|
async def __anext__(self):
|
||
|
if not self._rowbuffer:
|
||
|
self._buffer_rows()
|
||
|
|
||
|
while True:
|
||
|
while self._rowbuffer:
|
||
|
yield self._rowbuffer.popleft()
|
||
|
|
||
|
self._buffer_rows()
|
||
|
if not self._rowbuffer:
|
||
|
break
|
||
|
|
||
|
def fetchone(self):
|
||
|
if not self._rowbuffer:
|
||
|
self._buffer_rows()
|
||
|
if not self._rowbuffer:
|
||
|
return None
|
||
|
return self._rowbuffer.popleft()
|
||
|
|
||
|
def fetchmany(self, size=None):
|
||
|
if size is None:
|
||
|
return self.fetchall()
|
||
|
|
||
|
if not self._rowbuffer:
|
||
|
self._buffer_rows()
|
||
|
|
||
|
buf = list(self._rowbuffer)
|
||
|
lb = len(buf)
|
||
|
if size > lb:
|
||
|
buf.extend(
|
||
|
self._adapt_connection.await_(self._cursor.fetch(size - lb))
|
||
|
)
|
||
|
|
||
|
result = buf[0:size]
|
||
|
self._rowbuffer = collections.deque(buf[size:])
|
||
|
return result
|
||
|
|
||
|
def fetchall(self):
|
||
|
ret = list(self._rowbuffer) + list(
|
||
|
self._adapt_connection.await_(self._all())
|
||
|
)
|
||
|
self._rowbuffer.clear()
|
||
|
return ret
|
||
|
|
||
|
async def _all(self):
|
||
|
rows = []
|
||
|
|
||
|
# TODO: looks like we have to hand-roll some kind of batching here.
|
||
|
# hardcoding for the moment but this should be improved.
|
||
|
while True:
|
||
|
batch = await self._cursor.fetch(1000)
|
||
|
if batch:
|
||
|
rows.extend(batch)
|
||
|
continue
|
||
|
else:
|
||
|
break
|
||
|
return rows
|
||
|
|
||
|
def executemany(self, operation, seq_of_parameters):
|
||
|
raise NotImplementedError(
|
||
|
"server side cursor doesn't support executemany yet"
|
||
|
)
|
||
|
|
||
|
|
||
|
class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||
|
__slots__ = (
|
||
|
"dbapi",
|
||
|
"_connection",
|
||
|
"isolation_level",
|
||
|
"_isolation_setting",
|
||
|
"readonly",
|
||
|
"deferrable",
|
||
|
"_transaction",
|
||
|
"_started",
|
||
|
"_prepared_statement_cache",
|
||
|
"_invalidate_schema_cache_asof",
|
||
|
"_execute_mutex",
|
||
|
)
|
||
|
|
||
|
await_ = staticmethod(await_only)
|
||
|
|
||
|
def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
|
||
|
self.dbapi = dbapi
|
||
|
self._connection = connection
|
||
|
self.isolation_level = self._isolation_setting = "read_committed"
|
||
|
self.readonly = False
|
||
|
self.deferrable = False
|
||
|
self._transaction = None
|
||
|
self._started = False
|
||
|
self._invalidate_schema_cache_asof = time.time()
|
||
|
self._execute_mutex = asyncio.Lock()
|
||
|
|
||
|
if prepared_statement_cache_size:
|
||
|
self._prepared_statement_cache = util.LRUCache(
|
||
|
prepared_statement_cache_size
|
||
|
)
|
||
|
else:
|
||
|
self._prepared_statement_cache = None
|
||
|
|
||
|
async def _check_type_cache_invalidation(self, invalidate_timestamp):
|
||
|
if invalidate_timestamp > self._invalidate_schema_cache_asof:
|
||
|
await self._connection.reload_schema_state()
|
||
|
self._invalidate_schema_cache_asof = invalidate_timestamp
|
||
|
|
||
|
async def _prepare(self, operation, invalidate_timestamp):
|
||
|
await self._check_type_cache_invalidation(invalidate_timestamp)
|
||
|
|
||
|
cache = self._prepared_statement_cache
|
||
|
if cache is None:
|
||
|
prepared_stmt = await self._connection.prepare(operation)
|
||
|
attributes = prepared_stmt.get_attributes()
|
||
|
return prepared_stmt, attributes
|
||
|
|
||
|
# asyncpg uses a type cache for the "attributes" which seems to go
|
||
|
# stale independently of the PreparedStatement itself, so place that
|
||
|
# collection in the cache as well.
|
||
|
if operation in cache:
|
||
|
prepared_stmt, attributes, cached_timestamp = cache[operation]
|
||
|
|
||
|
# preparedstatements themselves also go stale for certain DDL
|
||
|
# changes such as size of a VARCHAR changing, so there is also
|
||
|
# a cross-connection invalidation timestamp
|
||
|
if cached_timestamp > invalidate_timestamp:
|
||
|
return prepared_stmt, attributes
|
||
|
|
||
|
prepared_stmt = await self._connection.prepare(operation)
|
||
|
attributes = prepared_stmt.get_attributes()
|
||
|
cache[operation] = (prepared_stmt, attributes, time.time())
|
||
|
|
||
|
return prepared_stmt, attributes
|
||
|
|
||
|
def _handle_exception(self, error):
|
||
|
if self._connection.is_closed():
|
||
|
self._transaction = None
|
||
|
self._started = False
|
||
|
|
||
|
if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
|
||
|
exception_mapping = self.dbapi._asyncpg_error_translate
|
||
|
|
||
|
for super_ in type(error).__mro__:
|
||
|
if super_ in exception_mapping:
|
||
|
translated_error = exception_mapping[super_](
|
||
|
"%s: %s" % (type(error), error)
|
||
|
)
|
||
|
translated_error.pgcode = (
|
||
|
translated_error.sqlstate
|
||
|
) = getattr(error, "sqlstate", None)
|
||
|
raise translated_error from error
|
||
|
else:
|
||
|
raise error
|
||
|
else:
|
||
|
raise error
|
||
|
|
||
|
@property
|
||
|
def autocommit(self):
|
||
|
return self.isolation_level == "autocommit"
|
||
|
|
||
|
@autocommit.setter
|
||
|
def autocommit(self, value):
|
||
|
if value:
|
||
|
self.isolation_level = "autocommit"
|
||
|
else:
|
||
|
self.isolation_level = self._isolation_setting
|
||
|
|
||
|
def set_isolation_level(self, level):
|
||
|
if self._started:
|
||
|
self.rollback()
|
||
|
self.isolation_level = self._isolation_setting = level
|
||
|
|
||
|
async def _start_transaction(self):
|
||
|
if self.isolation_level == "autocommit":
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
self._transaction = self._connection.transaction(
|
||
|
isolation=self.isolation_level,
|
||
|
readonly=self.readonly,
|
||
|
deferrable=self.deferrable,
|
||
|
)
|
||
|
await self._transaction.start()
|
||
|
except Exception as error:
|
||
|
self._handle_exception(error)
|
||
|
else:
|
||
|
self._started = True
|
||
|
|
||
|
def cursor(self, server_side=False):
|
||
|
if server_side:
|
||
|
return AsyncAdapt_asyncpg_ss_cursor(self)
|
||
|
else:
|
||
|
return AsyncAdapt_asyncpg_cursor(self)
|
||
|
|
||
|
def rollback(self):
|
||
|
if self._started:
|
||
|
try:
|
||
|
self.await_(self._transaction.rollback())
|
||
|
except Exception as error:
|
||
|
self._handle_exception(error)
|
||
|
finally:
|
||
|
self._transaction = None
|
||
|
self._started = False
|
||
|
|
||
|
def commit(self):
|
||
|
if self._started:
|
||
|
try:
|
||
|
self.await_(self._transaction.commit())
|
||
|
except Exception as error:
|
||
|
self._handle_exception(error)
|
||
|
finally:
|
||
|
self._transaction = None
|
||
|
self._started = False
|
||
|
|
||
|
def close(self):
|
||
|
self.rollback()
|
||
|
|
||
|
self.await_(self._connection.close())
|
||
|
|
||
|
def terminate(self):
|
||
|
self._connection.terminate()
|
||
|
|
||
|
|
||
|
class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
|
||
|
__slots__ = ()
|
||
|
|
||
|
await_ = staticmethod(await_fallback)
|
||
|
|
||
|
|
||
|
class AsyncAdapt_asyncpg_dbapi:
|
||
|
def __init__(self, asyncpg):
|
||
|
self.asyncpg = asyncpg
|
||
|
self.paramstyle = "format"
|
||
|
|
||
|
def connect(self, *arg, **kw):
|
||
|
async_fallback = kw.pop("async_fallback", False)
|
||
|
prepared_statement_cache_size = kw.pop(
|
||
|
"prepared_statement_cache_size", 100
|
||
|
)
|
||
|
if util.asbool(async_fallback):
|
||
|
return AsyncAdaptFallback_asyncpg_connection(
|
||
|
self,
|
||
|
await_fallback(self.asyncpg.connect(*arg, **kw)),
|
||
|
prepared_statement_cache_size=prepared_statement_cache_size,
|
||
|
)
|
||
|
else:
|
||
|
return AsyncAdapt_asyncpg_connection(
|
||
|
self,
|
||
|
await_only(self.asyncpg.connect(*arg, **kw)),
|
||
|
prepared_statement_cache_size=prepared_statement_cache_size,
|
||
|
)
|
||
|
|
||
|
class Error(Exception):
|
||
|
pass
|
||
|
|
||
|
class Warning(Exception): # noqa
|
||
|
pass
|
||
|
|
||
|
class InterfaceError(Error):
|
||
|
pass
|
||
|
|
||
|
class DatabaseError(Error):
|
||
|
pass
|
||
|
|
||
|
class InternalError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class OperationalError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class ProgrammingError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class IntegrityError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class DataError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class NotSupportedError(DatabaseError):
|
||
|
pass
|
||
|
|
||
|
class InternalServerError(InternalError):
|
||
|
pass
|
||
|
|
||
|
class InvalidCachedStatementError(NotSupportedError):
|
||
|
def __init__(self, message):
|
||
|
super(
|
||
|
AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self
|
||
|
).__init__(
|
||
|
message + " (SQLAlchemy asyncpg dialect will now invalidate "
|
||
|
"all prepared caches in response to this exception)",
|
||
|
)
|
||
|
|
||
|
@util.memoized_property
|
||
|
def _asyncpg_error_translate(self):
|
||
|
import asyncpg
|
||
|
|
||
|
return {
|
||
|
asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
|
||
|
asyncpg.exceptions.PostgresError: self.Error,
|
||
|
asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
|
||
|
asyncpg.exceptions.InterfaceError: self.InterfaceError,
|
||
|
asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
|
||
|
asyncpg.exceptions.InternalServerError: self.InternalServerError,
|
||
|
}
|
||
|
|
||
|
def Binary(self, value):
|
||
|
return value
|
||
|
|
||
|
STRING = util.symbol("STRING")
|
||
|
TIMESTAMP = util.symbol("TIMESTAMP")
|
||
|
TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
|
||
|
TIME = util.symbol("TIME")
|
||
|
TIME_W_TZ = util.symbol("TIME_W_TZ")
|
||
|
DATE = util.symbol("DATE")
|
||
|
INTERVAL = util.symbol("INTERVAL")
|
||
|
NUMBER = util.symbol("NUMBER")
|
||
|
FLOAT = util.symbol("FLOAT")
|
||
|
BOOLEAN = util.symbol("BOOLEAN")
|
||
|
INTEGER = util.symbol("INTEGER")
|
||
|
BIGINTEGER = util.symbol("BIGINTEGER")
|
||
|
BYTES = util.symbol("BYTES")
|
||
|
DECIMAL = util.symbol("DECIMAL")
|
||
|
JSON = util.symbol("JSON")
|
||
|
JSONB = util.symbol("JSONB")
|
||
|
ENUM = util.symbol("ENUM")
|
||
|
UUID = util.symbol("UUID")
|
||
|
BYTEA = util.symbol("BYTEA")
|
||
|
|
||
|
DATETIME = TIMESTAMP
|
||
|
BINARY = BYTEA
|
||
|
|
||
|
|
||
|
_pg_types = {
|
||
|
AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
|
||
|
AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
|
||
|
AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
|
||
|
AsyncAdapt_asyncpg_dbapi.DATE: "date",
|
||
|
AsyncAdapt_asyncpg_dbapi.TIME: "time",
|
||
|
AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
|
||
|
AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
|
||
|
AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
|
||
|
AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
|
||
|
AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
|
||
|
AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
|
||
|
AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
|
||
|
AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
|
||
|
AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
|
||
|
AsyncAdapt_asyncpg_dbapi.JSON: "json",
|
||
|
AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
|
||
|
AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
|
||
|
AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
|
||
|
AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
|
||
|
}
|
||
|
|
||
|
|
||
|
class PGDialect_asyncpg(PGDialect):
|
||
|
driver = "asyncpg"
|
||
|
supports_statement_cache = True
|
||
|
|
||
|
supports_unicode_statements = True
|
||
|
supports_server_side_cursors = True
|
||
|
|
||
|
supports_unicode_binds = True
|
||
|
has_terminate = True
|
||
|
|
||
|
default_paramstyle = "format"
|
||
|
supports_sane_multi_rowcount = False
|
||
|
execution_ctx_cls = PGExecutionContext_asyncpg
|
||
|
statement_compiler = PGCompiler_asyncpg
|
||
|
preparer = PGIdentifierPreparer_asyncpg
|
||
|
|
||
|
use_setinputsizes = True
|
||
|
|
||
|
use_native_uuid = True
|
||
|
|
||
|
colspecs = util.update_copy(
|
||
|
PGDialect.colspecs,
|
||
|
{
|
||
|
sqltypes.Time: AsyncpgTime,
|
||
|
sqltypes.Date: AsyncpgDate,
|
||
|
sqltypes.DateTime: AsyncpgDateTime,
|
||
|
sqltypes.Interval: AsyncPgInterval,
|
||
|
INTERVAL: AsyncPgInterval,
|
||
|
UUID: AsyncpgUUID,
|
||
|
sqltypes.Boolean: AsyncpgBoolean,
|
||
|
sqltypes.Integer: AsyncpgInteger,
|
||
|
sqltypes.BigInteger: AsyncpgBigInteger,
|
||
|
sqltypes.Numeric: AsyncpgNumeric,
|
||
|
sqltypes.Float: AsyncpgFloat,
|
||
|
sqltypes.JSON: AsyncpgJSON,
|
||
|
json.JSONB: AsyncpgJSONB,
|
||
|
sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
|
||
|
sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
|
||
|
sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
|
||
|
sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
|
||
|
sqltypes.Enum: AsyncPgEnum,
|
||
|
OID: AsyncpgOID,
|
||
|
REGCLASS: AsyncpgREGCLASS,
|
||
|
},
|
||
|
)
|
||
|
is_async = True
|
||
|
_invalidate_schema_cache_asof = 0
|
||
|
|
||
|
def _invalidate_schema_cache(self):
|
||
|
self._invalidate_schema_cache_asof = time.time()
|
||
|
|
||
|
@util.memoized_property
|
||
|
def _dbapi_version(self):
|
||
|
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||
|
return tuple(
|
||
|
[
|
||
|
int(x)
|
||
|
for x in re.findall(
|
||
|
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
|
||
|
)
|
||
|
]
|
||
|
)
|
||
|
else:
|
||
|
return (99, 99, 99)
|
||
|
|
||
|
@classmethod
|
||
|
def dbapi(cls):
|
||
|
return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
|
||
|
|
||
|
@util.memoized_property
|
||
|
def _isolation_lookup(self):
|
||
|
return {
|
||
|
"AUTOCOMMIT": "autocommit",
|
||
|
"READ COMMITTED": "read_committed",
|
||
|
"REPEATABLE READ": "repeatable_read",
|
||
|
"SERIALIZABLE": "serializable",
|
||
|
}
|
||
|
|
||
|
def set_isolation_level(self, connection, level):
|
||
|
try:
|
||
|
level = self._isolation_lookup[level.replace("_", " ")]
|
||
|
except KeyError as err:
|
||
|
util.raise_(
|
||
|
exc.ArgumentError(
|
||
|
"Invalid value '%s' for isolation_level. "
|
||
|
"Valid isolation levels for %s are %s"
|
||
|
% (level, self.name, ", ".join(self._isolation_lookup))
|
||
|
),
|
||
|
replace_context=err,
|
||
|
)
|
||
|
|
||
|
connection.set_isolation_level(level)
|
||
|
|
||
|
def set_readonly(self, connection, value):
|
||
|
connection.readonly = value
|
||
|
|
||
|
def get_readonly(self, connection):
|
||
|
return connection.readonly
|
||
|
|
||
|
def set_deferrable(self, connection, value):
|
||
|
connection.deferrable = value
|
||
|
|
||
|
def get_deferrable(self, connection):
|
||
|
return connection.deferrable
|
||
|
|
||
|
def do_terminate(self, dbapi_connection) -> None:
|
||
|
dbapi_connection.terminate()
|
||
|
|
||
|
def create_connect_args(self, url):
|
||
|
opts = url.translate_connect_args(username="user")
|
||
|
|
||
|
opts.update(url.query)
|
||
|
util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
|
||
|
util.coerce_kw_type(opts, "port", int)
|
||
|
return ([], opts)
|
||
|
|
||
|
@classmethod
|
||
|
def get_pool_class(cls, url):
|
||
|
|
||
|
async_fallback = url.query.get("async_fallback", False)
|
||
|
|
||
|
if util.asbool(async_fallback):
|
||
|
return pool.FallbackAsyncAdaptedQueuePool
|
||
|
else:
|
||
|
return pool.AsyncAdaptedQueuePool
|
||
|
|
||
|
def is_disconnect(self, e, connection, cursor):
|
||
|
if connection:
|
||
|
return connection._connection.is_closed()
|
||
|
else:
|
||
|
return isinstance(
|
||
|
e, self.dbapi.InterfaceError
|
||
|
) and "connection is closed" in str(e)
|
||
|
|
||
|
def do_set_input_sizes(self, cursor, list_of_tuples, context):
|
||
|
if self.positional:
|
||
|
cursor.setinputsizes(
|
||
|
*[dbtype for key, dbtype, sqltype in list_of_tuples]
|
||
|
)
|
||
|
else:
|
||
|
cursor.setinputsizes(
|
||
|
**{
|
||
|
key: dbtype
|
||
|
for key, dbtype, sqltype in list_of_tuples
|
||
|
if dbtype
|
||
|
}
|
||
|
)
|
||
|
|
||
|
async def setup_asyncpg_json_codec(self, conn):
|
||
|
"""set up JSON codec for asyncpg.
|
||
|
|
||
|
This occurs for all new connections and
|
||
|
can be overridden by third party dialects.
|
||
|
|
||
|
.. versionadded:: 1.4.27
|
||
|
|
||
|
"""
|
||
|
|
||
|
asyncpg_connection = conn._connection
|
||
|
deserializer = self._json_deserializer or _py_json.loads
|
||
|
|
||
|
def _json_decoder(bin_value):
|
||
|
return deserializer(bin_value.decode())
|
||
|
|
||
|
await asyncpg_connection.set_type_codec(
|
||
|
"json",
|
||
|
encoder=str.encode,
|
||
|
decoder=_json_decoder,
|
||
|
schema="pg_catalog",
|
||
|
format="binary",
|
||
|
)
|
||
|
|
||
|
async def setup_asyncpg_jsonb_codec(self, conn):
|
||
|
"""set up JSONB codec for asyncpg.
|
||
|
|
||
|
This occurs for all new connections and
|
||
|
can be overridden by third party dialects.
|
||
|
|
||
|
.. versionadded:: 1.4.27
|
||
|
|
||
|
"""
|
||
|
|
||
|
asyncpg_connection = conn._connection
|
||
|
deserializer = self._json_deserializer or _py_json.loads
|
||
|
|
||
|
def _jsonb_encoder(str_value):
|
||
|
# \x01 is the prefix for jsonb used by PostgreSQL.
|
||
|
# asyncpg requires it when format='binary'
|
||
|
return b"\x01" + str_value.encode()
|
||
|
|
||
|
deserializer = self._json_deserializer or _py_json.loads
|
||
|
|
||
|
def _jsonb_decoder(bin_value):
|
||
|
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
|
||
|
# asyncpg returns it when format='binary'
|
||
|
return deserializer(bin_value[1:].decode())
|
||
|
|
||
|
await asyncpg_connection.set_type_codec(
|
||
|
"jsonb",
|
||
|
encoder=_jsonb_encoder,
|
||
|
decoder=_jsonb_decoder,
|
||
|
schema="pg_catalog",
|
||
|
format="binary",
|
||
|
)
|
||
|
|
||
|
def on_connect(self):
|
||
|
"""on_connect for asyncpg
|
||
|
|
||
|
A major component of this for asyncpg is to set up type decoders at the
|
||
|
asyncpg level.
|
||
|
|
||
|
See https://github.com/MagicStack/asyncpg/issues/623 for
|
||
|
notes on JSON/JSONB implementation.
|
||
|
|
||
|
"""
|
||
|
|
||
|
super_connect = super(PGDialect_asyncpg, self).on_connect()
|
||
|
|
||
|
def connect(conn):
|
||
|
conn.await_(self.setup_asyncpg_json_codec(conn))
|
||
|
conn.await_(self.setup_asyncpg_jsonb_codec(conn))
|
||
|
if super_connect is not None:
|
||
|
super_connect(conn)
|
||
|
|
||
|
return connect
|
||
|
|
||
|
def get_driver_connection(self, connection):
|
||
|
return connection._connection
|
||
|
|
||
|
|
||
|
dialect = PGDialect_asyncpg
|